feat: add multi-provider LLM support with factory pattern
- Add tradingagents/llm_clients/ with unified factory pattern - Support OpenAI, Anthropic, Google, xAI, OpenRouter, Ollama, vLLM - Replace direct LLM imports in trading_graph.py with create_llm_client() - Handle provider-specific params (reasoning_effort, thinking_config)
This commit is contained in:
parent
13b826a31d
commit
79051580b8
|
|
@ -6,12 +6,10 @@ import json
|
|||
from datetime import date
|
||||
from typing import Dict, Any, Tuple, List, Optional
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from tradingagents.llm_clients import create_llm_client
|
||||
|
||||
from tradingagents.agents import *
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
|
|
@ -72,17 +70,18 @@ class TradingAgentsGraph:
|
|||
)
|
||||
|
||||
# Initialize LLMs
|
||||
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
|
||||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||||
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||||
elif self.config["llm_provider"].lower() == "anthropic":
|
||||
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||||
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||||
elif self.config["llm_provider"].lower() == "google":
|
||||
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
|
||||
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
||||
deep_client = create_llm_client(
|
||||
provider=self.config["llm_provider"],
|
||||
model=self.config["deep_think_llm"],
|
||||
base_url=self.config.get("backend_url"),
|
||||
)
|
||||
quick_client = create_llm_client(
|
||||
provider=self.config["llm_provider"],
|
||||
model=self.config["quick_think_llm"],
|
||||
base_url=self.config.get("backend_url"),
|
||||
)
|
||||
self.deep_thinking_llm = deep_client.get_llm()
|
||||
self.quick_thinking_llm = quick_client.get_llm()
|
||||
|
||||
# Initialize memories
|
||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
# LLM Clients - Consistency Improvements
|
||||
|
||||
## Issues to Fix
|
||||
|
||||
### 1. `validate_model()` is never called
|
||||
- Add validation call in `get_llm()` with warning (not error) for unknown models
|
||||
|
||||
### 2. Inconsistent parameter handling
|
||||
| Client | API Key Param | Special Params |
|
||||
|--------|---------------|----------------|
|
||||
| OpenAI | `api_key` | `reasoning_effort` |
|
||||
| Anthropic | `api_key` | `thinking_config` → `thinking` |
|
||||
| Google | `google_api_key` | `thinking_budget` |
|
||||
|
||||
**Fix:** Standardize with unified `api_key` that maps to provider-specific keys
|
||||
|
||||
### 3. `base_url` accepted but ignored
|
||||
- `AnthropicClient`: accepts `base_url` but never uses it
|
||||
- `GoogleClient`: accepts `base_url` but never uses it (correct - Google doesn't support it)
|
||||
|
||||
**Fix:** Remove unused `base_url` from clients that don't support it
|
||||
|
||||
### 4. Update validators.py with models from CLI
|
||||
- Sync `VALID_MODELS` dict with CLI model options after Feature 2 is complete
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
from .base_client import BaseLLMClient
|
||||
from .factory import create_llm_client
|
||||
|
||||
__all__ = ["BaseLLMClient", "create_llm_client"]
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
from .validators import validate_model
|
||||
|
||||
|
||||
class AnthropicClient(BaseLLMClient):
|
||||
"""Client for Anthropic Claude models."""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured ChatAnthropic instance."""
|
||||
llm_kwargs = {
|
||||
"model": self.model,
|
||||
"max_tokens": self.kwargs.get("max_tokens", 4096),
|
||||
}
|
||||
|
||||
for key in ("timeout", "max_retries", "api_key"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
if "thinking_config" in self.kwargs:
|
||||
llm_kwargs["thinking"] = self.kwargs["thinking_config"]
|
||||
|
||||
return ChatAnthropic(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
"""Validate model for Anthropic."""
|
||||
return validate_model("anthropic", self.model)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class BaseLLMClient(ABC):
|
||||
"""Abstract base class for LLM clients."""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.kwargs = kwargs
|
||||
|
||||
@abstractmethod
|
||||
def get_llm(self) -> Any:
|
||||
"""Return the configured LLM instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_model(self) -> bool:
|
||||
"""Validate that the model is supported by this client."""
|
||||
pass
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
from typing import Optional
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
from .openai_client import OpenAIClient
|
||||
from .anthropic_client import AnthropicClient
|
||||
from .google_client import GoogleClient
|
||||
from .vllm_client import VLLMClient
|
||||
|
||||
|
||||
def create_llm_client(
|
||||
provider: str,
|
||||
model: str,
|
||||
base_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> BaseLLMClient:
|
||||
"""Create an LLM client for the specified provider.
|
||||
|
||||
Args:
|
||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm)
|
||||
model: Model name/identifier
|
||||
base_url: Optional base URL for API endpoint
|
||||
**kwargs: Additional provider-specific arguments
|
||||
|
||||
Returns:
|
||||
Configured BaseLLMClient instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not supported
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower in ("openai", "ollama", "openrouter"):
|
||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||
|
||||
if provider_lower == "xai":
|
||||
return OpenAIClient(model, base_url, provider="xai", **kwargs)
|
||||
|
||||
if provider_lower == "anthropic":
|
||||
return AnthropicClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "google":
|
||||
return GoogleClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "vllm":
|
||||
return VLLMClient(model, base_url, **kwargs)
|
||||
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
from .validators import validate_model
|
||||
|
||||
|
||||
class GoogleClient(BaseLLMClient):
|
||||
"""Client for Google Gemini models."""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured ChatGoogleGenerativeAI instance."""
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
for key in ("timeout", "max_retries", "google_api_key"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
if "thinking_budget" in self.kwargs and self._is_preview_model():
|
||||
llm_kwargs["thinking_budget"] = self.kwargs["thinking_budget"]
|
||||
|
||||
return ChatGoogleGenerativeAI(**llm_kwargs)
|
||||
|
||||
def _is_preview_model(self) -> bool:
|
||||
"""Check if this is a preview model that supports thinking budget."""
|
||||
return "preview" in self.model.lower()
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
"""Validate model for Google."""
|
||||
return validate_model("google", self.model)
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
from .validators import validate_model
|
||||
|
||||
|
||||
class UnifiedChatOpenAI(ChatOpenAI):
|
||||
"""ChatOpenAI subclass that strips incompatible params for certain models."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
model = kwargs.get("model", "")
|
||||
if self._is_reasoning_model(model):
|
||||
kwargs.pop("temperature", None)
|
||||
kwargs.pop("top_p", None)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _is_reasoning_model(model: str) -> bool:
|
||||
"""Check if model is a reasoning model that doesn't support temperature."""
|
||||
model_lower = model.lower()
|
||||
return (
|
||||
model_lower.startswith("o1")
|
||||
or model_lower.startswith("o3")
|
||||
or "gpt-5" in model_lower
|
||||
)
|
||||
|
||||
|
||||
class OpenAIClient(BaseLLMClient):
|
||||
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
base_url: Optional[str] = None,
|
||||
provider: str = "openai",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
self.provider = provider.lower()
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured ChatOpenAI instance."""
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
if self.provider == "xai":
|
||||
llm_kwargs["base_url"] = "https://api.x.ai/v1"
|
||||
api_key = os.environ.get("XAI_API_KEY")
|
||||
if api_key:
|
||||
llm_kwargs["api_key"] = api_key
|
||||
elif self.base_url:
|
||||
llm_kwargs["base_url"] = self.base_url
|
||||
|
||||
for key in ("timeout", "max_retries", "reasoning_effort", "api_key"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
return UnifiedChatOpenAI(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
"""Validate model for the provider."""
|
||||
return validate_model(self.provider, self.model)
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
from typing import Dict, List
|
||||
|
||||
VALID_MODELS: Dict[str, List[str]] = {
|
||||
"openai": [
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4",
|
||||
"gpt-3.5-turbo",
|
||||
"o1",
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"o3-mini",
|
||||
"gpt-5-nano",
|
||||
"gpt-5-mini",
|
||||
"gpt-5",
|
||||
],
|
||||
"anthropic": [
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-opus-4-5-20251101",
|
||||
],
|
||||
"google": [
|
||||
"gemini-1.5-pro",
|
||||
"gemini-1.5-flash",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-flash-preview-05-20",
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-3-flash-preview",
|
||||
],
|
||||
"xai": [
|
||||
"grok-beta",
|
||||
"grok-2",
|
||||
"grok-2-mini",
|
||||
"grok-3",
|
||||
"grok-3-mini",
|
||||
],
|
||||
"ollama": [],
|
||||
"openrouter": [],
|
||||
"vllm": [],
|
||||
}
|
||||
|
||||
|
||||
def validate_model(provider: str, model: str) -> bool:
|
||||
"""Validate that a model is supported by the provider.
|
||||
|
||||
For ollama, openrouter, and vllm, any model is accepted.
|
||||
For other providers, checks against VALID_MODELS.
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower in ("ollama", "openrouter", "vllm"):
|
||||
return True
|
||||
|
||||
if provider_lower not in VALID_MODELS:
|
||||
return False
|
||||
|
||||
valid = VALID_MODELS[provider_lower]
|
||||
if not valid:
|
||||
return True
|
||||
|
||||
return model in valid
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
|
||||
|
||||
class VLLMClient(BaseLLMClient):
|
||||
"""Client for vLLM (placeholder for future implementation)."""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured vLLM instance."""
|
||||
raise NotImplementedError("vLLM client not yet implemented")
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
"""Validate model for vLLM."""
|
||||
return True
|
||||
Loading…
Reference in New Issue