444 lines
16 KiB
Python
444 lines
16 KiB
Python
"""
|
|
LLM Factory for TradingAgents.
|
|
|
|
Provides unified interface for creating LLM instances from different providers
|
|
(OpenAI, Anthropic, Google, etc.) with consistent configuration.
|
|
"""
|
|
|
|
import os
|
|
from typing import Optional, Dict, Any, Union
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Type definitions for LLM instances
|
|
# Define LLMType as Union of supported LLM providers
|
|
try:
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_anthropic import ChatAnthropic
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
except ImportError:
|
|
# Fallback imports not available during type checking
|
|
ChatOpenAI = Any # type: ignore
|
|
ChatAnthropic = Any # type: ignore
|
|
ChatGoogleGenerativeAI = Any # type: ignore
|
|
|
|
# LLMType union for return type annotations
|
|
LLMType = Union[ChatOpenAI, ChatAnthropic, ChatGoogleGenerativeAI]
|
|
|
|
|
|
class LLMFactory:
|
|
"""Factory for creating LLM instances from different providers."""
|
|
|
|
SUPPORTED_PROVIDERS = ["openai", "anthropic", "google"]
|
|
|
|
@staticmethod
|
|
def create_llm(
|
|
provider: str,
|
|
model: str,
|
|
temperature: float = 1.0,
|
|
max_tokens: Optional[int] = None,
|
|
backend_url: Optional[str] = None,
|
|
**kwargs
|
|
) -> LLMType:
|
|
"""
|
|
Create an LLM instance for the specified provider.
|
|
|
|
Args:
|
|
provider: LLM provider ("openai", "anthropic", "google")
|
|
model: Model name (e.g., "gpt-4o", "claude-3-5-sonnet-20241022")
|
|
temperature: Sampling temperature (0.0 to 2.0)
|
|
max_tokens: Maximum tokens to generate
|
|
backend_url: Custom API endpoint (for OpenAI-compatible APIs)
|
|
**kwargs: Additional provider-specific arguments
|
|
|
|
Returns:
|
|
LLM instance from the appropriate langchain provider
|
|
|
|
Raises:
|
|
ValueError: If provider is not supported or API key is missing
|
|
ImportError: If required package is not installed
|
|
|
|
Examples:
|
|
>>> # OpenAI
|
|
>>> llm = LLMFactory.create_llm("openai", "gpt-4o")
|
|
|
|
>>> # Anthropic
|
|
>>> llm = LLMFactory.create_llm("anthropic", "claude-3-5-sonnet-20241022")
|
|
|
|
>>> # Google
|
|
>>> llm = LLMFactory.create_llm("google", "gemini-pro")
|
|
"""
|
|
provider = provider.lower()
|
|
|
|
if provider not in LLMFactory.SUPPORTED_PROVIDERS:
|
|
error_msg = (f"Unsupported LLM provider: {provider}. "
|
|
f"Supported providers: {', '.join(LLMFactory.SUPPORTED_PROVIDERS)}")
|
|
logger.error(error_msg)
|
|
raise ValueError(error_msg)
|
|
|
|
logger.info("Creating LLM: provider=%s, model=%s, temperature=%.2f",
|
|
provider, model, temperature)
|
|
|
|
if provider == "openai":
|
|
return LLMFactory._create_openai_llm(
|
|
model, temperature, max_tokens, backend_url, **kwargs
|
|
)
|
|
elif provider == "anthropic":
|
|
return LLMFactory._create_anthropic_llm(
|
|
model, temperature, max_tokens, **kwargs
|
|
)
|
|
elif provider == "google":
|
|
return LLMFactory._create_google_llm(
|
|
model, temperature, max_tokens, **kwargs
|
|
)
|
|
else:
|
|
# This should never be reached due to provider validation above
|
|
logger.error("Unsupported provider after validation: %s", provider)
|
|
raise ValueError(f"Unsupported provider: {provider}")
|
|
|
|
@staticmethod
|
|
def _create_openai_llm(
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: Optional[int],
|
|
backend_url: Optional[str],
|
|
**kwargs
|
|
) -> LLMType:
|
|
"""
|
|
Create OpenAI LLM instance with specified configuration.
|
|
|
|
Args:
|
|
model: OpenAI model name (e.g., "gpt-4o", "gpt-4-turbo")
|
|
temperature: Sampling temperature (0.0 to 2.0)
|
|
max_tokens: Maximum tokens to generate
|
|
backend_url: Optional custom API endpoint for OpenAI-compatible APIs
|
|
**kwargs: Additional provider-specific arguments
|
|
|
|
Returns:
|
|
Configured ChatOpenAI instance
|
|
|
|
Raises:
|
|
ImportError: If langchain-openai package not installed
|
|
ValueError: If OPENAI_API_KEY not configured
|
|
"""
|
|
try:
|
|
from langchain_openai import ChatOpenAI
|
|
except ImportError as e:
|
|
logger.error("Failed to import langchain_openai: %s", str(e))
|
|
raise ImportError(
|
|
"langchain-openai is required for OpenAI models. "
|
|
"Install with: pip install langchain-openai"
|
|
)
|
|
|
|
# Check API key
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
if not api_key:
|
|
logger.error("OPENAI_API_KEY environment variable not set")
|
|
raise ValueError(
|
|
"OPENAI_API_KEY environment variable is required. "
|
|
"Set it in your .env file or environment."
|
|
)
|
|
|
|
# Build configuration
|
|
config = {
|
|
"model": model,
|
|
"temperature": temperature,
|
|
**kwargs
|
|
}
|
|
|
|
if max_tokens:
|
|
config["max_tokens"] = max_tokens
|
|
|
|
if backend_url:
|
|
config["base_url"] = backend_url
|
|
logger.debug("Using custom backend URL for OpenAI: %s", backend_url)
|
|
|
|
logger.info("Creating OpenAI LLM: model=%s, temperature=%.2f, max_tokens=%s",
|
|
model, temperature, max_tokens)
|
|
logger.debug("OpenAI LLM config: %s", config)
|
|
return ChatOpenAI(**config)
|
|
|
|
@staticmethod
|
|
def _create_anthropic_llm(
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: Optional[int],
|
|
**kwargs
|
|
) -> LLMType:
|
|
"""
|
|
Create Anthropic (Claude) LLM instance with specified configuration.
|
|
|
|
Args:
|
|
model: Anthropic model name (e.g., "claude-3-5-sonnet-20241022")
|
|
temperature: Sampling temperature (0.0 to 1.0 for Claude)
|
|
max_tokens: Maximum tokens to generate (defaults to 4096 for Claude)
|
|
**kwargs: Additional provider-specific arguments
|
|
|
|
Returns:
|
|
Configured ChatAnthropic instance
|
|
|
|
Raises:
|
|
ImportError: If langchain-anthropic package not installed
|
|
ValueError: If ANTHROPIC_API_KEY not configured
|
|
"""
|
|
try:
|
|
from langchain_anthropic import ChatAnthropic
|
|
except ImportError as e:
|
|
logger.error("Failed to import langchain_anthropic: %s", str(e))
|
|
raise ImportError(
|
|
"langchain-anthropic is required for Anthropic models. "
|
|
"Install with: pip install langchain-anthropic"
|
|
)
|
|
|
|
# Check API key
|
|
api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
if not api_key:
|
|
logger.error("ANTHROPIC_API_KEY environment variable not set")
|
|
raise ValueError(
|
|
"ANTHROPIC_API_KEY environment variable is required. "
|
|
"Set it in your .env file or environment."
|
|
)
|
|
|
|
# Build configuration
|
|
config = {
|
|
"model": model,
|
|
"temperature": temperature,
|
|
"anthropic_api_key": api_key,
|
|
**kwargs
|
|
}
|
|
|
|
if max_tokens:
|
|
config["max_tokens"] = max_tokens
|
|
else:
|
|
# Claude requires max_tokens, use reasonable default
|
|
config["max_tokens"] = 4096
|
|
logger.debug("Using default max_tokens for Claude: 4096")
|
|
|
|
logger.info("Creating Anthropic LLM: model=%s, temperature=%.2f, max_tokens=%d",
|
|
model, temperature, config["max_tokens"])
|
|
logger.debug("Anthropic LLM config keys: %s", list(config.keys()))
|
|
return ChatAnthropic(**config)
|
|
|
|
@staticmethod
|
|
def _create_google_llm(
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: Optional[int],
|
|
**kwargs
|
|
) -> LLMType:
|
|
"""
|
|
Create Google (Gemini) LLM instance with specified configuration.
|
|
|
|
Args:
|
|
model: Google model name (e.g., "gemini-1.5-pro", "gemini-1.5-flash")
|
|
temperature: Sampling temperature (0.0 to 2.0 for Gemini)
|
|
max_tokens: Maximum tokens to generate
|
|
**kwargs: Additional provider-specific arguments
|
|
|
|
Returns:
|
|
Configured ChatGoogleGenerativeAI instance
|
|
|
|
Raises:
|
|
ImportError: If langchain-google-genai package not installed
|
|
ValueError: If GOOGLE_API_KEY not configured
|
|
"""
|
|
try:
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
except ImportError as e:
|
|
logger.error("Failed to import langchain_google_genai: %s", str(e))
|
|
raise ImportError(
|
|
"langchain-google-genai is required for Google models. "
|
|
"Install with: pip install langchain-google-genai"
|
|
)
|
|
|
|
# Check API key
|
|
api_key = os.getenv("GOOGLE_API_KEY")
|
|
if not api_key:
|
|
logger.error("GOOGLE_API_KEY environment variable not set")
|
|
raise ValueError(
|
|
"GOOGLE_API_KEY environment variable is required. "
|
|
"Set it in your .env file or environment."
|
|
)
|
|
|
|
# Build configuration
|
|
config = {
|
|
"model": model,
|
|
"temperature": temperature,
|
|
"google_api_key": api_key,
|
|
**kwargs
|
|
}
|
|
|
|
if max_tokens:
|
|
config["max_output_tokens"] = max_tokens
|
|
|
|
logger.info("Creating Google LLM: model=%s, temperature=%.2f, max_tokens=%s",
|
|
model, temperature, max_tokens)
|
|
logger.debug("Google LLM config keys: %s", list(config.keys()))
|
|
return ChatGoogleGenerativeAI(**config)
|
|
|
|
@staticmethod
|
|
def get_recommended_models(provider: str) -> Dict[str, str]:
|
|
"""
|
|
Get recommended model names for a provider.
|
|
|
|
Args:
|
|
provider: LLM provider name
|
|
|
|
Returns:
|
|
Dictionary with model recommendations for different use cases
|
|
|
|
Examples:
|
|
>>> models = LLMFactory.get_recommended_models("anthropic")
|
|
>>> print(models["deep_thinking"]) # claude-3-5-sonnet-20241022
|
|
"""
|
|
recommendations = {
|
|
"openai": {
|
|
"deep_thinking": "o1-preview", # Best reasoning
|
|
"fast_thinking": "gpt-4o", # Fast, capable
|
|
"budget": "gpt-4o-mini", # Cost-effective
|
|
"legacy": "gpt-4-turbo" # Previous generation
|
|
},
|
|
"anthropic": {
|
|
"deep_thinking": "claude-3-5-sonnet-20241022", # Best overall
|
|
"fast_thinking": "claude-3-5-sonnet-20241022", # Same (very fast)
|
|
"budget": "claude-3-5-haiku-20241022", # Cost-effective
|
|
"legacy": "claude-3-opus-20240229" # Previous best
|
|
},
|
|
"google": {
|
|
"deep_thinking": "gemini-1.5-pro", # Best reasoning
|
|
"fast_thinking": "gemini-1.5-flash", # Fastest
|
|
"budget": "gemini-1.5-flash", # Same as fast
|
|
"legacy": "gemini-pro" # Previous generation
|
|
}
|
|
}
|
|
|
|
provider = provider.lower()
|
|
if provider not in recommendations:
|
|
raise ValueError(f"Unknown provider: {provider}")
|
|
|
|
return recommendations[provider]
|
|
|
|
@staticmethod
|
|
def validate_provider_setup(provider: str) -> Dict[str, Any]:
|
|
"""
|
|
Validate that a provider is properly configured.
|
|
|
|
Checks if the required package is installed and API key is configured
|
|
for the specified provider.
|
|
|
|
Args:
|
|
provider: Provider to validate (openai, anthropic, google)
|
|
|
|
Returns:
|
|
Dictionary with validation results containing:
|
|
- provider: Provider name
|
|
- valid: Overall validation status (True if ready to use)
|
|
- api_key_set: Whether API key environment variable is set
|
|
- package_installed: Whether required langchain package is installed
|
|
- errors: List of validation errors encountered
|
|
|
|
Examples:
|
|
>>> result = LLMFactory.validate_provider_setup("anthropic")
|
|
>>> if result["valid"]:
|
|
... print("Anthropic is properly configured!")
|
|
>>> else:
|
|
... for error in result["errors"]:
|
|
... print(error)
|
|
"""
|
|
provider = provider.lower()
|
|
logger.debug("Validating provider setup: %s", provider)
|
|
|
|
result = {
|
|
"provider": provider,
|
|
"valid": False,
|
|
"api_key_set": False,
|
|
"package_installed": False,
|
|
"errors": []
|
|
}
|
|
|
|
# Check package installation
|
|
try:
|
|
if provider == "openai":
|
|
import langchain_openai
|
|
result["package_installed"] = True
|
|
logger.debug("langchain_openai package found")
|
|
elif provider == "anthropic":
|
|
import langchain_anthropic
|
|
result["package_installed"] = True
|
|
logger.debug("langchain_anthropic package found")
|
|
elif provider == "google":
|
|
import langchain_google_genai
|
|
result["package_installed"] = True
|
|
logger.debug("langchain_google_genai package found")
|
|
except ImportError as e:
|
|
error_msg = f"Package not installed: {e}"
|
|
result["errors"].append(error_msg)
|
|
logger.warning("Package check failed: %s", error_msg)
|
|
|
|
# Check API key
|
|
key_env_vars = {
|
|
"openai": "OPENAI_API_KEY",
|
|
"anthropic": "ANTHROPIC_API_KEY",
|
|
"google": "GOOGLE_API_KEY"
|
|
}
|
|
|
|
if provider in key_env_vars:
|
|
env_var = key_env_vars[provider]
|
|
if os.getenv(env_var):
|
|
result["api_key_set"] = True
|
|
logger.debug("%s environment variable is set", env_var)
|
|
else:
|
|
error_msg = f"{env_var} not set in environment"
|
|
result["errors"].append(error_msg)
|
|
logger.warning("API key not found: %s", error_msg)
|
|
|
|
# Overall validation
|
|
result["valid"] = result["package_installed"] and result["api_key_set"]
|
|
logger.info("Provider validation for %s: valid=%s, errors=%d",
|
|
provider, result["valid"], len(result["errors"]))
|
|
|
|
return result
|
|
|
|
|
|
# Convenience function
|
|
def create_llm(provider: str = "openai", model: Optional[str] = None, **kwargs) -> LLMType:
|
|
"""
|
|
Convenience wrapper for LLMFactory.create_llm() with smart defaults.
|
|
|
|
If model is not specified, uses the recommended best-in-class model for
|
|
the provider (optimized for deep thinking and complex reasoning).
|
|
|
|
Args:
|
|
provider: LLM provider (default: "openai")
|
|
- "openai": Uses o1-preview as default
|
|
- "anthropic": Uses Claude 3.5 Sonnet as default
|
|
- "google": Uses Gemini 1.5 Pro as default
|
|
model: Specific model to use. If None, uses provider's recommended model
|
|
**kwargs: Additional arguments to pass to LLMFactory.create_llm()
|
|
|
|
Returns:
|
|
Configured LLM instance
|
|
|
|
Raises:
|
|
ValueError: If provider is not supported or API key is missing
|
|
ImportError: If required package not installed
|
|
|
|
Examples:
|
|
>>> llm = create_llm("anthropic") # Uses Claude 3.5 Sonnet
|
|
>>> llm = create_llm("openai", "gpt-4o") # Uses GPT-4O
|
|
>>> llm = create_llm("google", "gemini-1.5-flash") # Fast Gemini
|
|
"""
|
|
if model is None:
|
|
# Use recommended deep thinking model
|
|
logger.debug("No model specified for %s, using recommended default", provider)
|
|
try:
|
|
recommended = LLMFactory.get_recommended_models(provider)
|
|
model = recommended["deep_thinking"]
|
|
logger.info("Using recommended model for %s: %s", provider, model)
|
|
except ValueError as e:
|
|
logger.error("Failed to get recommended model: %s", str(e))
|
|
raise
|
|
|
|
return LLMFactory.create_llm(provider, model, **kwargs)
|