feat: add multi-provider LLM support with factory pattern

- Add tradingagents/llm_clients/ with unified factory pattern
- Support OpenAI, Anthropic, Google, xAI, OpenRouter, Ollama, vLLM
- Replace direct LLM imports in trading_graph.py with create_llm_client()
- Handle provider-specific params (reasoning_effort, thinking_config)
This commit is contained in:
Yijia Xiao 2026-01-20 06:52:18 +00:00
parent 13b826a31d
commit 79051580b8
No known key found for this signature in database
10 changed files with 328 additions and 15 deletions

View File

@ -6,12 +6,10 @@ import json
from datetime import date
from typing import Dict, Any, Tuple, List, Optional
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.prebuilt import ToolNode
from tradingagents.llm_clients import create_llm_client
from tradingagents.agents import *
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import FinancialSituationMemory
@ -72,17 +70,18 @@ class TradingAgentsGraph:
)
# Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "anthropic":
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "google":
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
else:
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
deep_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["deep_think_llm"],
base_url=self.config.get("backend_url"),
)
quick_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["quick_think_llm"],
base_url=self.config.get("backend_url"),
)
self.deep_thinking_llm = deep_client.get_llm()
self.quick_thinking_llm = quick_client.get_llm()
# Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)

View File

@ -0,0 +1,24 @@
# LLM Clients - Consistency Improvements
## Issues to Fix
### 1. `validate_model()` is never called
- Add validation call in `get_llm()` with warning (not error) for unknown models
### 2. Inconsistent parameter handling
| Client | API Key Param | Special Params |
|--------|---------------|----------------|
| OpenAI | `api_key` | `reasoning_effort` |
| Anthropic | `api_key` | `thinking_config``thinking` |
| Google | `google_api_key` | `thinking_budget` |
**Fix:** Standardize with unified `api_key` that maps to provider-specific keys
### 3. `base_url` accepted but ignored
- `AnthropicClient`: accepts `base_url` but never uses it
- `GoogleClient`: accepts `base_url` but never uses it (correct - Google doesn't support it)
**Fix:** Remove unused `base_url` from clients that don't support it
### 4. Update validators.py with models from CLI
- Sync `VALID_MODELS` dict with CLI model options after Feature 2 is complete

View File

@ -0,0 +1,4 @@
from .base_client import BaseLLMClient
from .factory import create_llm_client
__all__ = ["BaseLLMClient", "create_llm_client"]

View File

@ -0,0 +1,33 @@
from typing import Any, Optional
from langchain_anthropic import ChatAnthropic
from .base_client import BaseLLMClient
from .validators import validate_model
class AnthropicClient(BaseLLMClient):
"""Client for Anthropic Claude models."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
"""Return configured ChatAnthropic instance."""
llm_kwargs = {
"model": self.model,
"max_tokens": self.kwargs.get("max_tokens", 4096),
}
for key in ("timeout", "max_retries", "api_key"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
if "thinking_config" in self.kwargs:
llm_kwargs["thinking"] = self.kwargs["thinking_config"]
return ChatAnthropic(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate model for Anthropic."""
return validate_model("anthropic", self.model)

View File

@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
class BaseLLMClient(ABC):
"""Abstract base class for LLM clients."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
self.model = model
self.base_url = base_url
self.kwargs = kwargs
@abstractmethod
def get_llm(self) -> Any:
"""Return the configured LLM instance."""
pass
@abstractmethod
def validate_model(self) -> bool:
"""Validate that the model is supported by this client."""
pass

View File

@ -0,0 +1,47 @@
from typing import Optional
from .base_client import BaseLLMClient
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
from .google_client import GoogleClient
from .vllm_client import VLLMClient
def create_llm_client(
provider: str,
model: str,
base_url: Optional[str] = None,
**kwargs,
) -> BaseLLMClient:
"""Create an LLM client for the specified provider.
Args:
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm)
model: Model name/identifier
base_url: Optional base URL for API endpoint
**kwargs: Additional provider-specific arguments
Returns:
Configured BaseLLMClient instance
Raises:
ValueError: If provider is not supported
"""
provider_lower = provider.lower()
if provider_lower in ("openai", "ollama", "openrouter"):
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
if provider_lower == "xai":
return OpenAIClient(model, base_url, provider="xai", **kwargs)
if provider_lower == "anthropic":
return AnthropicClient(model, base_url, **kwargs)
if provider_lower == "google":
return GoogleClient(model, base_url, **kwargs)
if provider_lower == "vllm":
return VLLMClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@ -0,0 +1,34 @@
from typing import Any, Optional
from langchain_google_genai import ChatGoogleGenerativeAI
from .base_client import BaseLLMClient
from .validators import validate_model
class GoogleClient(BaseLLMClient):
"""Client for Google Gemini models."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
"""Return configured ChatGoogleGenerativeAI instance."""
llm_kwargs = {"model": self.model}
for key in ("timeout", "max_retries", "google_api_key"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
if "thinking_budget" in self.kwargs and self._is_preview_model():
llm_kwargs["thinking_budget"] = self.kwargs["thinking_budget"]
return ChatGoogleGenerativeAI(**llm_kwargs)
def _is_preview_model(self) -> bool:
"""Check if this is a preview model that supports thinking budget."""
return "preview" in self.model.lower()
def validate_model(self) -> bool:
"""Validate model for Google."""
return validate_model("google", self.model)

View File

@ -0,0 +1,64 @@
import os
from typing import Any, Optional
from langchain_openai import ChatOpenAI
from .base_client import BaseLLMClient
from .validators import validate_model
class UnifiedChatOpenAI(ChatOpenAI):
"""ChatOpenAI subclass that strips incompatible params for certain models."""
def __init__(self, **kwargs):
model = kwargs.get("model", "")
if self._is_reasoning_model(model):
kwargs.pop("temperature", None)
kwargs.pop("top_p", None)
super().__init__(**kwargs)
@staticmethod
def _is_reasoning_model(model: str) -> bool:
"""Check if model is a reasoning model that doesn't support temperature."""
model_lower = model.lower()
return (
model_lower.startswith("o1")
or model_lower.startswith("o3")
or "gpt-5" in model_lower
)
class OpenAIClient(BaseLLMClient):
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers."""
def __init__(
self,
model: str,
base_url: Optional[str] = None,
provider: str = "openai",
**kwargs,
):
super().__init__(model, base_url, **kwargs)
self.provider = provider.lower()
def get_llm(self) -> Any:
"""Return configured ChatOpenAI instance."""
llm_kwargs = {"model": self.model}
if self.provider == "xai":
llm_kwargs["base_url"] = "https://api.x.ai/v1"
api_key = os.environ.get("XAI_API_KEY")
if api_key:
llm_kwargs["api_key"] = api_key
elif self.base_url:
llm_kwargs["base_url"] = self.base_url
for key in ("timeout", "max_retries", "reasoning_effort", "api_key"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
return UnifiedChatOpenAI(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate model for the provider."""
return validate_model(self.provider, self.model)

View File

@ -0,0 +1,69 @@
from typing import Dict, List
VALID_MODELS: Dict[str, List[str]] = {
"openai": [
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-4",
"gpt-3.5-turbo",
"o1",
"o1-mini",
"o1-preview",
"o3-mini",
"gpt-5-nano",
"gpt-5-mini",
"gpt-5",
],
"anthropic": [
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
"claude-sonnet-4-20250514",
"claude-haiku-4-5-20251001",
"claude-opus-4-5-20251101",
],
"google": [
"gemini-1.5-pro",
"gemini-1.5-flash",
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
"gemini-2.5-pro-preview-05-06",
"gemini-2.5-flash-preview-05-20",
"gemini-3-pro-preview",
"gemini-3-flash-preview",
],
"xai": [
"grok-beta",
"grok-2",
"grok-2-mini",
"grok-3",
"grok-3-mini",
],
"ollama": [],
"openrouter": [],
"vllm": [],
}
def validate_model(provider: str, model: str) -> bool:
"""Validate that a model is supported by the provider.
For ollama, openrouter, and vllm, any model is accepted.
For other providers, checks against VALID_MODELS.
"""
provider_lower = provider.lower()
if provider_lower in ("ollama", "openrouter", "vllm"):
return True
if provider_lower not in VALID_MODELS:
return False
valid = VALID_MODELS[provider_lower]
if not valid:
return True
return model in valid

View File

@ -0,0 +1,18 @@
from typing import Any, Optional
from .base_client import BaseLLMClient
class VLLMClient(BaseLLMClient):
"""Client for vLLM (placeholder for future implementation)."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
"""Return configured vLLM instance."""
raise NotImplementedError("vLLM client not yet implemented")
def validate_model(self) -> bool:
"""Validate model for vLLM."""
return True