Add support for accessing LLM models on Azure Foundry

This commit is contained in:
Leonard Zhang 2026-03-01 16:06:15 -08:00
parent 5fec171a1e
commit 61ba3ca702
9 changed files with 179 additions and 26 deletions

View File

@ -4,3 +4,4 @@ GOOGLE_API_KEY=
ANTHROPIC_API_KEY= ANTHROPIC_API_KEY=
XAI_API_KEY= XAI_API_KEY=
OPENROUTER_API_KEY= OPENROUTER_API_KEY=
AZURE_FOUNDRY_API_KEY=

View File

@ -536,28 +536,31 @@ def get_user_selections():
) )
selected_research_depth = select_research_depth() selected_research_depth = select_research_depth()
# Step 5: OpenAI backend # Step 5: LLM Provider
console.print( console.print(
create_question_box( create_question_box(
"Step 5: OpenAI backend", "Select which service to talk to" "Step 5: LLM Provider", "Select which service to talk to"
) )
) )
selected_llm_provider, backend_url = select_llm_provider() selected_llm_provider, backend_url = select_llm_provider()
# Normalize provider key for dict lookups: "Azure Foundry" -> "azure_foundry"
provider_key = selected_llm_provider.lower().replace(" ", "_")
# Step 6: Thinking agents # Step 6: Thinking agents
console.print( console.print(
create_question_box( create_question_box(
"Step 6: Thinking Agents", "Select your thinking agents for analysis" "Step 6: Thinking Agents", "Select your thinking agents for analysis"
) )
) )
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) selected_shallow_thinker = select_shallow_thinking_agent(provider_key)
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) selected_deep_thinker = select_deep_thinking_agent(provider_key)
# Step 7: Provider-specific thinking configuration # Step 7: Provider-specific thinking configuration
thinking_level = None thinking_level = None
reasoning_effort = None reasoning_effort = None
provider_lower = selected_llm_provider.lower() provider_lower = provider_key
if provider_lower == "google": if provider_lower == "google":
console.print( console.print(
create_question_box( create_question_box(
@ -580,7 +583,7 @@ def get_user_selections():
"analysis_date": analysis_date, "analysis_date": analysis_date,
"analysts": selected_analysts, "analysts": selected_analysts,
"research_depth": selected_research_depth, "research_depth": selected_research_depth,
"llm_provider": selected_llm_provider.lower(), "llm_provider": provider_key,
"backend_url": backend_url, "backend_url": backend_url,
"shallow_thinker": selected_shallow_thinker, "shallow_thinker": selected_shallow_thinker,
"deep_thinker": selected_deep_thinker, "deep_thinker": selected_deep_thinker,
@ -623,19 +626,19 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
analyst_parts = [] analyst_parts = []
if final_state.get("market_report"): if final_state.get("market_report"):
analysts_dir.mkdir(exist_ok=True) analysts_dir.mkdir(exist_ok=True)
(analysts_dir / "market.md").write_text(final_state["market_report"]) (analysts_dir / "market.md").write_text(final_state["market_report"], encoding="utf-8")
analyst_parts.append(("Market Analyst", final_state["market_report"])) analyst_parts.append(("Market Analyst", final_state["market_report"]))
if final_state.get("sentiment_report"): if final_state.get("sentiment_report"):
analysts_dir.mkdir(exist_ok=True) analysts_dir.mkdir(exist_ok=True)
(analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"]) (analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"], encoding="utf-8")
analyst_parts.append(("Social Analyst", final_state["sentiment_report"])) analyst_parts.append(("Social Analyst", final_state["sentiment_report"]))
if final_state.get("news_report"): if final_state.get("news_report"):
analysts_dir.mkdir(exist_ok=True) analysts_dir.mkdir(exist_ok=True)
(analysts_dir / "news.md").write_text(final_state["news_report"]) (analysts_dir / "news.md").write_text(final_state["news_report"], encoding="utf-8")
analyst_parts.append(("News Analyst", final_state["news_report"])) analyst_parts.append(("News Analyst", final_state["news_report"]))
if final_state.get("fundamentals_report"): if final_state.get("fundamentals_report"):
analysts_dir.mkdir(exist_ok=True) analysts_dir.mkdir(exist_ok=True)
(analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"]) (analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"], encoding="utf-8")
analyst_parts.append(("Fundamentals Analyst", final_state["fundamentals_report"])) analyst_parts.append(("Fundamentals Analyst", final_state["fundamentals_report"]))
if analyst_parts: if analyst_parts:
content = "\n\n".join(f"### {name}\n{text}" for name, text in analyst_parts) content = "\n\n".join(f"### {name}\n{text}" for name, text in analyst_parts)
@ -648,15 +651,15 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
research_parts = [] research_parts = []
if debate.get("bull_history"): if debate.get("bull_history"):
research_dir.mkdir(exist_ok=True) research_dir.mkdir(exist_ok=True)
(research_dir / "bull.md").write_text(debate["bull_history"]) (research_dir / "bull.md").write_text(debate["bull_history"], encoding="utf-8")
research_parts.append(("Bull Researcher", debate["bull_history"])) research_parts.append(("Bull Researcher", debate["bull_history"]))
if debate.get("bear_history"): if debate.get("bear_history"):
research_dir.mkdir(exist_ok=True) research_dir.mkdir(exist_ok=True)
(research_dir / "bear.md").write_text(debate["bear_history"]) (research_dir / "bear.md").write_text(debate["bear_history"], encoding="utf-8")
research_parts.append(("Bear Researcher", debate["bear_history"])) research_parts.append(("Bear Researcher", debate["bear_history"]))
if debate.get("judge_decision"): if debate.get("judge_decision"):
research_dir.mkdir(exist_ok=True) research_dir.mkdir(exist_ok=True)
(research_dir / "manager.md").write_text(debate["judge_decision"]) (research_dir / "manager.md").write_text(debate["judge_decision"], encoding="utf-8")
research_parts.append(("Research Manager", debate["judge_decision"])) research_parts.append(("Research Manager", debate["judge_decision"]))
if research_parts: if research_parts:
content = "\n\n".join(f"### {name}\n{text}" for name, text in research_parts) content = "\n\n".join(f"### {name}\n{text}" for name, text in research_parts)
@ -666,7 +669,7 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
if final_state.get("trader_investment_plan"): if final_state.get("trader_investment_plan"):
trading_dir = save_path / "3_trading" trading_dir = save_path / "3_trading"
trading_dir.mkdir(exist_ok=True) trading_dir.mkdir(exist_ok=True)
(trading_dir / "trader.md").write_text(final_state["trader_investment_plan"]) (trading_dir / "trader.md").write_text(final_state["trader_investment_plan"], encoding="utf-8")
sections.append(f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}") sections.append(f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}")
# 4. Risk Management # 4. Risk Management
@ -676,15 +679,15 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
risk_parts = [] risk_parts = []
if risk.get("aggressive_history"): if risk.get("aggressive_history"):
risk_dir.mkdir(exist_ok=True) risk_dir.mkdir(exist_ok=True)
(risk_dir / "aggressive.md").write_text(risk["aggressive_history"]) (risk_dir / "aggressive.md").write_text(risk["aggressive_history"], encoding="utf-8")
risk_parts.append(("Aggressive Analyst", risk["aggressive_history"])) risk_parts.append(("Aggressive Analyst", risk["aggressive_history"]))
if risk.get("conservative_history"): if risk.get("conservative_history"):
risk_dir.mkdir(exist_ok=True) risk_dir.mkdir(exist_ok=True)
(risk_dir / "conservative.md").write_text(risk["conservative_history"]) (risk_dir / "conservative.md").write_text(risk["conservative_history"], encoding="utf-8")
risk_parts.append(("Conservative Analyst", risk["conservative_history"])) risk_parts.append(("Conservative Analyst", risk["conservative_history"]))
if risk.get("neutral_history"): if risk.get("neutral_history"):
risk_dir.mkdir(exist_ok=True) risk_dir.mkdir(exist_ok=True)
(risk_dir / "neutral.md").write_text(risk["neutral_history"]) (risk_dir / "neutral.md").write_text(risk["neutral_history"], encoding="utf-8")
risk_parts.append(("Neutral Analyst", risk["neutral_history"])) risk_parts.append(("Neutral Analyst", risk["neutral_history"]))
if risk_parts: if risk_parts:
content = "\n\n".join(f"### {name}\n{text}" for name, text in risk_parts) content = "\n\n".join(f"### {name}\n{text}" for name, text in risk_parts)
@ -694,12 +697,12 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
if risk.get("judge_decision"): if risk.get("judge_decision"):
portfolio_dir = save_path / "5_portfolio" portfolio_dir = save_path / "5_portfolio"
portfolio_dir.mkdir(exist_ok=True) portfolio_dir.mkdir(exist_ok=True)
(portfolio_dir / "decision.md").write_text(risk["judge_decision"]) (portfolio_dir / "decision.md").write_text(risk["judge_decision"], encoding="utf-8")
sections.append(f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}") sections.append(f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}")
# Write consolidated report # Write consolidated report
header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
(save_path / "complete_report.md").write_text(header + "\n\n".join(sections)) (save_path / "complete_report.md").write_text(header + "\n\n".join(sections), encoding="utf-8")
return save_path / "complete_report.md" return save_path / "complete_report.md"
@ -907,7 +910,12 @@ def run_analysis():
config["quick_think_llm"] = selections["shallow_thinker"] config["quick_think_llm"] = selections["shallow_thinker"]
config["deep_think_llm"] = selections["deep_thinker"] config["deep_think_llm"] = selections["deep_thinker"]
config["backend_url"] = selections["backend_url"] config["backend_url"] = selections["backend_url"]
config["llm_provider"] = selections["llm_provider"].lower() config["llm_provider"] = selections["llm_provider"]
# Azure Foundry: use endpoint from config/env, not backend_url
if selections["llm_provider"] == "azure_foundry":
# Keep the default azure_foundry_endpoint from config (or env var)
# Don't overwrite it with the empty backend_url from the provider list
pass
# Provider-specific thinking configuration # Provider-specific thinking configuration
config["google_thinking_level"] = selections.get("google_thinking_level") config["google_thinking_level"] = selections.get("google_thinking_level")
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
@ -948,7 +956,7 @@ def run_analysis():
func(*args, **kwargs) func(*args, **kwargs)
timestamp, message_type, content = obj.messages[-1] timestamp, message_type, content = obj.messages[-1]
content = content.replace("\n", " ") # Replace newlines with spaces content = content.replace("\n", " ") # Replace newlines with spaces
with open(log_file, "a") as f: with open(log_file, "a", encoding="utf-8") as f:
f.write(f"{timestamp} [{message_type}] {content}\n") f.write(f"{timestamp} [{message_type}] {content}\n")
return wrapper return wrapper
@ -959,7 +967,7 @@ def run_analysis():
func(*args, **kwargs) func(*args, **kwargs)
timestamp, tool_name, args = obj.tool_calls[-1] timestamp, tool_name, args = obj.tool_calls[-1]
args_str = ", ".join(f"{k}={v}" for k, v in args.items()) args_str = ", ".join(f"{k}={v}" for k, v in args.items())
with open(log_file, "a") as f: with open(log_file, "a", encoding="utf-8") as f:
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n") f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
return wrapper return wrapper
@ -972,7 +980,7 @@ def run_analysis():
content = obj.report_sections[section_name] content = obj.report_sections[section_name]
if content: if content:
file_name = f"{section_name}.md" file_name = f"{section_name}.md"
with open(report_dir / file_name, "w") as f: with open(report_dir / file_name, "w", encoding="utf-8") as f:
f.write(content) f.write(content)
return wrapper return wrapper

View File

@ -160,6 +160,20 @@ def select_shallow_thinking_agent(provider) -> str:
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
], ],
"azure_foundry": [
("GPT-5 Mini - Cost-optimized reasoning", "gpt-5-mini"),
("GPT-5 Nano - Ultra-fast, high-throughput", "gpt-5-nano"),
("GPT-5.2 - Latest flagship", "gpt-5.2"),
("GPT-5.1 - Flexible reasoning", "gpt-5.1"),
("GPT-4.1 - Smartest non-reasoning, 1M context", "gpt-4.1"),
("Claude Haiku 4.5 - Fast + extended thinking", "claude-haiku-4-5"),
("Claude Sonnet 4.5 - Best for agents/coding", "claude-sonnet-4-5"),
("Claude Sonnet 4 - High-performance", "claude-sonnet-4-20250514"),
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
("Gemini 2.5 Flash - Balanced, recommended", "gemini-2.5-flash"),
("Gemini 3 Pro - Reasoning-first", "gemini-3-pro-preview"),
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
],
} }
choice = questionary.select( choice = questionary.select(
@ -228,6 +242,22 @@ def select_deep_thinking_agent(provider) -> str:
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("Qwen3:latest (8B, local)", "qwen3:latest"), ("Qwen3:latest (8B, local)", "qwen3:latest"),
], ],
"azure_foundry": [
("GPT-5.2 - Latest flagship", "gpt-5.2"),
("GPT-5.1 - Flexible reasoning", "gpt-5.1"),
("GPT-5 - Advanced reasoning", "gpt-5"),
("GPT-4.1 - Smartest non-reasoning, 1M context", "gpt-4.1"),
("GPT-5 Mini - Cost-optimized reasoning", "gpt-5-mini"),
("GPT-5 Nano - Ultra-fast, high-throughput", "gpt-5-nano"),
("Claude Sonnet 4.5 - Best for agents/coding", "claude-sonnet-4-5"),
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
("Claude Opus 4.1 - Most capable model", "claude-opus-4-1-20250805"),
("Claude Haiku 4.5 - Fast + extended thinking", "claude-haiku-4-5"),
("Claude Sonnet 4 - High-performance", "claude-sonnet-4-20250514"),
("Gemini 3 Pro - Reasoning-first", "gemini-3-pro-preview"),
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
("Gemini 2.5 Flash - Balanced, recommended", "gemini-2.5-flash"),
],
} }
choice = questionary.select( choice = questionary.select(
@ -260,6 +290,7 @@ def select_llm_provider() -> tuple[str, str]:
("Google", "https://generativelanguage.googleapis.com/v1"), ("Google", "https://generativelanguage.googleapis.com/v1"),
("Anthropic", "https://api.anthropic.com/"), ("Anthropic", "https://api.anthropic.com/"),
("xAI", "https://api.x.ai/v1"), ("xAI", "https://api.x.ai/v1"),
("Azure Foundry", ""), # Endpoint read from config/env var
("Openrouter", "https://openrouter.ai/api/v1"), ("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"), ("Ollama", "http://localhost:11434/v1"),
] ]

