diff --git a/cli/utils.py b/cli/utils.py index 5a8ec16c..5ae24515 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -14,6 +14,22 @@ ANALYST_ORDER = [ ("Fundamentals Analyst", AnalystType.FUNDAMENTALS), ] +# Shared model lists for providers with identical options in shallow/deep selections +# Azure: deployment names are user-specific, so we provide common defaults +AZURE_MODELS = [ + ("GPT-4o (your deployment)", "gpt-4o"), + ("GPT-4o-mini (your deployment)", "gpt-4o-mini"), + ("GPT-4 Turbo (your deployment)", "gpt-4-turbo"), + ("Custom deployment name", "__custom__"), +] + +# Ollama: same models for both shallow and deep (local inference) +OLLAMA_MODELS = [ + ("Qwen3:latest (8B, local)", "qwen3:latest"), + ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), + ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), +] + def get_ticker() -> str: """Prompt the user to enter a ticker symbol.""" @@ -159,11 +175,8 @@ def select_shallow_thinking_agent(provider) -> str: ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), ], - "ollama": [ - ("Qwen3:latest (8B, local)", "qwen3:latest"), - ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), - ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), - ], + "ollama": OLLAMA_MODELS, + "azure": AZURE_MODELS, } choice = questionary.select( @@ -226,11 +239,8 @@ def select_deep_thinking_agent(provider) -> str: ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), ], - "ollama": [ - ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), - ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), - ("Qwen3:latest (8B, local)", "qwen3:latest"), - ], + "ollama": OLLAMA_MODELS, + "azure": AZURE_MODELS, } choice = questionary.select( @@ -265,6 +275,7 @@ def select_llm_provider() -> tuple[str, str]: ("xAI", "https://api.x.ai/v1"), ("Openrouter", "https://openrouter.ai/api/v1"), ("Ollama", "http://localhost:11434/v1"), + ("Azure OpenAI", "azure"), ] choice = questionary.select( diff --git a/tradingagents/llm_clients/azure_client.py b/tradingagents/llm_clients/azure_client.py new file mode 100644 index 00000000..b46e8bd6 --- /dev/null +++ b/tradingagents/llm_clients/azure_client.py @@ -0,0 +1,63 @@ +"""Azure OpenAI client implementation.""" + +import os +from typing import Any, Optional + +from langchain_openai import AzureChatOpenAI + +from .base_client import BaseLLMClient + + +class AzureClient(BaseLLMClient): + """Client for Azure OpenAI provider. + + Required environment variables: + AZURE_OPENAI_API_KEY: Your Azure OpenAI API key + AZURE_OPENAI_ENDPOINT: Your Azure endpoint (e.g., https://your-resource.openai.azure.com/) + OPENAI_API_VERSION: API version (e.g., 2024-02-15-preview) + + Optional: + AZURE_DEPLOYMENT_NAME: Your deployment name (can also be passed as model parameter) + """ + + def __init__( + self, + model: str, + base_url: Optional[str] = None, + **kwargs, + ): + super().__init__(model, base_url, **kwargs) + # In Azure, model is the deployment name + self.deployment_name = model + + def get_llm(self) -> Any: + """Return configured AzureChatOpenAI instance.""" + llm_kwargs = { + "azure_deployment": self.deployment_name, + } + + # Pass through Azure-specific kwargs + for key in ("timeout", "max_retries", "callbacks", "api_version", + "azure_endpoint", "api_key", "http_client", "http_async_client"): + if key in self.kwargs: + llm_kwargs[key] = self.kwargs[key] + + # Environment variable fallbacks (DRY pattern) + for kwarg_name, env_var in { + "azure_endpoint": "AZURE_OPENAI_ENDPOINT", + "api_key": "AZURE_OPENAI_API_KEY", + }.items(): + if kwarg_name not in llm_kwargs: + value = os.environ.get(env_var) + if value: + llm_kwargs[kwarg_name] = value + + if "api_version" not in llm_kwargs: + api_version = os.environ.get("OPENAI_API_VERSION", "2024-02-15-preview") + llm_kwargs["api_version"] = api_version + + return AzureChatOpenAI(**llm_kwargs) + + def validate_model(self) -> bool: + """Validate Azure deployment name (always valid, Azure validates at runtime).""" + return bool(self.deployment_name) diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index 028c88a2..dc57d423 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -4,6 +4,7 @@ from .base_client import BaseLLMClient from .openai_client import OpenAIClient from .anthropic_client import AnthropicClient from .google_client import GoogleClient +from .azure_client import AzureClient def create_llm_client( @@ -15,8 +16,8 @@ def create_llm_client( """Create an LLM client for the specified provider. Args: - provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter) - model: Model name/identifier + provider: LLM provider (openai, anthropic, google, azure, xai, ollama, openrouter) + model: Model name/identifier (for Azure, use deployment name) base_url: Optional base URL for API endpoint **kwargs: Additional provider-specific arguments @@ -40,4 +41,7 @@ def create_llm_client( if provider_lower == "google": return GoogleClient(model, base_url, **kwargs) + if provider_lower == "azure": + return AzureClient(model, base_url, **kwargs) + raise ValueError(f"Unsupported LLM provider: {provider}") diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 924f24b0..f628c6ab 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -54,7 +54,13 @@ class OpenAIClient(BaseLLMClient): if api_key: llm_kwargs["api_key"] = api_key elif self.provider == "ollama": - llm_kwargs["base_url"] = "http://localhost:11434/v1" + # Support custom Ollama URL via base_url or OLLAMA_HOST env var + ollama_url = self.base_url or os.environ.get("OLLAMA_HOST", "http://localhost:11434") + # Ensure /v1 suffix for OpenAI-compatible API (handle trailing slash edge case) + ollama_url = ollama_url.rstrip("/") + if not ollama_url.endswith("/v1"): + ollama_url += "/v1" + llm_kwargs["base_url"] = ollama_url llm_kwargs["api_key"] = "ollama" # Ollama doesn't require auth elif self.base_url: llm_kwargs["base_url"] = self.base_url