From 5bc20dfe9087316ff63f5ba96bb3fbb7a20ed12f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=98=B3=E8=99=8E?= Date: Mon, 16 Mar 2026 08:18:36 +0800 Subject: [PATCH] feat: add Azure OpenAI support - Add AzureClient class using langchain_openai.AzureChatOpenAI - Add Azure to LLM provider factory - Add Azure deployment options to CLI selectors - Support environment variables: AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, OPENAI_API_VERSION Fixes #334 --- cli/utils.py | 13 +++++ tradingagents/llm_clients/azure_client.py | 64 +++++++++++++++++++++++ tradingagents/llm_clients/factory.py | 8 ++- 3 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 tradingagents/llm_clients/azure_client.py diff --git a/cli/utils.py b/cli/utils.py index 5a8ec16c..3bfc80a6 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -164,6 +164,12 @@ def select_shallow_thinking_agent(provider) -> str: ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), ], + "azure": [ + ("GPT-4o (your deployment)", "gpt-4o"), + ("GPT-4o-mini (your deployment)", "gpt-4o-mini"), + ("GPT-4 Turbo (your deployment)", "gpt-4-turbo"), + ("Custom deployment name", "__custom__"), + ], } choice = questionary.select( @@ -231,6 +237,12 @@ def select_deep_thinking_agent(provider) -> str: ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), ("Qwen3:latest (8B, local)", "qwen3:latest"), ], + "azure": [ + ("GPT-4o (your deployment)", "gpt-4o"), + ("GPT-4o-mini (your deployment)", "gpt-4o-mini"), + ("GPT-4 Turbo (your deployment)", "gpt-4-turbo"), + ("Custom deployment name", "__custom__"), + ], } choice = questionary.select( @@ -265,6 +277,7 @@ def select_llm_provider() -> tuple[str, str]: ("xAI", "https://api.x.ai/v1"), ("Openrouter", "https://openrouter.ai/api/v1"), ("Ollama", "http://localhost:11434/v1"), + ("Azure OpenAI", "azure"), ] choice = questionary.select( diff --git a/tradingagents/llm_clients/azure_client.py b/tradingagents/llm_clients/azure_client.py new file mode 100644 index 00000000..8ea6c909 --- /dev/null +++ b/tradingagents/llm_clients/azure_client.py @@ -0,0 +1,64 @@ +"""Azure OpenAI client implementation.""" + +import os +from typing import Any, Optional + +from langchain_openai import AzureChatOpenAI + +from .base_client import BaseLLMClient + + +class AzureClient(BaseLLMClient): + """Client for Azure OpenAI provider. + + Required environment variables: + AZURE_OPENAI_API_KEY: Your Azure OpenAI API key + AZURE_OPENAI_ENDPOINT: Your Azure endpoint (e.g., https://your-resource.openai.azure.com/) + OPENAI_API_VERSION: API version (e.g., 2024-02-15-preview) + + Optional: + AZURE_DEPLOYMENT_NAME: Your deployment name (can also be passed as model parameter) + """ + + def __init__( + self, + model: str, + base_url: Optional[str] = None, + **kwargs, + ): + super().__init__(model, base_url, **kwargs) + # In Azure, model is the deployment name + self.deployment_name = model + + def get_llm(self) -> Any: + """Return configured AzureChatOpenAI instance.""" + llm_kwargs = { + "azure_deployment": self.deployment_name, + } + + # Pass through Azure-specific kwargs + for key in ("timeout", "max_retries", "callbacks", "api_version", + "azure_endpoint", "api_key", "http_client", "http_async_client"): + if key in self.kwargs: + llm_kwargs[key] = self.kwargs[key] + + # Environment variable fallbacks + if "azure_endpoint" not in llm_kwargs: + endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT") + if endpoint: + llm_kwargs["azure_endpoint"] = endpoint + + if "api_key" not in llm_kwargs: + api_key = os.environ.get("AZURE_OPENAI_API_KEY") + if api_key: + llm_kwargs["api_key"] = api_key + + if "api_version" not in llm_kwargs: + api_version = os.environ.get("OPENAI_API_VERSION", "2024-02-15-preview") + llm_kwargs["api_version"] = api_version + + return AzureChatOpenAI(**llm_kwargs) + + def validate_model(self) -> bool: + """Validate Azure deployment name (always valid, Azure validates at runtime).""" + return bool(self.deployment_name) \ No newline at end of file diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index 93c2a7d3..7c640e7d 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -4,6 +4,7 @@ from .base_client import BaseLLMClient from .openai_client import OpenAIClient from .anthropic_client import AnthropicClient from .google_client import GoogleClient +from .azure_client import AzureClient def create_llm_client( @@ -15,8 +16,8 @@ def create_llm_client( """Create an LLM client for the specified provider. Args: - provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter) - model: Model name/identifier + provider: LLM provider (openai, anthropic, google, azure, xai, ollama, openrouter) + model: Model name/identifier (for Azure, use deployment name) base_url: Optional base URL for API endpoint **kwargs: Additional provider-specific arguments - http_client: Custom httpx.Client for SSL proxy or certificate customization @@ -46,4 +47,7 @@ def create_llm_client( if provider_lower == "google": return GoogleClient(model, base_url, **kwargs) + if provider_lower == "azure": + return AzureClient(model, base_url, **kwargs) + raise ValueError(f"Unsupported LLM provider: {provider}")