View File

@ -11,7 +11,7 @@ def create_risk_manager(llm, memory):
risk_debate_state = state["risk_debate_state"] risk_debate_state = state["risk_debate_state"]
market_research_report = state["market_report"] market_research_report = state["market_report"]
news_report = state["news_report"] news_report = state["news_report"]
fundamentals_report = state["news_report"] fundamentals_report = state["fundamentals_report"]
sentiment_report = state["sentiment_report"] sentiment_report = state["sentiment_report"]
trader_plan = state["investment_plan"] trader_plan = state["investment_plan"]

View File

@ -8,6 +8,7 @@ DEFAULT_CONFIG = {
"dataflows/data_cache", "dataflows/data_cache",
), ),
# LLM settings # LLM settings
# Supported providers: openai, anthropic, google, xai, ollama, openrouter, azure_foundry
"llm_provider": "openai", "llm_provider": "openai",
"deep_think_llm": "gpt-5.2", "deep_think_llm": "gpt-5.2",
"quick_think_llm": "gpt-5-mini", "quick_think_llm": "gpt-5-mini",
@ -15,7 +16,9 @@ DEFAULT_CONFIG = {
# Provider-specific thinking configuration # Provider-specific thinking configuration
"google_thinking_level": None, # "high", "minimal", etc. "google_thinking_level": None, # "high", "minimal", etc.
"openai_reasoning_effort": None, # "medium", "high", "low" "openai_reasoning_effort": None, # "medium", "high", "low"
# Debate and discussion settings # Azure Foundry settings (set env var AZURE_FOUNDRY_API_KEY in .env,
# and AZURE_FOUNDRY_ENDPOINT here or as env var)
"azure_foundry_endpoint": "https://<resource>.openai.azure.com/openai/v1/",
"max_debate_rounds": 1, "max_debate_rounds": 1,
"max_risk_discuss_rounds": 1, "max_risk_discuss_rounds": 1,
"max_recur_limit": 100, "max_recur_limit": 100,

View File

@ -145,6 +145,14 @@ class TradingAgentsGraph:
if reasoning_effort: if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort kwargs["reasoning_effort"] = reasoning_effort
elif provider == "azure_foundry":
endpoint = self.config.get("azure_foundry_endpoint")
if endpoint:
kwargs["azure_foundry_endpoint"] = endpoint
api_key = self.config.get("azure_foundry_api_key")
if api_key:
kwargs["api_key"] = api_key
return kwargs return kwargs
def _create_tool_nodes(self) -> Dict[str, ToolNode]: def _create_tool_nodes(self) -> Dict[str, ToolNode]:

View File

@ -0,0 +1,73 @@
import os
from typing import Any, Optional
from langchain_openai import ChatOpenAI
from .base_client import BaseLLMClient
from .validators import validate_model
class AzureFoundryClient(BaseLLMClient):
"""Client for models hosted on Azure AI Foundry.
Azure AI Foundry exposes an OpenAI-compatible chat completions endpoint,
so we use ChatOpenAI with a custom base_url pointing to your Foundry
deployment.
Required environment variables (unless passed explicitly):
AZURE_FOUNDRY_ENDPOINT: Your Azure Foundry inference endpoint URI
e.g. https://<resource>.services.ai.azure.com/models
https://<endpoint>.<region>.models.ai.azure.com/v1
AZURE_FOUNDRY_API_KEY: Your Azure Foundry API key
"""
def __init__(
self,
model: str,
base_url: Optional[str] = None,
**kwargs,
):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
"""Return a ChatOpenAI instance configured for Azure Foundry."""
# Resolve endpoint: explicit kwarg > base_url > env var
endpoint = (
self.kwargs.get("azure_foundry_endpoint")
or self.base_url
or os.environ.get("AZURE_FOUNDRY_ENDPOINT")
)
if not endpoint:
raise ValueError(
"Azure Foundry endpoint is required. Set the AZURE_FOUNDRY_ENDPOINT "
"environment variable, pass 'backend_url' in the config, or provide "
"'azure_foundry_endpoint' in kwargs."
)
# Resolve API key: explicit kwarg > env var
api_key = (
self.kwargs.get("api_key")
or os.environ.get("AZURE_FOUNDRY_API_KEY")
)
if not api_key:
raise ValueError(
"Azure Foundry API key is required. Set the AZURE_FOUNDRY_API_KEY "
"environment variable or pass 'api_key' in kwargs."
)
llm_kwargs = {
"model": self.model,
"base_url": endpoint,
"api_key": api_key,
}
# Forward optional params
for key in ("timeout", "max_retries", "temperature", "max_tokens", "callbacks"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
return ChatOpenAI(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate model for Azure Foundry."""
return validate_model("azure_foundry", self.model)

View File

@ -4,6 +4,7 @@ from .base_client import BaseLLMClient
from .openai_client import OpenAIClient from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient from .anthropic_client import AnthropicClient
from .google_client import GoogleClient from .google_client import GoogleClient
from .azure_foundry_client import AzureFoundryClient
def create_llm_client( def create_llm_client(
@ -40,4 +41,7 @@ def create_llm_client(
if provider_lower == "google": if provider_lower == "google":
return GoogleClient(model, base_url, **kwargs) return GoogleClient(model, base_url, **kwargs)
if provider_lower == "azure_foundry":
return AzureFoundryClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}") raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@ -63,6 +63,31 @@ VALID_MODELS = {
"grok-4-fast-reasoning", "grok-4-fast-reasoning",
"grok-4-fast-non-reasoning", "grok-4-fast-non-reasoning",
], ],
# Azure Foundry can host any model from the catalog;
# list common ones here but any model name is accepted.
"azure_foundry": [
# OpenAI models on Azure
"gpt-4o",
"gpt-4o-mini",
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
"o4-mini",
"o3",
"o3-mini",
# Meta Llama models
"Meta-Llama-3.3-70B-Instruct",
"Meta-Llama-3.1-405B-Instruct",
# Mistral models
"Mistral-Large-2",
"Mistral-Small",
# Cohere models
"Cohere-command-r-plus",
"Cohere-command-r",
# DeepSeek models
"DeepSeek-R1",
"DeepSeek-V3",
],
} }
@ -73,7 +98,7 @@ def validate_model(provider: str, model: str) -> bool:
""" """
provider_lower = provider.lower() provider_lower = provider.lower()
if provider_lower in ("ollama", "openrouter"): if provider_lower in ("ollama", "openrouter", "azure_foundry"):
return True return True
if provider_lower not in VALID_MODELS: if provider_lower not in VALID_MODELS: