diff --git a/.env.enterprise.example b/.env.enterprise.example new file mode 100644 index 00000000..4f7bda3d --- /dev/null +++ b/.env.enterprise.example @@ -0,0 +1,5 @@ +# Azure OpenAI +AZURE_OPENAI_API_KEY= +AZURE_OPENAI_ENDPOINT=https://your-resource-name.openai.azure.com/ +AZURE_OPENAI_DEPLOYMENT_NAME= +# OPENAI_API_VERSION=2024-10-21 # optional, required for non-v1 API diff --git a/.env.example b/.env.example index 1328b838..be9bf13e 100644 --- a/.env.example +++ b/.env.example @@ -3,4 +3,7 @@ OPENAI_API_KEY= GOOGLE_API_KEY= ANTHROPIC_API_KEY= XAI_API_KEY= +DEEPSEEK_API_KEY= +DASHSCOPE_API_KEY= +ZHIPU_API_KEY= OPENROUTER_API_KEY= diff --git a/README.md b/README.md index 9a92bff9..97cbde48 100644 --- a/README.md +++ b/README.md @@ -140,10 +140,15 @@ export OPENAI_API_KEY=... # OpenAI (GPT) export GOOGLE_API_KEY=... # Google (Gemini) export ANTHROPIC_API_KEY=... # Anthropic (Claude) export XAI_API_KEY=... # xAI (Grok) +export DEEPSEEK_API_KEY=... # DeepSeek +export DASHSCOPE_API_KEY=... # Qwen (Alibaba DashScope) +export ZHIPU_API_KEY=... # GLM (Zhipu) export OPENROUTER_API_KEY=... # OpenRouter export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage ``` +For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials. + For local models, configure Ollama with `llm_provider: "ollama"` in your config. Alternatively, copy `.env.example` to `.env` and fill in your keys: diff --git a/cli/main.py b/cli/main.py index 29294d8d..52e8a332 100644 --- a/cli/main.py +++ b/cli/main.py @@ -6,8 +6,9 @@ from functools import wraps from rich.console import Console from dotenv import load_dotenv -# Load environment variables from .env file +# Load environment variables load_dotenv() +load_dotenv(".env.enterprise", override=False) from rich.panel import Panel from rich.spinner import Spinner from rich.live import Live diff --git a/cli/utils.py b/cli/utils.py index e071ce06..85c282ed 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -174,17 +174,30 @@ def select_openrouter_model() -> str: return choice -def select_shallow_thinking_agent(provider) -> str: - """Select shallow thinking llm engine using an interactive selection.""" +def _prompt_custom_model_id() -> str: + """Prompt user to type a custom model ID.""" + return questionary.text( + "Enter model ID:", + validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.", + ).ask().strip() + +def _select_model(provider: str, mode: str) -> str: + """Select a model for the given provider and mode (quick/deep).""" if provider.lower() == "openrouter": return select_openrouter_model() + if provider.lower() == "azure": + return questionary.text( + f"Enter Azure deployment name ({mode}-thinking):", + validate=lambda x: len(x.strip()) > 0 or "Please enter a deployment name.", + ).ask().strip() + choice = questionary.select( - "Select Your [Quick-Thinking LLM Engine]:", + f"Select Your [{mode.title()}-Thinking LLM Engine]:", choices=[ questionary.Choice(display, value=value) - for display, value in get_model_options(provider, "quick") + for display, value in get_model_options(provider, mode) ], instruction="\n- Use arrow keys to navigate\n- Press Enter to select", style=questionary.Style( @@ -197,58 +210,45 @@ def select_shallow_thinking_agent(provider) -> str: ).ask() if choice is None: - console.print( - "\n[red]No shallow thinking llm engine selected. Exiting...[/red]" - ) + console.print(f"\n[red]No {mode} thinking llm engine selected. Exiting...[/red]") exit(1) + if choice == "custom": + return _prompt_custom_model_id() + return choice +def select_shallow_thinking_agent(provider) -> str: + """Select shallow thinking llm engine using an interactive selection.""" + return _select_model(provider, "quick") + + def select_deep_thinking_agent(provider) -> str: """Select deep thinking llm engine using an interactive selection.""" - - if provider.lower() == "openrouter": - return select_openrouter_model() - - choice = questionary.select( - "Select Your [Deep-Thinking LLM Engine]:", - choices=[ - questionary.Choice(display, value=value) - for display, value in get_model_options(provider, "deep") - ], - instruction="\n- Use arrow keys to navigate\n- Press Enter to select", - style=questionary.Style( - [ - ("selected", "fg:magenta noinherit"), - ("highlighted", "fg:magenta noinherit"), - ("pointer", "fg:magenta noinherit"), - ] - ), - ).ask() - - if choice is None: - console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]") - exit(1) - - return choice + return _select_model(provider, "deep") def select_llm_provider() -> tuple[str, str | None]: """Select the LLM provider and its API endpoint.""" - BASE_URLS = [ - ("OpenAI", "https://api.openai.com/v1"), - ("Google", None), # google-genai SDK manages its own endpoint - ("Anthropic", "https://api.anthropic.com/"), - ("xAI", "https://api.x.ai/v1"), - ("Openrouter", "https://openrouter.ai/api/v1"), - ("Ollama", "http://localhost:11434/v1"), + # (display_name, provider_key, base_url) + PROVIDERS = [ + ("OpenAI", "openai", "https://api.openai.com/v1"), + ("Google", "google", None), + ("Anthropic", "anthropic", "https://api.anthropic.com/"), + ("xAI", "xai", "https://api.x.ai/v1"), + ("DeepSeek", "deepseek", "https://api.deepseek.com"), + ("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"), + ("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"), + ("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"), + ("Azure OpenAI", "azure", None), + ("Ollama", "ollama", "http://localhost:11434/v1"), ] - + choice = questionary.select( "Select your LLM Provider:", choices=[ - questionary.Choice(display, value=(display, value)) - for display, value in BASE_URLS + questionary.Choice(display, value=(provider_key, url)) + for display, provider_key, url in PROVIDERS ], instruction="\n- Use arrow keys to navigate\n- Press Enter to select", style=questionary.Style( @@ -261,13 +261,11 @@ def select_llm_provider() -> tuple[str, str | None]: ).ask() if choice is None: - console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") + console.print("\n[red]No LLM provider selected. Exiting...[/red]") exit(1) - - display_name, url = choice - print(f"You selected: {display_name}\tURL: {url}") - return display_name, url + provider, url = choice + return provider, url def ask_openai_reasoning_effort() -> str: diff --git a/tradingagents/llm_clients/azure_client.py b/tradingagents/llm_clients/azure_client.py new file mode 100644 index 00000000..0c0ae5a4 --- /dev/null +++ b/tradingagents/llm_clients/azure_client.py @@ -0,0 +1,52 @@ +import os +from typing import Any, Optional + +from langchain_openai import AzureChatOpenAI + +from .base_client import BaseLLMClient, normalize_content +from .validators import validate_model + +_PASSTHROUGH_KWARGS = ( + "timeout", "max_retries", "api_key", "reasoning_effort", + "callbacks", "http_client", "http_async_client", +) + + +class NormalizedAzureChatOpenAI(AzureChatOpenAI): + """AzureChatOpenAI with normalized content output.""" + + def invoke(self, input, config=None, **kwargs): + return normalize_content(super().invoke(input, config, **kwargs)) + + +class AzureOpenAIClient(BaseLLMClient): + """Client for Azure OpenAI deployments. + + Requires environment variables: + AZURE_OPENAI_API_KEY: API key + AZURE_OPENAI_ENDPOINT: Endpoint URL (e.g. https://.openai.azure.com/) + AZURE_OPENAI_DEPLOYMENT_NAME: Deployment name + OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview) + """ + + def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + super().__init__(model, base_url, **kwargs) + + def get_llm(self) -> Any: + """Return configured AzureChatOpenAI instance.""" + self.warn_if_unknown_model() + + llm_kwargs = { + "model": self.model, + "azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", self.model), + } + + for key in _PASSTHROUGH_KWARGS: + if key in self.kwargs: + llm_kwargs[key] = self.kwargs[key] + + return NormalizedAzureChatOpenAI(**llm_kwargs) + + def validate_model(self) -> bool: + """Azure accepts any deployed model name.""" + return True diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index 93c2a7d3..a9a7e83d 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -4,6 +4,12 @@ from .base_client import BaseLLMClient from .openai_client import OpenAIClient from .anthropic_client import AnthropicClient from .google_client import GoogleClient +from .azure_client import AzureOpenAIClient + +# Providers that use the OpenAI-compatible chat completions API +_OPENAI_COMPATIBLE = ( + "openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter", +) def create_llm_client( @@ -15,16 +21,10 @@ def create_llm_client( """Create an LLM client for the specified provider. Args: - provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter) + provider: LLM provider name model: Model name/identifier base_url: Optional base URL for API endpoint **kwargs: Additional provider-specific arguments - - http_client: Custom httpx.Client for SSL proxy or certificate customization - - http_async_client: Custom httpx.AsyncClient for async operations - - timeout: Request timeout in seconds - - max_retries: Maximum retry attempts - - api_key: API key for the provider - - callbacks: LangChain callbacks Returns: Configured BaseLLMClient instance @@ -34,16 +34,16 @@ def create_llm_client( """ provider_lower = provider.lower() - if provider_lower in ("openai", "ollama", "openrouter"): + if provider_lower in _OPENAI_COMPATIBLE: return OpenAIClient(model, base_url, provider=provider_lower, **kwargs) - if provider_lower == "xai": - return OpenAIClient(model, base_url, provider="xai", **kwargs) - if provider_lower == "anthropic": return AnthropicClient(model, base_url, **kwargs) if provider_lower == "google": return GoogleClient(model, base_url, **kwargs) + if provider_lower == "azure": + return AzureOpenAIClient(model, base_url, **kwargs) + raise ValueError(f"Unsupported LLM provider: {provider}") diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index fd91c66d..a2c57ed8 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -63,8 +63,43 @@ MODEL_OPTIONS: ProviderModeOptions = { ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), ], }, - # OpenRouter models are fetched dynamically at CLI runtime. - # No static entries needed; any model ID is accepted by the validator. + "deepseek": { + "quick": [ + ("DeepSeek V3.2", "deepseek-chat"), + ("Custom model ID", "custom"), + ], + "deep": [ + ("DeepSeek V3.2 (thinking)", "deepseek-reasoner"), + ("DeepSeek V3.2", "deepseek-chat"), + ("Custom model ID", "custom"), + ], + }, + "qwen": { + "quick": [ + ("Qwen 3.5 Flash", "qwen3.5-flash"), + ("Qwen Plus", "qwen-plus"), + ("Custom model ID", "custom"), + ], + "deep": [ + ("Qwen 3.6 Plus", "qwen3.6-plus"), + ("Qwen 3.5 Plus", "qwen3.5-plus"), + ("Qwen 3 Max", "qwen3-max"), + ("Custom model ID", "custom"), + ], + }, + "glm": { + "quick": [ + ("GLM-4.7", "glm-4.7"), + ("GLM-5", "glm-5"), + ("Custom model ID", "custom"), + ], + "deep": [ + ("GLM-5.1", "glm-5.1"), + ("GLM-5", "glm-5"), + ("Custom model ID", "custom"), + ], + }, + # OpenRouter: fetched dynamically. Azure: any deployed model name. "ollama": { "quick": [ ("Qwen3:latest (8B, local)", "qwen3:latest"), diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 4f2e1b32..f943124a 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -27,6 +27,9 @@ _PASSTHROUGH_KWARGS = ( # Provider base URLs and API key env vars _PROVIDER_CONFIG = { "xai": ("https://api.x.ai/v1", "XAI_API_KEY"), + "deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"), + "qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"), + "glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"), "openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"), "ollama": ("http://localhost:11434/v1", None), }