318 lines
10 KiB
Python
318 lines
10 KiB
Python
"""
|
|
LLM Factory for TradingAgents.
|
|
|
|
Provides unified interface for creating LLM instances from different providers
|
|
(OpenAI, Anthropic, Google, etc.) with consistent configuration.
|
|
"""
|
|
|
|
import os
|
|
from typing import Optional, Dict, Any
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMFactory:
|
|
"""Factory for creating LLM instances from different providers."""
|
|
|
|
SUPPORTED_PROVIDERS = ["openai", "anthropic", "google"]
|
|
|
|
@staticmethod
|
|
def create_llm(
|
|
provider: str,
|
|
model: str,
|
|
temperature: float = 1.0,
|
|
max_tokens: Optional[int] = None,
|
|
backend_url: Optional[str] = None,
|
|
**kwargs
|
|
):
|
|
"""
|
|
Create an LLM instance for the specified provider.
|
|
|
|
Args:
|
|
provider: LLM provider ("openai", "anthropic", "google")
|
|
model: Model name (e.g., "gpt-4o", "claude-3-5-sonnet-20241022")
|
|
temperature: Sampling temperature (0.0 to 2.0)
|
|
max_tokens: Maximum tokens to generate
|
|
backend_url: Custom API endpoint (for OpenAI-compatible APIs)
|
|
**kwargs: Additional provider-specific arguments
|
|
|
|
Returns:
|
|
LLM instance from the appropriate langchain provider
|
|
|
|
Raises:
|
|
ValueError: If provider is not supported or API key is missing
|
|
ImportError: If required package is not installed
|
|
|
|
Examples:
|
|
>>> # OpenAI
|
|
>>> llm = LLMFactory.create_llm("openai", "gpt-4o")
|
|
|
|
>>> # Anthropic
|
|
>>> llm = LLMFactory.create_llm("anthropic", "claude-3-5-sonnet-20241022")
|
|
|
|
>>> # Google
|
|
>>> llm = LLMFactory.create_llm("google", "gemini-pro")
|
|
"""
|
|
provider = provider.lower()
|
|
|
|
if provider not in LLMFactory.SUPPORTED_PROVIDERS:
|
|
raise ValueError(
|
|
f"Unsupported LLM provider: {provider}. "
|
|
f"Supported providers: {', '.join(LLMFactory.SUPPORTED_PROVIDERS)}"
|
|
)
|
|
|
|
if provider == "openai":
|
|
return LLMFactory._create_openai_llm(
|
|
model, temperature, max_tokens, backend_url, **kwargs
|
|
)
|
|
elif provider == "anthropic":
|
|
return LLMFactory._create_anthropic_llm(
|
|
model, temperature, max_tokens, **kwargs
|
|
)
|
|
elif provider == "google":
|
|
return LLMFactory._create_google_llm(
|
|
model, temperature, max_tokens, **kwargs
|
|
)
|
|
|
|
@staticmethod
|
|
def _create_openai_llm(
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: Optional[int],
|
|
backend_url: Optional[str],
|
|
**kwargs
|
|
):
|
|
"""Create OpenAI LLM instance."""
|
|
try:
|
|
from langchain_openai import ChatOpenAI
|
|
except ImportError:
|
|
raise ImportError(
|
|
"langchain-openai is required for OpenAI models. "
|
|
"Install with: pip install langchain-openai"
|
|
)
|
|
|
|
# Check API key
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
if not api_key:
|
|
raise ValueError(
|
|
"OPENAI_API_KEY environment variable is required. "
|
|
"Set it in your .env file or environment."
|
|
)
|
|
|
|
# Build configuration
|
|
config = {
|
|
"model": model,
|
|
"temperature": temperature,
|
|
**kwargs
|
|
}
|
|
|
|
if max_tokens:
|
|
config["max_tokens"] = max_tokens
|
|
|
|
if backend_url:
|
|
config["base_url"] = backend_url
|
|
|
|
logger.info(f"Creating OpenAI LLM: {model} (temp={temperature})")
|
|
return ChatOpenAI(**config)
|
|
|
|
@staticmethod
|
|
def _create_anthropic_llm(
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: Optional[int],
|
|
**kwargs
|
|
):
|
|
"""Create Anthropic (Claude) LLM instance."""
|
|
try:
|
|
from langchain_anthropic import ChatAnthropic
|
|
except ImportError:
|
|
raise ImportError(
|
|
"langchain-anthropic is required for Anthropic models. "
|
|
"Install with: pip install langchain-anthropic"
|
|
)
|
|
|
|
# Check API key
|
|
api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
if not api_key:
|
|
raise ValueError(
|
|
"ANTHROPIC_API_KEY environment variable is required. "
|
|
"Set it in your .env file or environment."
|
|
)
|
|
|
|
# Build configuration
|
|
config = {
|
|
"model": model,
|
|
"temperature": temperature,
|
|
"anthropic_api_key": api_key,
|
|
**kwargs
|
|
}
|
|
|
|
if max_tokens:
|
|
config["max_tokens"] = max_tokens
|
|
else:
|
|
# Claude requires max_tokens, use reasonable default
|
|
config["max_tokens"] = 4096
|
|
|
|
logger.info(f"Creating Anthropic LLM: {model} (temp={temperature})")
|
|
return ChatAnthropic(**config)
|
|
|
|
@staticmethod
|
|
def _create_google_llm(
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: Optional[int],
|
|
**kwargs
|
|
):
|
|
"""Create Google (Gemini) LLM instance."""
|
|
try:
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
except ImportError:
|
|
raise ImportError(
|
|
"langchain-google-genai is required for Google models. "
|
|
"Install with: pip install langchain-google-genai"
|
|
)
|
|
|
|
# Check API key
|
|
api_key = os.getenv("GOOGLE_API_KEY")
|
|
if not api_key:
|
|
raise ValueError(
|
|
"GOOGLE_API_KEY environment variable is required. "
|
|
"Set it in your .env file or environment."
|
|
)
|
|
|
|
# Build configuration
|
|
config = {
|
|
"model": model,
|
|
"temperature": temperature,
|
|
"google_api_key": api_key,
|
|
**kwargs
|
|
}
|
|
|
|
if max_tokens:
|
|
config["max_output_tokens"] = max_tokens
|
|
|
|
logger.info(f"Creating Google LLM: {model} (temp={temperature})")
|
|
return ChatGoogleGenerativeAI(**config)
|
|
|
|
@staticmethod
|
|
def get_recommended_models(provider: str) -> Dict[str, str]:
|
|
"""
|
|
Get recommended model names for a provider.
|
|
|
|
Args:
|
|
provider: LLM provider name
|
|
|
|
Returns:
|
|
Dictionary with model recommendations for different use cases
|
|
|
|
Examples:
|
|
>>> models = LLMFactory.get_recommended_models("anthropic")
|
|
>>> print(models["deep_thinking"]) # claude-3-5-sonnet-20241022
|
|
"""
|
|
recommendations = {
|
|
"openai": {
|
|
"deep_thinking": "o1-preview", # Best reasoning
|
|
"fast_thinking": "gpt-4o", # Fast, capable
|
|
"budget": "gpt-4o-mini", # Cost-effective
|
|
"legacy": "gpt-4-turbo" # Previous generation
|
|
},
|
|
"anthropic": {
|
|
"deep_thinking": "claude-3-5-sonnet-20241022", # Best overall
|
|
"fast_thinking": "claude-3-5-sonnet-20241022", # Same (very fast)
|
|
"budget": "claude-3-5-haiku-20241022", # Cost-effective
|
|
"legacy": "claude-3-opus-20240229" # Previous best
|
|
},
|
|
"google": {
|
|
"deep_thinking": "gemini-1.5-pro", # Best reasoning
|
|
"fast_thinking": "gemini-1.5-flash", # Fastest
|
|
"budget": "gemini-1.5-flash", # Same as fast
|
|
"legacy": "gemini-pro" # Previous generation
|
|
}
|
|
}
|
|
|
|
provider = provider.lower()
|
|
if provider not in recommendations:
|
|
raise ValueError(f"Unknown provider: {provider}")
|
|
|
|
return recommendations[provider]
|
|
|
|
@staticmethod
|
|
def validate_provider_setup(provider: str) -> Dict[str, Any]:
|
|
"""
|
|
Validate that a provider is properly configured.
|
|
|
|
Args:
|
|
provider: Provider to validate
|
|
|
|
Returns:
|
|
Dictionary with validation results
|
|
|
|
Examples:
|
|
>>> result = LLMFactory.validate_provider_setup("anthropic")
|
|
>>> if result["valid"]:
|
|
... print("Anthropic is configured!")
|
|
"""
|
|
provider = provider.lower()
|
|
|
|
result = {
|
|
"provider": provider,
|
|
"valid": False,
|
|
"api_key_set": False,
|
|
"package_installed": False,
|
|
"errors": []
|
|
}
|
|
|
|
# Check package installation
|
|
try:
|
|
if provider == "openai":
|
|
import langchain_openai
|
|
result["package_installed"] = True
|
|
elif provider == "anthropic":
|
|
import langchain_anthropic
|
|
result["package_installed"] = True
|
|
elif provider == "google":
|
|
import langchain_google_genai
|
|
result["package_installed"] = True
|
|
except ImportError as e:
|
|
result["errors"].append(f"Package not installed: {e}")
|
|
|
|
# Check API key
|
|
key_env_vars = {
|
|
"openai": "OPENAI_API_KEY",
|
|
"anthropic": "ANTHROPIC_API_KEY",
|
|
"google": "GOOGLE_API_KEY"
|
|
}
|
|
|
|
if provider in key_env_vars:
|
|
env_var = key_env_vars[provider]
|
|
if os.getenv(env_var):
|
|
result["api_key_set"] = True
|
|
else:
|
|
result["errors"].append(f"{env_var} not set in environment")
|
|
|
|
# Overall validation
|
|
result["valid"] = result["package_installed"] and result["api_key_set"]
|
|
|
|
return result
|
|
|
|
|
|
# Convenience function
|
|
def create_llm(provider: str = "openai", model: str = None, **kwargs):
|
|
"""
|
|
Convenience wrapper for LLMFactory.create_llm().
|
|
|
|
If model is not specified, uses recommended model for the provider.
|
|
|
|
Examples:
|
|
>>> llm = create_llm("anthropic") # Uses Claude 3.5 Sonnet
|
|
>>> llm = create_llm("openai", "gpt-4o")
|
|
"""
|
|
if model is None:
|
|
# Use recommended deep thinking model
|
|
recommended = LLMFactory.get_recommended_models(provider)
|
|
model = recommended["deep_thinking"]
|
|
logger.info(f"No model specified, using recommended: {model}")
|
|
|
|
return LLMFactory.create_llm(provider, model, **kwargs)
|