Add support for accessing LLM models on Azure Foundry
This commit is contained in:
parent
5fec171a1e
commit
61ba3ca702
|
|
@ -4,3 +4,4 @@ GOOGLE_API_KEY=
|
||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
XAI_API_KEY=
|
XAI_API_KEY=
|
||||||
OPENROUTER_API_KEY=
|
OPENROUTER_API_KEY=
|
||||||
|
AZURE_FOUNDRY_API_KEY=
|
||||||
|
|
|
||||||
54
cli/main.py
54
cli/main.py
|
|
@ -536,28 +536,31 @@ def get_user_selections():
|
||||||
)
|
)
|
||||||
selected_research_depth = select_research_depth()
|
selected_research_depth = select_research_depth()
|
||||||
|
|
||||||
# Step 5: OpenAI backend
|
# Step 5: LLM Provider
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 5: OpenAI backend", "Select which service to talk to"
|
"Step 5: LLM Provider", "Select which service to talk to"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
selected_llm_provider, backend_url = select_llm_provider()
|
selected_llm_provider, backend_url = select_llm_provider()
|
||||||
|
|
||||||
|
# Normalize provider key for dict lookups: "Azure Foundry" -> "azure_foundry"
|
||||||
|
provider_key = selected_llm_provider.lower().replace(" ", "_")
|
||||||
|
|
||||||
# Step 6: Thinking agents
|
# Step 6: Thinking agents
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
"Step 6: Thinking Agents", "Select your thinking agents for analysis"
|
"Step 6: Thinking Agents", "Select your thinking agents for analysis"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
selected_shallow_thinker = select_shallow_thinking_agent(provider_key)
|
||||||
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
|
selected_deep_thinker = select_deep_thinking_agent(provider_key)
|
||||||
|
|
||||||
# Step 7: Provider-specific thinking configuration
|
# Step 7: Provider-specific thinking configuration
|
||||||
thinking_level = None
|
thinking_level = None
|
||||||
reasoning_effort = None
|
reasoning_effort = None
|
||||||
|
|
||||||
provider_lower = selected_llm_provider.lower()
|
provider_lower = provider_key
|
||||||
if provider_lower == "google":
|
if provider_lower == "google":
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
|
|
@ -580,7 +583,7 @@ def get_user_selections():
|
||||||
"analysis_date": analysis_date,
|
"analysis_date": analysis_date,
|
||||||
"analysts": selected_analysts,
|
"analysts": selected_analysts,
|
||||||
"research_depth": selected_research_depth,
|
"research_depth": selected_research_depth,
|
||||||
"llm_provider": selected_llm_provider.lower(),
|
"llm_provider": provider_key,
|
||||||
"backend_url": backend_url,
|
"backend_url": backend_url,
|
||||||
"shallow_thinker": selected_shallow_thinker,
|
"shallow_thinker": selected_shallow_thinker,
|
||||||
"deep_thinker": selected_deep_thinker,
|
"deep_thinker": selected_deep_thinker,
|
||||||
|
|
@ -623,19 +626,19 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
||||||
analyst_parts = []
|
analyst_parts = []
|
||||||
if final_state.get("market_report"):
|
if final_state.get("market_report"):
|
||||||
analysts_dir.mkdir(exist_ok=True)
|
analysts_dir.mkdir(exist_ok=True)
|
||||||
(analysts_dir / "market.md").write_text(final_state["market_report"])
|
(analysts_dir / "market.md").write_text(final_state["market_report"], encoding="utf-8")
|
||||||
analyst_parts.append(("Market Analyst", final_state["market_report"]))
|
analyst_parts.append(("Market Analyst", final_state["market_report"]))
|
||||||
if final_state.get("sentiment_report"):
|
if final_state.get("sentiment_report"):
|
||||||
analysts_dir.mkdir(exist_ok=True)
|
analysts_dir.mkdir(exist_ok=True)
|
||||||
(analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"])
|
(analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"], encoding="utf-8")
|
||||||
analyst_parts.append(("Social Analyst", final_state["sentiment_report"]))
|
analyst_parts.append(("Social Analyst", final_state["sentiment_report"]))
|
||||||
if final_state.get("news_report"):
|
if final_state.get("news_report"):
|
||||||
analysts_dir.mkdir(exist_ok=True)
|
analysts_dir.mkdir(exist_ok=True)
|
||||||
(analysts_dir / "news.md").write_text(final_state["news_report"])
|
(analysts_dir / "news.md").write_text(final_state["news_report"], encoding="utf-8")
|
||||||
analyst_parts.append(("News Analyst", final_state["news_report"]))
|
analyst_parts.append(("News Analyst", final_state["news_report"]))
|
||||||
if final_state.get("fundamentals_report"):
|
if final_state.get("fundamentals_report"):
|
||||||
analysts_dir.mkdir(exist_ok=True)
|
analysts_dir.mkdir(exist_ok=True)
|
||||||
(analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"])
|
(analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"], encoding="utf-8")
|
||||||
analyst_parts.append(("Fundamentals Analyst", final_state["fundamentals_report"]))
|
analyst_parts.append(("Fundamentals Analyst", final_state["fundamentals_report"]))
|
||||||
if analyst_parts:
|
if analyst_parts:
|
||||||
content = "\n\n".join(f"### {name}\n{text}" for name, text in analyst_parts)
|
content = "\n\n".join(f"### {name}\n{text}" for name, text in analyst_parts)
|
||||||
|
|
@ -648,15 +651,15 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
||||||
research_parts = []
|
research_parts = []
|
||||||
if debate.get("bull_history"):
|
if debate.get("bull_history"):
|
||||||
research_dir.mkdir(exist_ok=True)
|
research_dir.mkdir(exist_ok=True)
|
||||||
(research_dir / "bull.md").write_text(debate["bull_history"])
|
(research_dir / "bull.md").write_text(debate["bull_history"], encoding="utf-8")
|
||||||
research_parts.append(("Bull Researcher", debate["bull_history"]))
|
research_parts.append(("Bull Researcher", debate["bull_history"]))
|
||||||
if debate.get("bear_history"):
|
if debate.get("bear_history"):
|
||||||
research_dir.mkdir(exist_ok=True)
|
research_dir.mkdir(exist_ok=True)
|
||||||
(research_dir / "bear.md").write_text(debate["bear_history"])
|
(research_dir / "bear.md").write_text(debate["bear_history"], encoding="utf-8")
|
||||||
research_parts.append(("Bear Researcher", debate["bear_history"]))
|
research_parts.append(("Bear Researcher", debate["bear_history"]))
|
||||||
if debate.get("judge_decision"):
|
if debate.get("judge_decision"):
|
||||||
research_dir.mkdir(exist_ok=True)
|
research_dir.mkdir(exist_ok=True)
|
||||||
(research_dir / "manager.md").write_text(debate["judge_decision"])
|
(research_dir / "manager.md").write_text(debate["judge_decision"], encoding="utf-8")
|
||||||
research_parts.append(("Research Manager", debate["judge_decision"]))
|
research_parts.append(("Research Manager", debate["judge_decision"]))
|
||||||
if research_parts:
|
if research_parts:
|
||||||
content = "\n\n".join(f"### {name}\n{text}" for name, text in research_parts)
|
content = "\n\n".join(f"### {name}\n{text}" for name, text in research_parts)
|
||||||
|
|
@ -666,7 +669,7 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
||||||
if final_state.get("trader_investment_plan"):
|
if final_state.get("trader_investment_plan"):
|
||||||
trading_dir = save_path / "3_trading"
|
trading_dir = save_path / "3_trading"
|
||||||
trading_dir.mkdir(exist_ok=True)
|
trading_dir.mkdir(exist_ok=True)
|
||||||
(trading_dir / "trader.md").write_text(final_state["trader_investment_plan"])
|
(trading_dir / "trader.md").write_text(final_state["trader_investment_plan"], encoding="utf-8")
|
||||||
sections.append(f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}")
|
sections.append(f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}")
|
||||||
|
|
||||||
# 4. Risk Management
|
# 4. Risk Management
|
||||||
|
|
@ -676,15 +679,15 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
||||||
risk_parts = []
|
risk_parts = []
|
||||||
if risk.get("aggressive_history"):
|
if risk.get("aggressive_history"):
|
||||||
risk_dir.mkdir(exist_ok=True)
|
risk_dir.mkdir(exist_ok=True)
|
||||||
(risk_dir / "aggressive.md").write_text(risk["aggressive_history"])
|
(risk_dir / "aggressive.md").write_text(risk["aggressive_history"], encoding="utf-8")
|
||||||
risk_parts.append(("Aggressive Analyst", risk["aggressive_history"]))
|
risk_parts.append(("Aggressive Analyst", risk["aggressive_history"]))
|
||||||
if risk.get("conservative_history"):
|
if risk.get("conservative_history"):
|
||||||
risk_dir.mkdir(exist_ok=True)
|
risk_dir.mkdir(exist_ok=True)
|
||||||
(risk_dir / "conservative.md").write_text(risk["conservative_history"])
|
(risk_dir / "conservative.md").write_text(risk["conservative_history"], encoding="utf-8")
|
||||||
risk_parts.append(("Conservative Analyst", risk["conservative_history"]))
|
risk_parts.append(("Conservative Analyst", risk["conservative_history"]))
|
||||||
if risk.get("neutral_history"):
|
if risk.get("neutral_history"):
|
||||||
risk_dir.mkdir(exist_ok=True)
|
risk_dir.mkdir(exist_ok=True)
|
||||||
(risk_dir / "neutral.md").write_text(risk["neutral_history"])
|
(risk_dir / "neutral.md").write_text(risk["neutral_history"], encoding="utf-8")
|
||||||
risk_parts.append(("Neutral Analyst", risk["neutral_history"]))
|
risk_parts.append(("Neutral Analyst", risk["neutral_history"]))
|
||||||
if risk_parts:
|
if risk_parts:
|
||||||
content = "\n\n".join(f"### {name}\n{text}" for name, text in risk_parts)
|
content = "\n\n".join(f"### {name}\n{text}" for name, text in risk_parts)
|
||||||
|
|
@ -694,12 +697,12 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
||||||
if risk.get("judge_decision"):
|
if risk.get("judge_decision"):
|
||||||
portfolio_dir = save_path / "5_portfolio"
|
portfolio_dir = save_path / "5_portfolio"
|
||||||
portfolio_dir.mkdir(exist_ok=True)
|
portfolio_dir.mkdir(exist_ok=True)
|
||||||
(portfolio_dir / "decision.md").write_text(risk["judge_decision"])
|
(portfolio_dir / "decision.md").write_text(risk["judge_decision"], encoding="utf-8")
|
||||||
sections.append(f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}")
|
sections.append(f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}")
|
||||||
|
|
||||||
# Write consolidated report
|
# Write consolidated report
|
||||||
header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
(save_path / "complete_report.md").write_text(header + "\n\n".join(sections))
|
(save_path / "complete_report.md").write_text(header + "\n\n".join(sections), encoding="utf-8")
|
||||||
return save_path / "complete_report.md"
|
return save_path / "complete_report.md"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -907,7 +910,12 @@ def run_analysis():
|
||||||
config["quick_think_llm"] = selections["shallow_thinker"]
|
config["quick_think_llm"] = selections["shallow_thinker"]
|
||||||
config["deep_think_llm"] = selections["deep_thinker"]
|
config["deep_think_llm"] = selections["deep_thinker"]
|
||||||
config["backend_url"] = selections["backend_url"]
|
config["backend_url"] = selections["backend_url"]
|
||||||
config["llm_provider"] = selections["llm_provider"].lower()
|
config["llm_provider"] = selections["llm_provider"]
|
||||||
|
# Azure Foundry: use endpoint from config/env, not backend_url
|
||||||
|
if selections["llm_provider"] == "azure_foundry":
|
||||||
|
# Keep the default azure_foundry_endpoint from config (or env var)
|
||||||
|
# Don't overwrite it with the empty backend_url from the provider list
|
||||||
|
pass
|
||||||
# Provider-specific thinking configuration
|
# Provider-specific thinking configuration
|
||||||
config["google_thinking_level"] = selections.get("google_thinking_level")
|
config["google_thinking_level"] = selections.get("google_thinking_level")
|
||||||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||||
|
|
@ -948,7 +956,7 @@ def run_analysis():
|
||||||
func(*args, **kwargs)
|
func(*args, **kwargs)
|
||||||
timestamp, message_type, content = obj.messages[-1]
|
timestamp, message_type, content = obj.messages[-1]
|
||||||
content = content.replace("\n", " ") # Replace newlines with spaces
|
content = content.replace("\n", " ") # Replace newlines with spaces
|
||||||
with open(log_file, "a") as f:
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
f.write(f"{timestamp} [{message_type}] {content}\n")
|
f.write(f"{timestamp} [{message_type}] {content}\n")
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
@ -959,7 +967,7 @@ def run_analysis():
|
||||||
func(*args, **kwargs)
|
func(*args, **kwargs)
|
||||||
timestamp, tool_name, args = obj.tool_calls[-1]
|
timestamp, tool_name, args = obj.tool_calls[-1]
|
||||||
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
|
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
|
||||||
with open(log_file, "a") as f:
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
|
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
@ -972,7 +980,7 @@ def run_analysis():
|
||||||
content = obj.report_sections[section_name]
|
content = obj.report_sections[section_name]
|
||||||
if content:
|
if content:
|
||||||
file_name = f"{section_name}.md"
|
file_name = f"{section_name}.md"
|
||||||
with open(report_dir / file_name, "w") as f:
|
with open(report_dir / file_name, "w", encoding="utf-8") as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
|
||||||
31
cli/utils.py
31
cli/utils.py
|
|
@ -160,6 +160,20 @@ def select_shallow_thinking_agent(provider) -> str:
|
||||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||||
],
|
],
|
||||||
|
"azure_foundry": [
|
||||||
|
("GPT-5 Mini - Cost-optimized reasoning", "gpt-5-mini"),
|
||||||
|
("GPT-5 Nano - Ultra-fast, high-throughput", "gpt-5-nano"),
|
||||||
|
("GPT-5.2 - Latest flagship", "gpt-5.2"),
|
||||||
|
("GPT-5.1 - Flexible reasoning", "gpt-5.1"),
|
||||||
|
("GPT-4.1 - Smartest non-reasoning, 1M context", "gpt-4.1"),
|
||||||
|
("Claude Haiku 4.5 - Fast + extended thinking", "claude-haiku-4-5"),
|
||||||
|
("Claude Sonnet 4.5 - Best for agents/coding", "claude-sonnet-4-5"),
|
||||||
|
("Claude Sonnet 4 - High-performance", "claude-sonnet-4-20250514"),
|
||||||
|
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
||||||
|
("Gemini 2.5 Flash - Balanced, recommended", "gemini-2.5-flash"),
|
||||||
|
("Gemini 3 Pro - Reasoning-first", "gemini-3-pro-preview"),
|
||||||
|
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
|
|
@ -228,6 +242,22 @@ def select_deep_thinking_agent(provider) -> str:
|
||||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||||
],
|
],
|
||||||
|
"azure_foundry": [
|
||||||
|
("GPT-5.2 - Latest flagship", "gpt-5.2"),
|
||||||
|
("GPT-5.1 - Flexible reasoning", "gpt-5.1"),
|
||||||
|
("GPT-5 - Advanced reasoning", "gpt-5"),
|
||||||
|
("GPT-4.1 - Smartest non-reasoning, 1M context", "gpt-4.1"),
|
||||||
|
("GPT-5 Mini - Cost-optimized reasoning", "gpt-5-mini"),
|
||||||
|
("GPT-5 Nano - Ultra-fast, high-throughput", "gpt-5-nano"),
|
||||||
|
("Claude Sonnet 4.5 - Best for agents/coding", "claude-sonnet-4-5"),
|
||||||
|
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
|
||||||
|
("Claude Opus 4.1 - Most capable model", "claude-opus-4-1-20250805"),
|
||||||
|
("Claude Haiku 4.5 - Fast + extended thinking", "claude-haiku-4-5"),
|
||||||
|
("Claude Sonnet 4 - High-performance", "claude-sonnet-4-20250514"),
|
||||||
|
("Gemini 3 Pro - Reasoning-first", "gemini-3-pro-preview"),
|
||||||
|
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
||||||
|
("Gemini 2.5 Flash - Balanced, recommended", "gemini-2.5-flash"),
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
|
|
@ -260,6 +290,7 @@ def select_llm_provider() -> tuple[str, str]:
|
||||||
("Google", "https://generativelanguage.googleapis.com/v1"),
|
("Google", "https://generativelanguage.googleapis.com/v1"),
|
||||||
("Anthropic", "https://api.anthropic.com/"),
|
("Anthropic", "https://api.anthropic.com/"),
|
||||||
("xAI", "https://api.x.ai/v1"),
|
("xAI", "https://api.x.ai/v1"),
|
||||||
|
("Azure Foundry", ""), # Endpoint read from config/env var
|
||||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||||
("Ollama", "http://localhost:11434/v1"),
|
("Ollama", "http://localhost:11434/v1"),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ def create_risk_manager(llm, memory):
|
||||||
risk_debate_state = state["risk_debate_state"]
|
risk_debate_state = state["risk_debate_state"]
|
||||||
market_research_report = state["market_report"]
|
market_research_report = state["market_report"]
|
||||||
news_report = state["news_report"]
|
news_report = state["news_report"]
|
||||||
fundamentals_report = state["news_report"]
|
fundamentals_report = state["fundamentals_report"]
|
||||||
sentiment_report = state["sentiment_report"]
|
sentiment_report = state["sentiment_report"]
|
||||||
trader_plan = state["investment_plan"]
|
trader_plan = state["investment_plan"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ DEFAULT_CONFIG = {
|
||||||
"dataflows/data_cache",
|
"dataflows/data_cache",
|
||||||
),
|
),
|
||||||
# LLM settings
|
# LLM settings
|
||||||
|
# Supported providers: openai, anthropic, google, xai, ollama, openrouter, azure_foundry
|
||||||
"llm_provider": "openai",
|
"llm_provider": "openai",
|
||||||
"deep_think_llm": "gpt-5.2",
|
"deep_think_llm": "gpt-5.2",
|
||||||
"quick_think_llm": "gpt-5-mini",
|
"quick_think_llm": "gpt-5-mini",
|
||||||
|
|
@ -15,7 +16,9 @@ DEFAULT_CONFIG = {
|
||||||
# Provider-specific thinking configuration
|
# Provider-specific thinking configuration
|
||||||
"google_thinking_level": None, # "high", "minimal", etc.
|
"google_thinking_level": None, # "high", "minimal", etc.
|
||||||
"openai_reasoning_effort": None, # "medium", "high", "low"
|
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||||
# Debate and discussion settings
|
# Azure Foundry settings (set env var AZURE_FOUNDRY_API_KEY in .env,
|
||||||
|
# and AZURE_FOUNDRY_ENDPOINT here or as env var)
|
||||||
|
"azure_foundry_endpoint": "https://<resource>.openai.azure.com/openai/v1/",
|
||||||
"max_debate_rounds": 1,
|
"max_debate_rounds": 1,
|
||||||
"max_risk_discuss_rounds": 1,
|
"max_risk_discuss_rounds": 1,
|
||||||
"max_recur_limit": 100,
|
"max_recur_limit": 100,
|
||||||
|
|
|
||||||
|
|
@ -145,6 +145,14 @@ class TradingAgentsGraph:
|
||||||
if reasoning_effort:
|
if reasoning_effort:
|
||||||
kwargs["reasoning_effort"] = reasoning_effort
|
kwargs["reasoning_effort"] = reasoning_effort
|
||||||
|
|
||||||
|
elif provider == "azure_foundry":
|
||||||
|
endpoint = self.config.get("azure_foundry_endpoint")
|
||||||
|
if endpoint:
|
||||||
|
kwargs["azure_foundry_endpoint"] = endpoint
|
||||||
|
api_key = self.config.get("azure_foundry_api_key")
|
||||||
|
if api_key:
|
||||||
|
kwargs["api_key"] = api_key
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
import os
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from .base_client import BaseLLMClient
|
||||||
|
from .validators import validate_model
|
||||||
|
|
||||||
|
|
||||||
|
class AzureFoundryClient(BaseLLMClient):
|
||||||
|
"""Client for models hosted on Azure AI Foundry.
|
||||||
|
|
||||||
|
Azure AI Foundry exposes an OpenAI-compatible chat completions endpoint,
|
||||||
|
so we use ChatOpenAI with a custom base_url pointing to your Foundry
|
||||||
|
deployment.
|
||||||
|
|
||||||
|
Required environment variables (unless passed explicitly):
|
||||||
|
AZURE_FOUNDRY_ENDPOINT: Your Azure Foundry inference endpoint URI
|
||||||
|
e.g. https://<resource>.services.ai.azure.com/models
|
||||||
|
https://<endpoint>.<region>.models.ai.azure.com/v1
|
||||||
|
AZURE_FOUNDRY_API_KEY: Your Azure Foundry API key
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
def get_llm(self) -> Any:
|
||||||
|
"""Return a ChatOpenAI instance configured for Azure Foundry."""
|
||||||
|
# Resolve endpoint: explicit kwarg > base_url > env var
|
||||||
|
endpoint = (
|
||||||
|
self.kwargs.get("azure_foundry_endpoint")
|
||||||
|
or self.base_url
|
||||||
|
or os.environ.get("AZURE_FOUNDRY_ENDPOINT")
|
||||||
|
)
|
||||||
|
if not endpoint:
|
||||||
|
raise ValueError(
|
||||||
|
"Azure Foundry endpoint is required. Set the AZURE_FOUNDRY_ENDPOINT "
|
||||||
|
"environment variable, pass 'backend_url' in the config, or provide "
|
||||||
|
"'azure_foundry_endpoint' in kwargs."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resolve API key: explicit kwarg > env var
|
||||||
|
api_key = (
|
||||||
|
self.kwargs.get("api_key")
|
||||||
|
or os.environ.get("AZURE_FOUNDRY_API_KEY")
|
||||||
|
)
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"Azure Foundry API key is required. Set the AZURE_FOUNDRY_API_KEY "
|
||||||
|
"environment variable or pass 'api_key' in kwargs."
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_kwargs = {
|
||||||
|
"model": self.model,
|
||||||
|
"base_url": endpoint,
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Forward optional params
|
||||||
|
for key in ("timeout", "max_retries", "temperature", "max_tokens", "callbacks"):
|
||||||
|
if key in self.kwargs:
|
||||||
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
return ChatOpenAI(**llm_kwargs)
|
||||||
|
|
||||||
|
def validate_model(self) -> bool:
|
||||||
|
"""Validate model for Azure Foundry."""
|
||||||
|
return validate_model("azure_foundry", self.model)
|
||||||
|
|
@ -4,6 +4,7 @@ from .base_client import BaseLLMClient
|
||||||
from .openai_client import OpenAIClient
|
from .openai_client import OpenAIClient
|
||||||
from .anthropic_client import AnthropicClient
|
from .anthropic_client import AnthropicClient
|
||||||
from .google_client import GoogleClient
|
from .google_client import GoogleClient
|
||||||
|
from .azure_foundry_client import AzureFoundryClient
|
||||||
|
|
||||||
|
|
||||||
def create_llm_client(
|
def create_llm_client(
|
||||||
|
|
@ -40,4 +41,7 @@ def create_llm_client(
|
||||||
if provider_lower == "google":
|
if provider_lower == "google":
|
||||||
return GoogleClient(model, base_url, **kwargs)
|
return GoogleClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
|
if provider_lower == "azure_foundry":
|
||||||
|
return AzureFoundryClient(model, base_url, **kwargs)
|
||||||
|
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,31 @@ VALID_MODELS = {
|
||||||
"grok-4-fast-reasoning",
|
"grok-4-fast-reasoning",
|
||||||
"grok-4-fast-non-reasoning",
|
"grok-4-fast-non-reasoning",
|
||||||
],
|
],
|
||||||
|
# Azure Foundry can host any model from the catalog;
|
||||||
|
# list common ones here but any model name is accepted.
|
||||||
|
"azure_foundry": [
|
||||||
|
# OpenAI models on Azure
|
||||||
|
"gpt-4o",
|
||||||
|
"gpt-4o-mini",
|
||||||
|
"gpt-4.1",
|
||||||
|
"gpt-4.1-mini",
|
||||||
|
"gpt-4.1-nano",
|
||||||
|
"o4-mini",
|
||||||
|
"o3",
|
||||||
|
"o3-mini",
|
||||||
|
# Meta Llama models
|
||||||
|
"Meta-Llama-3.3-70B-Instruct",
|
||||||
|
"Meta-Llama-3.1-405B-Instruct",
|
||||||
|
# Mistral models
|
||||||
|
"Mistral-Large-2",
|
||||||
|
"Mistral-Small",
|
||||||
|
# Cohere models
|
||||||
|
"Cohere-command-r-plus",
|
||||||
|
"Cohere-command-r",
|
||||||
|
# DeepSeek models
|
||||||
|
"DeepSeek-R1",
|
||||||
|
"DeepSeek-V3",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -73,7 +98,7 @@ def validate_model(provider: str, model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower()
|
||||||
|
|
||||||
if provider_lower in ("ollama", "openrouter"):
|
if provider_lower in ("ollama", "openrouter", "azure_foundry"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if provider_lower not in VALID_MODELS:
|
if provider_lower not in VALID_MODELS:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue