diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 40cdff75..9b274697 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -6,12 +6,10 @@ import json from datetime import date from typing import Dict, Any, Tuple, List, Optional -from langchain_openai import ChatOpenAI -from langchain_anthropic import ChatAnthropic -from langchain_google_genai import ChatGoogleGenerativeAI - from langgraph.prebuilt import ToolNode +from tradingagents.llm_clients import create_llm_client + from tradingagents.agents import * from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.agents.utils.memory import FinancialSituationMemory @@ -72,17 +70,18 @@ class TradingAgentsGraph: ) # Initialize LLMs - if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": - self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) - self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) - elif self.config["llm_provider"].lower() == "anthropic": - self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) - self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) - elif self.config["llm_provider"].lower() == "google": - self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"]) - self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"]) - else: - raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}") + deep_client = create_llm_client( + provider=self.config["llm_provider"], + model=self.config["deep_think_llm"], + base_url=self.config.get("backend_url"), + ) + quick_client = create_llm_client( + provider=self.config["llm_provider"], + model=self.config["quick_think_llm"], + base_url=self.config.get("backend_url"), + ) + self.deep_thinking_llm = deep_client.get_llm() + self.quick_thinking_llm = quick_client.get_llm() # Initialize memories self.bull_memory = FinancialSituationMemory("bull_memory", self.config) diff --git a/tradingagents/llm_clients/TODO.md b/tradingagents/llm_clients/TODO.md new file mode 100644 index 00000000..d5b5ac9c --- /dev/null +++ b/tradingagents/llm_clients/TODO.md @@ -0,0 +1,24 @@ +# LLM Clients - Consistency Improvements + +## Issues to Fix + +### 1. `validate_model()` is never called +- Add validation call in `get_llm()` with warning (not error) for unknown models + +### 2. Inconsistent parameter handling +| Client | API Key Param | Special Params | +|--------|---------------|----------------| +| OpenAI | `api_key` | `reasoning_effort` | +| Anthropic | `api_key` | `thinking_config` → `thinking` | +| Google | `google_api_key` | `thinking_budget` | + +**Fix:** Standardize with unified `api_key` that maps to provider-specific keys + +### 3. `base_url` accepted but ignored +- `AnthropicClient`: accepts `base_url` but never uses it +- `GoogleClient`: accepts `base_url` but never uses it (correct - Google doesn't support it) + +**Fix:** Remove unused `base_url` from clients that don't support it + +### 4. Update validators.py with models from CLI +- Sync `VALID_MODELS` dict with CLI model options after Feature 2 is complete diff --git a/tradingagents/llm_clients/__init__.py b/tradingagents/llm_clients/__init__.py new file mode 100644 index 00000000..e528eabe --- /dev/null +++ b/tradingagents/llm_clients/__init__.py @@ -0,0 +1,4 @@ +from .base_client import BaseLLMClient +from .factory import create_llm_client + +__all__ = ["BaseLLMClient", "create_llm_client"] diff --git a/tradingagents/llm_clients/anthropic_client.py b/tradingagents/llm_clients/anthropic_client.py new file mode 100644 index 00000000..64cc0f43 --- /dev/null +++ b/tradingagents/llm_clients/anthropic_client.py @@ -0,0 +1,33 @@ +from typing import Any, Optional + +from langchain_anthropic import ChatAnthropic + +from .base_client import BaseLLMClient +from .validators import validate_model + + +class AnthropicClient(BaseLLMClient): + """Client for Anthropic Claude models.""" + + def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + super().__init__(model, base_url, **kwargs) + + def get_llm(self) -> Any: + """Return configured ChatAnthropic instance.""" + llm_kwargs = { + "model": self.model, + "max_tokens": self.kwargs.get("max_tokens", 4096), + } + + for key in ("timeout", "max_retries", "api_key"): + if key in self.kwargs: + llm_kwargs[key] = self.kwargs[key] + + if "thinking_config" in self.kwargs: + llm_kwargs["thinking"] = self.kwargs["thinking_config"] + + return ChatAnthropic(**llm_kwargs) + + def validate_model(self) -> bool: + """Validate model for Anthropic.""" + return validate_model("anthropic", self.model) diff --git a/tradingagents/llm_clients/base_client.py b/tradingagents/llm_clients/base_client.py new file mode 100644 index 00000000..43845575 --- /dev/null +++ b/tradingagents/llm_clients/base_client.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional + + +class BaseLLMClient(ABC): + """Abstract base class for LLM clients.""" + + def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + self.model = model + self.base_url = base_url + self.kwargs = kwargs + + @abstractmethod + def get_llm(self) -> Any: + """Return the configured LLM instance.""" + pass + + @abstractmethod + def validate_model(self) -> bool: + """Validate that the model is supported by this client.""" + pass diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py new file mode 100644 index 00000000..e10e83da --- /dev/null +++ b/tradingagents/llm_clients/factory.py @@ -0,0 +1,47 @@ +from typing import Optional + +from .base_client import BaseLLMClient +from .openai_client import OpenAIClient +from .anthropic_client import AnthropicClient +from .google_client import GoogleClient +from .vllm_client import VLLMClient + + +def create_llm_client( + provider: str, + model: str, + base_url: Optional[str] = None, + **kwargs, +) -> BaseLLMClient: + """Create an LLM client for the specified provider. + + Args: + provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm) + model: Model name/identifier + base_url: Optional base URL for API endpoint + **kwargs: Additional provider-specific arguments + + Returns: + Configured BaseLLMClient instance + + Raises: + ValueError: If provider is not supported + """ + provider_lower = provider.lower() + + if provider_lower in ("openai", "ollama", "openrouter"): + return OpenAIClient(model, base_url, provider=provider_lower, **kwargs) + + if provider_lower == "xai": + return OpenAIClient(model, base_url, provider="xai", **kwargs) + + if provider_lower == "anthropic": + return AnthropicClient(model, base_url, **kwargs) + + if provider_lower == "google": + return GoogleClient(model, base_url, **kwargs) + + if provider_lower == "vllm": + return VLLMClient(model, base_url, **kwargs) + + raise ValueError(f"Unsupported LLM provider: {provider}") diff --git a/tradingagents/llm_clients/google_client.py b/tradingagents/llm_clients/google_client.py new file mode 100644 index 00000000..2ebc19e5 --- /dev/null +++ b/tradingagents/llm_clients/google_client.py @@ -0,0 +1,34 @@ +from typing import Any, Optional + +from langchain_google_genai import ChatGoogleGenerativeAI + +from .base_client import BaseLLMClient +from .validators import validate_model + + +class GoogleClient(BaseLLMClient): + """Client for Google Gemini models.""" + + def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + super().__init__(model, base_url, **kwargs) + + def get_llm(self) -> Any: + """Return configured ChatGoogleGenerativeAI instance.""" + llm_kwargs = {"model": self.model} + + for key in ("timeout", "max_retries", "google_api_key"): + if key in self.kwargs: + llm_kwargs[key] = self.kwargs[key] + + if "thinking_budget" in self.kwargs and self._is_preview_model(): + llm_kwargs["thinking_budget"] = self.kwargs["thinking_budget"] + + return ChatGoogleGenerativeAI(**llm_kwargs) + + def _is_preview_model(self) -> bool: + """Check if this is a preview model that supports thinking budget.""" + return "preview" in self.model.lower() + + def validate_model(self) -> bool: + """Validate model for Google.""" + return validate_model("google", self.model) diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py new file mode 100644 index 00000000..3c838fa9 --- /dev/null +++ b/tradingagents/llm_clients/openai_client.py @@ -0,0 +1,64 @@ +import os +from typing import Any, Optional + +from langchain_openai import ChatOpenAI + +from .base_client import BaseLLMClient +from .validators import validate_model + + +class UnifiedChatOpenAI(ChatOpenAI): + """ChatOpenAI subclass that strips incompatible params for certain models.""" + + def __init__(self, **kwargs): + model = kwargs.get("model", "") + if self._is_reasoning_model(model): + kwargs.pop("temperature", None) + kwargs.pop("top_p", None) + super().__init__(**kwargs) + + @staticmethod + def _is_reasoning_model(model: str) -> bool: + """Check if model is a reasoning model that doesn't support temperature.""" + model_lower = model.lower() + return ( + model_lower.startswith("o1") + or model_lower.startswith("o3") + or "gpt-5" in model_lower + ) + + +class OpenAIClient(BaseLLMClient): + """Client for OpenAI, Ollama, OpenRouter, and xAI providers.""" + + def __init__( + self, + model: str, + base_url: Optional[str] = None, + provider: str = "openai", + **kwargs, + ): + super().__init__(model, base_url, **kwargs) + self.provider = provider.lower() + + def get_llm(self) -> Any: + """Return configured ChatOpenAI instance.""" + llm_kwargs = {"model": self.model} + + if self.provider == "xai": + llm_kwargs["base_url"] = "https://api.x.ai/v1" + api_key = os.environ.get("XAI_API_KEY") + if api_key: + llm_kwargs["api_key"] = api_key + elif self.base_url: + llm_kwargs["base_url"] = self.base_url + + for key in ("timeout", "max_retries", "reasoning_effort", "api_key"): + if key in self.kwargs: + llm_kwargs[key] = self.kwargs[key] + + return UnifiedChatOpenAI(**llm_kwargs) + + def validate_model(self) -> bool: + """Validate model for the provider.""" + return validate_model(self.provider, self.model) diff --git a/tradingagents/llm_clients/validators.py b/tradingagents/llm_clients/validators.py new file mode 100644 index 00000000..526dc37c --- /dev/null +++ b/tradingagents/llm_clients/validators.py @@ -0,0 +1,69 @@ +from typing import Dict, List + +VALID_MODELS: Dict[str, List[str]] = { + "openai": [ + "gpt-4o", + "gpt-4o-mini", + "gpt-4-turbo", + "gpt-4", + "gpt-3.5-turbo", + "o1", + "o1-mini", + "o1-preview", + "o3-mini", + "gpt-5-nano", + "gpt-5-mini", + "gpt-5", + ], + "anthropic": [ + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", + "claude-sonnet-4-20250514", + "claude-haiku-4-5-20251001", + "claude-opus-4-5-20251101", + ], + "google": [ + "gemini-1.5-pro", + "gemini-1.5-flash", + "gemini-2.0-flash", + "gemini-2.0-flash-lite", + "gemini-2.5-pro-preview-05-06", + "gemini-2.5-flash-preview-05-20", + "gemini-3-pro-preview", + "gemini-3-flash-preview", + ], + "xai": [ + "grok-beta", + "grok-2", + "grok-2-mini", + "grok-3", + "grok-3-mini", + ], + "ollama": [], + "openrouter": [], + "vllm": [], +} + + +def validate_model(provider: str, model: str) -> bool: + """Validate that a model is supported by the provider. + + For ollama, openrouter, and vllm, any model is accepted. + For other providers, checks against VALID_MODELS. + """ + provider_lower = provider.lower() + + if provider_lower in ("ollama", "openrouter", "vllm"): + return True + + if provider_lower not in VALID_MODELS: + return False + + valid = VALID_MODELS[provider_lower] + if not valid: + return True + + return model in valid diff --git a/tradingagents/llm_clients/vllm_client.py b/tradingagents/llm_clients/vllm_client.py new file mode 100644 index 00000000..a1ebfebf --- /dev/null +++ b/tradingagents/llm_clients/vllm_client.py @@ -0,0 +1,18 @@ +from typing import Any, Optional + +from .base_client import BaseLLMClient + + +class VLLMClient(BaseLLMClient): + """Client for vLLM (placeholder for future implementation).""" + + def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + super().__init__(model, base_url, **kwargs) + + def get_llm(self) -> Any: + """Return configured vLLM instance.""" + raise NotImplementedError("vLLM client not yet implemented") + + def validate_model(self) -> bool: + """Validate model for vLLM.""" + return True