diff --git a/.env.example b/.env.example index 1328b838..e1c5c2f1 100644 --- a/.env.example +++ b/.env.example @@ -1,5 +1,8 @@ # LLM Providers (set the one you use) OPENAI_API_KEY= +AZURE_OPENAI_API_KEY= +AZURE_OPENAI_ENDPOINT= +AZURE_OPENAI_API_VERSION=2024-10-21 GOOGLE_API_KEY= ANTHROPIC_API_KEY= XAI_API_KEY= diff --git a/README.md b/README.md index 34310010..b7801de5 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,7 @@ An interface will appear showing results as they load, letting you track the age ### Implementation Details -We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, OpenRouter, and Ollama. +We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Azure OpenAI, Google, Anthropic, xAI, OpenRouter, and Ollama. ### Python Usage @@ -186,7 +186,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG config = DEFAULT_CONFIG.copy() -config["llm_provider"] = "openai" # openai, google, anthropic, xai, openrouter, ollama +config["llm_provider"] = "openai" # openai, azure, google, anthropic, xai, openrouter, ollama config["deep_think_llm"] = "gpt-5.2" # Model for complex reasoning config["quick_think_llm"] = "gpt-5-mini" # Model for quick tasks config["max_debate_rounds"] = 2 diff --git a/cli/main.py b/cli/main.py index fb97d189..ee5abcf5 100644 --- a/cli/main.py +++ b/cli/main.py @@ -536,10 +536,10 @@ def get_user_selections(): ) selected_research_depth = select_research_depth() - # Step 5: OpenAI backend + # Step 5: LLM provider console.print( create_question_box( - "Step 5: OpenAI backend", "Select which service to talk to" + "Step 5: LLM Provider", "Select which service to talk to" ) ) selected_llm_provider, backend_url = select_llm_provider() @@ -556,6 +556,8 @@ def get_user_selections(): # Step 7: Provider-specific thinking configuration thinking_level = None reasoning_effort = None + azure_endpoint = None + azure_api_version = None provider_lower = selected_llm_provider.lower() if provider_lower == "google": @@ -574,6 +576,25 @@ def get_user_selections(): ) ) reasoning_effort = ask_openai_reasoning_effort() + elif provider_lower == "azure": + console.print( + create_question_box( + "Step 7: Azure OpenAI", + "Configure endpoint, API version, and deployment names" + ) + ) + azure_endpoint = ask_azure_endpoint(backend_url) + azure_api_version = ask_azure_api_version() + selected_shallow_thinker = ask_azure_deployment_name( + "Quick-Thinking LLM", + selected_shallow_thinker, + ) + selected_deep_thinker = ask_azure_deployment_name( + "Deep-Thinking LLM", + selected_deep_thinker, + ) + backend_url = azure_endpoint + reasoning_effort = ask_openai_reasoning_effort() return { "ticker": selected_ticker, @@ -586,6 +607,8 @@ def get_user_selections(): "deep_thinker": selected_deep_thinker, "google_thinking_level": thinking_level, "openai_reasoning_effort": reasoning_effort, + "azure_endpoint": azure_endpoint, + "azure_api_version": azure_api_version, } @@ -911,6 +934,8 @@ def run_analysis(): # Provider-specific thinking configuration config["google_thinking_level"] = selections.get("google_thinking_level") config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") + config["azure_endpoint"] = selections.get("azure_endpoint") + config["azure_api_version"] = selections.get("azure_api_version") # Create stats callback handler for tracking LLM/tool calls stats_handler = StatsCallbackHandler() diff --git a/cli/utils.py b/cli/utils.py index aa097fb5..4e01f3b5 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -134,6 +134,13 @@ def select_shallow_thinking_agent(provider) -> str: ("GPT-5.1 - Flexible reasoning", "gpt-5.1"), ("GPT-4.1 - Smartest non-reasoning, 1M context", "gpt-4.1"), ], + "azure": [ + ("GPT-5 Mini deployment", "gpt-5-mini"), + ("GPT-5 Nano deployment", "gpt-5-nano"), + ("GPT-5.2 deployment", "gpt-5.2"), + ("GPT-5.1 deployment", "gpt-5.1"), + ("GPT-4.1 deployment", "gpt-4.1"), + ], "anthropic": [ ("Claude Haiku 4.5 - Fast + extended thinking", "claude-haiku-4-5"), ("Claude Sonnet 4.5 - Best for agents/coding", "claude-sonnet-4-5"), @@ -200,6 +207,14 @@ def select_deep_thinking_agent(provider) -> str: ("GPT-5 Mini - Cost-optimized reasoning", "gpt-5-mini"), ("GPT-5 Nano - Ultra-fast, high-throughput", "gpt-5-nano"), ], + "azure": [ + ("GPT-5.2 deployment", "gpt-5.2"), + ("GPT-5.1 deployment", "gpt-5.1"), + ("GPT-5 deployment", "gpt-5"), + ("GPT-4.1 deployment", "gpt-4.1"), + ("GPT-5 Mini deployment", "gpt-5-mini"), + ("GPT-5 Nano deployment", "gpt-5-nano"), + ], "anthropic": [ ("Claude Sonnet 4.5 - Best for agents/coding", "claude-sonnet-4-5"), ("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"), @@ -253,10 +268,11 @@ def select_deep_thinking_agent(provider) -> str: return choice def select_llm_provider() -> tuple[str, str]: - """Select the OpenAI api url using interactive selection.""" - # Define OpenAI api options with their corresponding endpoints + """Select an LLM provider and default API endpoint.""" + # Define provider options with their corresponding endpoints BASE_URLS = [ ("OpenAI", "https://api.openai.com/v1"), + ("Azure", "https://YOUR-RESOURCE.openai.azure.com/"), ("Google", "https://generativelanguage.googleapis.com/v1"), ("Anthropic", "https://api.anthropic.com/"), ("xAI", "https://api.x.ai/v1"), @@ -281,7 +297,7 @@ def select_llm_provider() -> tuple[str, str]: ).ask() if choice is None: - console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") + console.print("\n[red]No LLM provider selected. Exiting...[/red]") exit(1) display_name, url = choice @@ -326,3 +342,30 @@ def ask_gemini_thinking_config() -> str | None: ("pointer", "fg:green noinherit"), ]), ).ask() + + +def ask_azure_endpoint(default: str = "https://YOUR-RESOURCE.openai.azure.com/") -> str: + """Ask for Azure OpenAI endpoint URL.""" + return questionary.text( + "Enter Azure OpenAI Endpoint URL:", + default=default, + validate=lambda x: len(x.strip()) > 0 or "Endpoint URL is required.", + ).ask() + + +def ask_azure_api_version(default: str = "2024-10-21") -> str: + """Ask for Azure OpenAI API version.""" + return questionary.text( + "Enter Azure OpenAI API Version:", + default=default, + validate=lambda x: len(x.strip()) > 0 or "API version is required.", + ).ask() + + +def ask_azure_deployment_name(label: str, default: str) -> str: + """Ask for Azure OpenAI deployment name.""" + return questionary.text( + f"Enter Azure deployment for {label}:", + default=default, + validate=lambda x: len(x.strip()) > 0 or "Deployment name is required.", + ).ask() diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index ecf0dc29..70c7797e 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -15,6 +15,8 @@ DEFAULT_CONFIG = { # Provider-specific thinking configuration "google_thinking_level": None, # "high", "minimal", etc. "openai_reasoning_effort": None, # "medium", "high", "low" + "azure_endpoint": None, # e.g. https://.openai.azure.com/ + "azure_api_version": "2024-10-21", # Debate and discussion settings "max_debate_rounds": 1, "max_risk_discuss_rounds": 1, diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 44ecca0c..17a771ec 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -144,6 +144,18 @@ class TradingAgentsGraph: reasoning_effort = self.config.get("openai_reasoning_effort") if reasoning_effort: kwargs["reasoning_effort"] = reasoning_effort + elif provider == "azure": + azure_endpoint = self.config.get("azure_endpoint") or self.config.get("backend_url") + if azure_endpoint: + kwargs["azure_endpoint"] = azure_endpoint + + api_version = self.config.get("azure_api_version") + if api_version: + kwargs["api_version"] = api_version + + reasoning_effort = self.config.get("openai_reasoning_effort") + if reasoning_effort: + kwargs["reasoning_effort"] = reasoning_effort return kwargs diff --git a/tradingagents/llm_clients/azure_openai_client.py b/tradingagents/llm_clients/azure_openai_client.py new file mode 100644 index 00000000..cf5a3802 --- /dev/null +++ b/tradingagents/llm_clients/azure_openai_client.py @@ -0,0 +1,53 @@ +import os +from typing import Any, Optional + +from langchain_openai import AzureChatOpenAI + +from .base_client import BaseLLMClient +from .validators import validate_model + + +class AzureOpenAIClient(BaseLLMClient): + """Client for Azure OpenAI models via AzureChatOpenAI.""" + + def __init__( + self, + model: str, + base_url: Optional[str] = None, + **kwargs, + ): + super().__init__(model, base_url, **kwargs) + + def get_llm(self) -> Any: + """Return configured AzureChatOpenAI instance.""" + azure_endpoint = ( + self.kwargs.get("azure_endpoint") + or self.base_url + or os.environ.get("AZURE_OPENAI_ENDPOINT") + ) + api_version = self.kwargs.get("api_version") or os.environ.get( + "AZURE_OPENAI_API_VERSION", + "2024-10-21", + ) + api_key = self.kwargs.get("api_key") or os.environ.get("AZURE_OPENAI_API_KEY") + + llm_kwargs = { + "azure_deployment": self.model, + "model": self.model, + "api_version": api_version, + } + + if azure_endpoint: + llm_kwargs["azure_endpoint"] = azure_endpoint + if api_key: + llm_kwargs["api_key"] = api_key + + for key in ("timeout", "max_retries", "reasoning_effort", "callbacks"): + if key in self.kwargs: + llm_kwargs[key] = self.kwargs[key] + + return AzureChatOpenAI(**llm_kwargs) + + def validate_model(self) -> bool: + """Validate model for Azure OpenAI.""" + return validate_model("azure", self.model) diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index 028c88a2..1b19b908 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -2,6 +2,7 @@ from typing import Optional from .base_client import BaseLLMClient from .openai_client import OpenAIClient +from .azure_openai_client import AzureOpenAIClient from .anthropic_client import AnthropicClient from .google_client import GoogleClient @@ -15,7 +16,7 @@ def create_llm_client( """Create an LLM client for the specified provider. Args: - provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter) + provider: LLM provider (openai, azure, anthropic, google, xai, ollama, openrouter) model: Model name/identifier base_url: Optional base URL for API endpoint **kwargs: Additional provider-specific arguments @@ -31,6 +32,9 @@ def create_llm_client( if provider_lower in ("openai", "ollama", "openrouter"): return OpenAIClient(model, base_url, provider=provider_lower, **kwargs) + if provider_lower == "azure": + return AzureOpenAIClient(model, base_url, **kwargs) + if provider_lower == "xai": return OpenAIClient(model, base_url, provider="xai", **kwargs) diff --git a/tradingagents/llm_clients/validators.py b/tradingagents/llm_clients/validators.py index 3c0f2290..85de716c 100644 --- a/tradingagents/llm_clients/validators.py +++ b/tradingagents/llm_clients/validators.py @@ -69,11 +69,11 @@ VALID_MODELS = { def validate_model(provider: str, model: str) -> bool: """Check if model name is valid for the given provider. - For ollama, openrouter - any model is accepted. + For ollama, openrouter, azure - any model/deployment is accepted. """ provider_lower = provider.lower() - if provider_lower in ("ollama", "openrouter"): + if provider_lower in ("ollama", "openrouter", "azure"): return True if provider_lower not in VALID_MODELS: