TradingAgents/tradingagents/llm_factory.py

444 lines
16 KiB
Python

"""
LLM Factory for TradingAgents.
Provides unified interface for creating LLM instances from different providers
(OpenAI, Anthropic, Google, etc.) with consistent configuration.
"""
import os
from typing import Optional, Dict, Any, Union
import logging
logger = logging.getLogger(__name__)
# Type definitions for LLM instances
# Define LLMType as Union of supported LLM providers
try:
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
except ImportError:
# Fallback imports not available during type checking
ChatOpenAI = Any # type: ignore
ChatAnthropic = Any # type: ignore
ChatGoogleGenerativeAI = Any # type: ignore
# LLMType union for return type annotations
LLMType = Union[ChatOpenAI, ChatAnthropic, ChatGoogleGenerativeAI]
class LLMFactory:
"""Factory for creating LLM instances from different providers."""
SUPPORTED_PROVIDERS = ["openai", "anthropic", "google"]
@staticmethod
def create_llm(
provider: str,
model: str,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
backend_url: Optional[str] = None,
**kwargs
) -> LLMType:
"""
Create an LLM instance for the specified provider.
Args:
provider: LLM provider ("openai", "anthropic", "google")
model: Model name (e.g., "gpt-4o", "claude-3-5-sonnet-20241022")
temperature: Sampling temperature (0.0 to 2.0)
max_tokens: Maximum tokens to generate
backend_url: Custom API endpoint (for OpenAI-compatible APIs)
**kwargs: Additional provider-specific arguments
Returns:
LLM instance from the appropriate langchain provider
Raises:
ValueError: If provider is not supported or API key is missing
ImportError: If required package is not installed
Examples:
>>> # OpenAI
>>> llm = LLMFactory.create_llm("openai", "gpt-4o")
>>> # Anthropic
>>> llm = LLMFactory.create_llm("anthropic", "claude-3-5-sonnet-20241022")
>>> # Google
>>> llm = LLMFactory.create_llm("google", "gemini-pro")
"""
provider = provider.lower()
if provider not in LLMFactory.SUPPORTED_PROVIDERS:
error_msg = (f"Unsupported LLM provider: {provider}. "
f"Supported providers: {', '.join(LLMFactory.SUPPORTED_PROVIDERS)}")
logger.error(error_msg)
raise ValueError(error_msg)
logger.info("Creating LLM: provider=%s, model=%s, temperature=%.2f",
provider, model, temperature)
if provider == "openai":
return LLMFactory._create_openai_llm(
model, temperature, max_tokens, backend_url, **kwargs
)
elif provider == "anthropic":
return LLMFactory._create_anthropic_llm(
model, temperature, max_tokens, **kwargs
)
elif provider == "google":
return LLMFactory._create_google_llm(
model, temperature, max_tokens, **kwargs
)
else:
# This should never be reached due to provider validation above
logger.error("Unsupported provider after validation: %s", provider)
raise ValueError(f"Unsupported provider: {provider}")
@staticmethod
def _create_openai_llm(
model: str,
temperature: float,
max_tokens: Optional[int],
backend_url: Optional[str],
**kwargs
) -> LLMType:
"""
Create OpenAI LLM instance with specified configuration.
Args:
model: OpenAI model name (e.g., "gpt-4o", "gpt-4-turbo")
temperature: Sampling temperature (0.0 to 2.0)
max_tokens: Maximum tokens to generate
backend_url: Optional custom API endpoint for OpenAI-compatible APIs
**kwargs: Additional provider-specific arguments
Returns:
Configured ChatOpenAI instance
Raises:
ImportError: If langchain-openai package not installed
ValueError: If OPENAI_API_KEY not configured
"""
try:
from langchain_openai import ChatOpenAI
except ImportError as e:
logger.error("Failed to import langchain_openai: %s", str(e))
raise ImportError(
"langchain-openai is required for OpenAI models. "
"Install with: pip install langchain-openai"
)
# Check API key
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.error("OPENAI_API_KEY environment variable not set")
raise ValueError(
"OPENAI_API_KEY environment variable is required. "
"Set it in your .env file or environment."
)
# Build configuration
config = {
"model": model,
"temperature": temperature,
**kwargs
}
if max_tokens:
config["max_tokens"] = max_tokens
if backend_url:
config["base_url"] = backend_url
logger.debug("Using custom backend URL for OpenAI: %s", backend_url)
logger.info("Creating OpenAI LLM: model=%s, temperature=%.2f, max_tokens=%s",
model, temperature, max_tokens)
logger.debug("OpenAI LLM config: %s", config)
return ChatOpenAI(**config)
@staticmethod
def _create_anthropic_llm(
model: str,
temperature: float,
max_tokens: Optional[int],
**kwargs
) -> LLMType:
"""
Create Anthropic (Claude) LLM instance with specified configuration.
Args:
model: Anthropic model name (e.g., "claude-3-5-sonnet-20241022")
temperature: Sampling temperature (0.0 to 1.0 for Claude)
max_tokens: Maximum tokens to generate (defaults to 4096 for Claude)
**kwargs: Additional provider-specific arguments
Returns:
Configured ChatAnthropic instance
Raises:
ImportError: If langchain-anthropic package not installed
ValueError: If ANTHROPIC_API_KEY not configured
"""
try:
from langchain_anthropic import ChatAnthropic
except ImportError as e:
logger.error("Failed to import langchain_anthropic: %s", str(e))
raise ImportError(
"langchain-anthropic is required for Anthropic models. "
"Install with: pip install langchain-anthropic"
)
# Check API key
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
logger.error("ANTHROPIC_API_KEY environment variable not set")
raise ValueError(
"ANTHROPIC_API_KEY environment variable is required. "
"Set it in your .env file or environment."
)
# Build configuration
config = {
"model": model,
"temperature": temperature,
"anthropic_api_key": api_key,
**kwargs
}
if max_tokens:
config["max_tokens"] = max_tokens
else:
# Claude requires max_tokens, use reasonable default
config["max_tokens"] = 4096
logger.debug("Using default max_tokens for Claude: 4096")
logger.info("Creating Anthropic LLM: model=%s, temperature=%.2f, max_tokens=%d",
model, temperature, config["max_tokens"])
logger.debug("Anthropic LLM config keys: %s", list(config.keys()))
return ChatAnthropic(**config)
@staticmethod
def _create_google_llm(
model: str,
temperature: float,
max_tokens: Optional[int],
**kwargs
) -> LLMType:
"""
Create Google (Gemini) LLM instance with specified configuration.
Args:
model: Google model name (e.g., "gemini-1.5-pro", "gemini-1.5-flash")
temperature: Sampling temperature (0.0 to 2.0 for Gemini)
max_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific arguments
Returns:
Configured ChatGoogleGenerativeAI instance
Raises:
ImportError: If langchain-google-genai package not installed
ValueError: If GOOGLE_API_KEY not configured
"""
try:
from langchain_google_genai import ChatGoogleGenerativeAI
except ImportError as e:
logger.error("Failed to import langchain_google_genai: %s", str(e))
raise ImportError(
"langchain-google-genai is required for Google models. "
"Install with: pip install langchain-google-genai"
)
# Check API key
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
logger.error("GOOGLE_API_KEY environment variable not set")
raise ValueError(
"GOOGLE_API_KEY environment variable is required. "
"Set it in your .env file or environment."
)
# Build configuration
config = {
"model": model,
"temperature": temperature,
"google_api_key": api_key,
**kwargs
}
if max_tokens:
config["max_output_tokens"] = max_tokens
logger.info("Creating Google LLM: model=%s, temperature=%.2f, max_tokens=%s",
model, temperature, max_tokens)
logger.debug("Google LLM config keys: %s", list(config.keys()))
return ChatGoogleGenerativeAI(**config)
@staticmethod
def get_recommended_models(provider: str) -> Dict[str, str]:
"""
Get recommended model names for a provider.
Args:
provider: LLM provider name
Returns:
Dictionary with model recommendations for different use cases
Examples:
>>> models = LLMFactory.get_recommended_models("anthropic")
>>> print(models["deep_thinking"]) # claude-3-5-sonnet-20241022
"""
recommendations = {
"openai": {
"deep_thinking": "o1-preview", # Best reasoning
"fast_thinking": "gpt-4o", # Fast, capable
"budget": "gpt-4o-mini", # Cost-effective
"legacy": "gpt-4-turbo" # Previous generation
},
"anthropic": {
"deep_thinking": "claude-3-5-sonnet-20241022", # Best overall
"fast_thinking": "claude-3-5-sonnet-20241022", # Same (very fast)
"budget": "claude-3-5-haiku-20241022", # Cost-effective
"legacy": "claude-3-opus-20240229" # Previous best
},
"google": {
"deep_thinking": "gemini-1.5-pro", # Best reasoning
"fast_thinking": "gemini-1.5-flash", # Fastest
"budget": "gemini-1.5-flash", # Same as fast
"legacy": "gemini-pro" # Previous generation
}
}
provider = provider.lower()
if provider not in recommendations:
raise ValueError(f"Unknown provider: {provider}")
return recommendations[provider]
@staticmethod
def validate_provider_setup(provider: str) -> Dict[str, Any]:
"""
Validate that a provider is properly configured.
Checks if the required package is installed and API key is configured
for the specified provider.
Args:
provider: Provider to validate (openai, anthropic, google)
Returns:
Dictionary with validation results containing:
- provider: Provider name
- valid: Overall validation status (True if ready to use)
- api_key_set: Whether API key environment variable is set
- package_installed: Whether required langchain package is installed
- errors: List of validation errors encountered
Examples:
>>> result = LLMFactory.validate_provider_setup("anthropic")
>>> if result["valid"]:
... print("Anthropic is properly configured!")
>>> else:
... for error in result["errors"]:
... print(error)
"""
provider = provider.lower()
logger.debug("Validating provider setup: %s", provider)
result = {
"provider": provider,
"valid": False,
"api_key_set": False,
"package_installed": False,
"errors": []
}
# Check package installation
try:
if provider == "openai":
import langchain_openai
result["package_installed"] = True
logger.debug("langchain_openai package found")
elif provider == "anthropic":
import langchain_anthropic
result["package_installed"] = True
logger.debug("langchain_anthropic package found")
elif provider == "google":
import langchain_google_genai
result["package_installed"] = True
logger.debug("langchain_google_genai package found")
except ImportError as e:
error_msg = f"Package not installed: {e}"
result["errors"].append(error_msg)
logger.warning("Package check failed: %s", error_msg)
# Check API key
key_env_vars = {
"openai": "OPENAI_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"google": "GOOGLE_API_KEY"
}
if provider in key_env_vars:
env_var = key_env_vars[provider]
if os.getenv(env_var):
result["api_key_set"] = True
logger.debug("%s environment variable is set", env_var)
else:
error_msg = f"{env_var} not set in environment"
result["errors"].append(error_msg)
logger.warning("API key not found: %s", error_msg)
# Overall validation
result["valid"] = result["package_installed"] and result["api_key_set"]
logger.info("Provider validation for %s: valid=%s, errors=%d",
provider, result["valid"], len(result["errors"]))
return result
# Convenience function
def create_llm(provider: str = "openai", model: Optional[str] = None, **kwargs) -> LLMType:
"""
Convenience wrapper for LLMFactory.create_llm() with smart defaults.
If model is not specified, uses the recommended best-in-class model for
the provider (optimized for deep thinking and complex reasoning).
Args:
provider: LLM provider (default: "openai")
- "openai": Uses o1-preview as default
- "anthropic": Uses Claude 3.5 Sonnet as default
- "google": Uses Gemini 1.5 Pro as default
model: Specific model to use. If None, uses provider's recommended model
**kwargs: Additional arguments to pass to LLMFactory.create_llm()
Returns:
Configured LLM instance
Raises:
ValueError: If provider is not supported or API key is missing
ImportError: If required package not installed
Examples:
>>> llm = create_llm("anthropic") # Uses Claude 3.5 Sonnet
>>> llm = create_llm("openai", "gpt-4o") # Uses GPT-4O
>>> llm = create_llm("google", "gemini-1.5-flash") # Fast Gemini
"""
if model is None:
# Use recommended deep thinking model
logger.debug("No model specified for %s, using recommended default", provider)
try:
recommended = LLMFactory.get_recommended_models(provider)
model = recommended["deep_thinking"]
logger.info("Using recommended model for %s: %s", provider, model)
except ValueError as e:
logger.error("Failed to get recommended model: %s", str(e))
raise
return LLMFactory.create_llm(provider, model, **kwargs)