TradingAgents/tests/test_llm_factory.py

406 lines
17 KiB
Python

"""
Comprehensive tests for LLM Factory.
Tests provider validation, model recommendations, LLM creation,
error handling, and environment variable configuration.
"""
import os
import pytest
from unittest.mock import Mock, patch, MagicMock
from decimal import Decimal
from tradingagents.llm_factory import LLMFactory, create_llm
class TestLLMFactory:
"""Test suite for LLMFactory class."""
def test_supported_providers(self):
"""Test that supported providers list is correct."""
assert "openai" in LLMFactory.SUPPORTED_PROVIDERS
assert "anthropic" in LLMFactory.SUPPORTED_PROVIDERS
assert "google" in LLMFactory.SUPPORTED_PROVIDERS
assert len(LLMFactory.SUPPORTED_PROVIDERS) == 3
def test_unsupported_provider_raises_error(self):
"""Test that unsupported provider raises ValueError."""
with pytest.raises(ValueError, match="Unsupported LLM provider"):
LLMFactory.create_llm("unsupported_provider", "some-model")
def test_provider_case_insensitive(self):
"""Test that provider names are case-insensitive."""
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
with patch("langchain_openai.ChatOpenAI") as mock_openai:
mock_openai.return_value = Mock()
# These should all work
LLMFactory.create_llm("OpenAI", "gpt-4o")
LLMFactory.create_llm("OPENAI", "gpt-4o")
LLMFactory.create_llm("openai", "gpt-4o")
assert mock_openai.call_count == 3
class TestOpenAILLM:
"""Test OpenAI LLM creation."""
@patch("langchain_openai.ChatOpenAI")
def test_create_openai_llm_basic(self, mock_openai):
"""Test basic OpenAI LLM creation."""
mock_openai.return_value = Mock()
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
llm = LLMFactory.create_llm("openai", "gpt-4o")
assert mock_openai.called
call_kwargs = mock_openai.call_args[1]
assert call_kwargs["model"] == "gpt-4o"
assert call_kwargs["temperature"] == 1.0
@patch("langchain_openai.ChatOpenAI")
def test_create_openai_llm_with_temperature(self, mock_openai):
"""Test OpenAI LLM creation with custom temperature."""
mock_openai.return_value = Mock()
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
LLMFactory.create_llm("openai", "gpt-4o", temperature=0.7)
call_kwargs = mock_openai.call_args[1]
assert call_kwargs["temperature"] == 0.7
@patch("langchain_openai.ChatOpenAI")
def test_create_openai_llm_with_max_tokens(self, mock_openai):
"""Test OpenAI LLM creation with max_tokens."""
mock_openai.return_value = Mock()
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
LLMFactory.create_llm("openai", "gpt-4o", max_tokens=2048)
call_kwargs = mock_openai.call_args[1]
assert call_kwargs["max_tokens"] == 2048
@patch("langchain_openai.ChatOpenAI")
def test_create_openai_llm_with_backend_url(self, mock_openai):
"""Test OpenAI LLM creation with custom backend URL."""
mock_openai.return_value = Mock()
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
custom_url = "https://custom.openai.proxy/v1"
LLMFactory.create_llm(
"openai",
"gpt-4o",
backend_url=custom_url
)
call_kwargs = mock_openai.call_args[1]
assert call_kwargs["base_url"] == custom_url
@patch("langchain_openai.ChatOpenAI")
def test_create_openai_llm_with_extra_kwargs(self, mock_openai):
"""Test OpenAI LLM creation with additional kwargs."""
mock_openai.return_value = Mock()
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
LLMFactory.create_llm(
"openai",
"gpt-4o",
streaming=True,
timeout=30
)
call_kwargs = mock_openai.call_args[1]
assert call_kwargs["streaming"] is True
assert call_kwargs["timeout"] == 30
def test_create_openai_llm_missing_api_key(self):
"""Test that missing API key raises ValueError."""
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(ValueError, match="OPENAI_API_KEY"):
LLMFactory.create_llm("openai", "gpt-4o")
def test_create_openai_llm_missing_package(self):
"""Test that missing langchain-openai package raises ImportError."""
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
with patch.dict("sys.modules", {"langchain_openai": None}):
with pytest.raises(ImportError, match="langchain-openai"):
LLMFactory.create_llm("openai", "gpt-4o")
class TestAnthropicLLM:
"""Test Anthropic (Claude) LLM creation."""
@patch("langchain_anthropic.ChatAnthropic")
def test_create_anthropic_llm_basic(self, mock_anthropic):
"""Test basic Anthropic LLM creation."""
mock_anthropic.return_value = Mock()
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
llm = LLMFactory.create_llm("anthropic", "claude-3-5-sonnet-20241022")
assert mock_anthropic.called
call_kwargs = mock_anthropic.call_args[1]
assert call_kwargs["model"] == "claude-3-5-sonnet-20241022"
assert call_kwargs["temperature"] == 1.0
assert call_kwargs["anthropic_api_key"] == "test-key"
@patch("langchain_anthropic.ChatAnthropic")
def test_create_anthropic_llm_with_max_tokens(self, mock_anthropic):
"""Test Anthropic LLM creation with max_tokens."""
mock_anthropic.return_value = Mock()
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
LLMFactory.create_llm("anthropic", "claude-3-5-sonnet-20241022", max_tokens=8192)
call_kwargs = mock_anthropic.call_args[1]
assert call_kwargs["max_tokens"] == 8192
@patch("langchain_anthropic.ChatAnthropic")
def test_create_anthropic_llm_default_max_tokens(self, mock_anthropic):
"""Test that Anthropic LLM gets default max_tokens if not specified."""
mock_anthropic.return_value = Mock()
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
LLMFactory.create_llm("anthropic", "claude-3-5-sonnet-20241022")
call_kwargs = mock_anthropic.call_args[1]
# Claude requires max_tokens, should default to 4096
assert call_kwargs["max_tokens"] == 4096
def test_create_anthropic_llm_missing_api_key(self):
"""Test that missing API key raises ValueError."""
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(ValueError, match="ANTHROPIC_API_KEY"):
LLMFactory.create_llm("anthropic", "claude-3-5-sonnet-20241022")
def test_create_anthropic_llm_missing_package(self):
"""Test that missing langchain-anthropic package raises ImportError."""
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
with patch.dict("sys.modules", {"langchain_anthropic": None}):
with pytest.raises(ImportError, match="langchain-anthropic"):
LLMFactory.create_llm("anthropic", "claude-3-5-sonnet-20241022")
class TestGoogleLLM:
"""Test Google (Gemini) LLM creation."""
@patch("langchain_google_genai.ChatGoogleGenerativeAI")
def test_create_google_llm_basic(self, mock_google):
"""Test basic Google LLM creation."""
mock_google.return_value = Mock()
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}):
llm = LLMFactory.create_llm("google", "gemini-1.5-pro")
assert mock_google.called
call_kwargs = mock_google.call_args[1]
assert call_kwargs["model"] == "gemini-1.5-pro"
assert call_kwargs["temperature"] == 1.0
assert call_kwargs["google_api_key"] == "test-key"
@patch("langchain_google_genai.ChatGoogleGenerativeAI")
def test_create_google_llm_with_max_tokens(self, mock_google):
"""Test Google LLM creation with max_tokens."""
mock_google.return_value = Mock()
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}):
LLMFactory.create_llm("google", "gemini-1.5-pro", max_tokens=4096)
call_kwargs = mock_google.call_args[1]
# Google uses max_output_tokens instead of max_tokens
assert call_kwargs["max_output_tokens"] == 4096
def test_create_google_llm_missing_api_key(self):
"""Test that missing API key raises ValueError."""
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(ValueError, match="GOOGLE_API_KEY"):
LLMFactory.create_llm("google", "gemini-1.5-pro")
def test_create_google_llm_missing_package(self):
"""Test that missing langchain-google-genai package raises ImportError."""
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}):
with patch.dict("sys.modules", {"langchain_google_genai": None}):
with pytest.raises(ImportError, match="langchain-google-genai"):
LLMFactory.create_llm("google", "gemini-1.5-pro")
class TestModelRecommendations:
"""Test model recommendation functionality."""
def test_get_openai_recommendations(self):
"""Test getting OpenAI model recommendations."""
models = LLMFactory.get_recommended_models("openai")
assert "deep_thinking" in models
assert "fast_thinking" in models
assert "budget" in models
assert "legacy" in models
assert models["deep_thinking"] == "o1-preview"
assert models["fast_thinking"] == "gpt-4o"
assert models["budget"] == "gpt-4o-mini"
def test_get_anthropic_recommendations(self):
"""Test getting Anthropic model recommendations."""
models = LLMFactory.get_recommended_models("anthropic")
assert models["deep_thinking"] == "claude-3-5-sonnet-20241022"
assert models["fast_thinking"] == "claude-3-5-sonnet-20241022"
assert models["budget"] == "claude-3-5-haiku-20241022"
def test_get_google_recommendations(self):
"""Test getting Google model recommendations."""
models = LLMFactory.get_recommended_models("google")
assert models["deep_thinking"] == "gemini-1.5-pro"
assert models["fast_thinking"] == "gemini-1.5-flash"
assert models["budget"] == "gemini-1.5-flash"
def test_get_recommendations_case_insensitive(self):
"""Test that get_recommended_models is case-insensitive."""
models1 = LLMFactory.get_recommended_models("OpenAI")
models2 = LLMFactory.get_recommended_models("openai")
assert models1 == models2
def test_get_recommendations_unknown_provider(self):
"""Test that unknown provider raises ValueError."""
with pytest.raises(ValueError, match="Unknown provider"):
LLMFactory.get_recommended_models("unknown_provider")
class TestProviderValidation:
"""Test provider validation functionality."""
def test_validate_openai_setup_complete(self):
"""Test validating complete OpenAI setup."""
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
with patch("langchain_openai"):
result = LLMFactory.validate_provider_setup("openai")
assert result["provider"] == "openai"
assert result["valid"] is True
assert result["api_key_set"] is True
assert result["package_installed"] is True
assert len(result["errors"]) == 0
def test_validate_openai_missing_key(self):
"""Test validating OpenAI setup with missing API key."""
with patch.dict(os.environ, {}, clear=True):
with patch("langchain_openai"):
result = LLMFactory.validate_provider_setup("openai")
assert result["valid"] is False
assert result["api_key_set"] is False
assert result["package_installed"] is True
assert any("OPENAI_API_KEY" in error for error in result["errors"])
def test_validate_openai_missing_package(self):
"""Test validating OpenAI setup with missing package."""
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
# Simulate ImportError by patching the import
import sys
original_modules = sys.modules.copy()
# Remove the module if it exists
if "langchain_openai" in sys.modules:
del sys.modules["langchain_openai"]
# Make it raise ImportError on import
sys.modules["langchain_openai"] = None
try:
result = LLMFactory.validate_provider_setup("openai")
assert result["valid"] is False
assert result["package_installed"] is False
assert result["api_key_set"] is True
assert any("Package not installed" in error for error in result["errors"])
finally:
# Restore original modules
sys.modules.update(original_modules)
def test_validate_anthropic_setup_complete(self):
"""Test validating complete Anthropic setup."""
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
with patch("langchain_anthropic"):
result = LLMFactory.validate_provider_setup("anthropic")
assert result["provider"] == "anthropic"
assert result["valid"] is True
assert result["api_key_set"] is True
assert result["package_installed"] is True
def test_validate_google_setup_complete(self):
"""Test validating complete Google setup."""
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}):
with patch("langchain_google_genai"):
result = LLMFactory.validate_provider_setup("google")
assert result["provider"] == "google"
assert result["valid"] is True
assert result["api_key_set"] is True
assert result["package_installed"] is True
class TestConvenienceFunction:
"""Test the convenience create_llm function."""
@patch("langchain_openai.ChatOpenAI")
def test_create_llm_defaults_to_openai(self, mock_openai):
"""Test that create_llm defaults to OpenAI."""
mock_openai.return_value = Mock()
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
llm = create_llm()
assert mock_openai.called
@patch("langchain_openai.ChatOpenAI")
def test_create_llm_auto_selects_model(self, mock_openai):
"""Test that create_llm auto-selects recommended model."""
mock_openai.return_value = Mock()
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
llm = create_llm("openai")
call_kwargs = mock_openai.call_args[1]
# Should use recommended deep thinking model
assert call_kwargs["model"] == "o1-preview"
@patch("langchain_anthropic.ChatAnthropic")
def test_create_llm_with_specified_model(self, mock_anthropic):
"""Test create_llm with specified model."""
mock_anthropic.return_value = Mock()
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
llm = create_llm("anthropic", "claude-3-5-haiku-20241022")
call_kwargs = mock_anthropic.call_args[1]
assert call_kwargs["model"] == "claude-3-5-haiku-20241022"
@pytest.mark.parametrize("provider,model,env_var", [
("openai", "gpt-4o", "OPENAI_API_KEY"),
("anthropic", "claude-3-5-sonnet-20241022", "ANTHROPIC_API_KEY"),
("google", "gemini-1.5-pro", "GOOGLE_API_KEY"),
])
def test_all_providers_require_api_key(provider, model, env_var):
"""Parametrized test: all providers require API keys."""
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(ValueError, match=env_var):
LLMFactory.create_llm(provider, model)
@pytest.mark.parametrize("temperature", [0.0, 0.5, 1.0, 1.5, 2.0])
def test_temperature_values(temperature):
"""Parametrized test: various temperature values."""
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
with patch("tradingagents.llm_factory.ChatOpenAI") as mock_openai:
mock_openai.return_value = Mock()
LLMFactory.create_llm("openai", "gpt-4o", temperature=temperature)
call_kwargs = mock_openai.call_args[1]
assert call_kwargs["temperature"] == temperature