406 lines
17 KiB
Python
406 lines
17 KiB
Python
"""
|
|
Comprehensive tests for LLM Factory.
|
|
|
|
Tests provider validation, model recommendations, LLM creation,
|
|
error handling, and environment variable configuration.
|
|
"""
|
|
|
|
import os
|
|
import pytest
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
from decimal import Decimal
|
|
|
|
from tradingagents.llm_factory import LLMFactory, create_llm
|
|
|
|
|
|
class TestLLMFactory:
|
|
"""Test suite for LLMFactory class."""
|
|
|
|
def test_supported_providers(self):
|
|
"""Test that supported providers list is correct."""
|
|
assert "openai" in LLMFactory.SUPPORTED_PROVIDERS
|
|
assert "anthropic" in LLMFactory.SUPPORTED_PROVIDERS
|
|
assert "google" in LLMFactory.SUPPORTED_PROVIDERS
|
|
assert len(LLMFactory.SUPPORTED_PROVIDERS) == 3
|
|
|
|
def test_unsupported_provider_raises_error(self):
|
|
"""Test that unsupported provider raises ValueError."""
|
|
with pytest.raises(ValueError, match="Unsupported LLM provider"):
|
|
LLMFactory.create_llm("unsupported_provider", "some-model")
|
|
|
|
def test_provider_case_insensitive(self):
|
|
"""Test that provider names are case-insensitive."""
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
with patch("langchain_openai.ChatOpenAI") as mock_openai:
|
|
mock_openai.return_value = Mock()
|
|
|
|
# These should all work
|
|
LLMFactory.create_llm("OpenAI", "gpt-4o")
|
|
LLMFactory.create_llm("OPENAI", "gpt-4o")
|
|
LLMFactory.create_llm("openai", "gpt-4o")
|
|
|
|
assert mock_openai.call_count == 3
|
|
|
|
|
|
class TestOpenAILLM:
|
|
"""Test OpenAI LLM creation."""
|
|
|
|
@patch("langchain_openai.ChatOpenAI")
|
|
def test_create_openai_llm_basic(self, mock_openai):
|
|
"""Test basic OpenAI LLM creation."""
|
|
mock_openai.return_value = Mock()
|
|
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
llm = LLMFactory.create_llm("openai", "gpt-4o")
|
|
|
|
assert mock_openai.called
|
|
call_kwargs = mock_openai.call_args[1]
|
|
assert call_kwargs["model"] == "gpt-4o"
|
|
assert call_kwargs["temperature"] == 1.0
|
|
|
|
@patch("langchain_openai.ChatOpenAI")
|
|
def test_create_openai_llm_with_temperature(self, mock_openai):
|
|
"""Test OpenAI LLM creation with custom temperature."""
|
|
mock_openai.return_value = Mock()
|
|
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
LLMFactory.create_llm("openai", "gpt-4o", temperature=0.7)
|
|
|
|
call_kwargs = mock_openai.call_args[1]
|
|
assert call_kwargs["temperature"] == 0.7
|
|
|
|
@patch("langchain_openai.ChatOpenAI")
|
|
def test_create_openai_llm_with_max_tokens(self, mock_openai):
|
|
"""Test OpenAI LLM creation with max_tokens."""
|
|
mock_openai.return_value = Mock()
|
|
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
LLMFactory.create_llm("openai", "gpt-4o", max_tokens=2048)
|
|
|
|
call_kwargs = mock_openai.call_args[1]
|
|
assert call_kwargs["max_tokens"] == 2048
|
|
|
|
@patch("langchain_openai.ChatOpenAI")
|
|
def test_create_openai_llm_with_backend_url(self, mock_openai):
|
|
"""Test OpenAI LLM creation with custom backend URL."""
|
|
mock_openai.return_value = Mock()
|
|
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
custom_url = "https://custom.openai.proxy/v1"
|
|
LLMFactory.create_llm(
|
|
"openai",
|
|
"gpt-4o",
|
|
backend_url=custom_url
|
|
)
|
|
|
|
call_kwargs = mock_openai.call_args[1]
|
|
assert call_kwargs["base_url"] == custom_url
|
|
|
|
@patch("langchain_openai.ChatOpenAI")
|
|
def test_create_openai_llm_with_extra_kwargs(self, mock_openai):
|
|
"""Test OpenAI LLM creation with additional kwargs."""
|
|
mock_openai.return_value = Mock()
|
|
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
LLMFactory.create_llm(
|
|
"openai",
|
|
"gpt-4o",
|
|
streaming=True,
|
|
timeout=30
|
|
)
|
|
|
|
call_kwargs = mock_openai.call_args[1]
|
|
assert call_kwargs["streaming"] is True
|
|
assert call_kwargs["timeout"] == 30
|
|
|
|
def test_create_openai_llm_missing_api_key(self):
|
|
"""Test that missing API key raises ValueError."""
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
with pytest.raises(ValueError, match="OPENAI_API_KEY"):
|
|
LLMFactory.create_llm("openai", "gpt-4o")
|
|
|
|
def test_create_openai_llm_missing_package(self):
|
|
"""Test that missing langchain-openai package raises ImportError."""
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
with patch.dict("sys.modules", {"langchain_openai": None}):
|
|
with pytest.raises(ImportError, match="langchain-openai"):
|
|
LLMFactory.create_llm("openai", "gpt-4o")
|
|
|
|
|
|
class TestAnthropicLLM:
|
|
"""Test Anthropic (Claude) LLM creation."""
|
|
|
|
@patch("langchain_anthropic.ChatAnthropic")
|
|
def test_create_anthropic_llm_basic(self, mock_anthropic):
|
|
"""Test basic Anthropic LLM creation."""
|
|
mock_anthropic.return_value = Mock()
|
|
|
|
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
|
llm = LLMFactory.create_llm("anthropic", "claude-3-5-sonnet-20241022")
|
|
|
|
assert mock_anthropic.called
|
|
call_kwargs = mock_anthropic.call_args[1]
|
|
assert call_kwargs["model"] == "claude-3-5-sonnet-20241022"
|
|
assert call_kwargs["temperature"] == 1.0
|
|
assert call_kwargs["anthropic_api_key"] == "test-key"
|
|
|
|
@patch("langchain_anthropic.ChatAnthropic")
|
|
def test_create_anthropic_llm_with_max_tokens(self, mock_anthropic):
|
|
"""Test Anthropic LLM creation with max_tokens."""
|
|
mock_anthropic.return_value = Mock()
|
|
|
|
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
|
LLMFactory.create_llm("anthropic", "claude-3-5-sonnet-20241022", max_tokens=8192)
|
|
|
|
call_kwargs = mock_anthropic.call_args[1]
|
|
assert call_kwargs["max_tokens"] == 8192
|
|
|
|
@patch("langchain_anthropic.ChatAnthropic")
|
|
def test_create_anthropic_llm_default_max_tokens(self, mock_anthropic):
|
|
"""Test that Anthropic LLM gets default max_tokens if not specified."""
|
|
mock_anthropic.return_value = Mock()
|
|
|
|
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
|
LLMFactory.create_llm("anthropic", "claude-3-5-sonnet-20241022")
|
|
|
|
call_kwargs = mock_anthropic.call_args[1]
|
|
# Claude requires max_tokens, should default to 4096
|
|
assert call_kwargs["max_tokens"] == 4096
|
|
|
|
def test_create_anthropic_llm_missing_api_key(self):
|
|
"""Test that missing API key raises ValueError."""
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
with pytest.raises(ValueError, match="ANTHROPIC_API_KEY"):
|
|
LLMFactory.create_llm("anthropic", "claude-3-5-sonnet-20241022")
|
|
|
|
def test_create_anthropic_llm_missing_package(self):
|
|
"""Test that missing langchain-anthropic package raises ImportError."""
|
|
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
|
with patch.dict("sys.modules", {"langchain_anthropic": None}):
|
|
with pytest.raises(ImportError, match="langchain-anthropic"):
|
|
LLMFactory.create_llm("anthropic", "claude-3-5-sonnet-20241022")
|
|
|
|
|
|
class TestGoogleLLM:
|
|
"""Test Google (Gemini) LLM creation."""
|
|
|
|
@patch("langchain_google_genai.ChatGoogleGenerativeAI")
|
|
def test_create_google_llm_basic(self, mock_google):
|
|
"""Test basic Google LLM creation."""
|
|
mock_google.return_value = Mock()
|
|
|
|
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}):
|
|
llm = LLMFactory.create_llm("google", "gemini-1.5-pro")
|
|
|
|
assert mock_google.called
|
|
call_kwargs = mock_google.call_args[1]
|
|
assert call_kwargs["model"] == "gemini-1.5-pro"
|
|
assert call_kwargs["temperature"] == 1.0
|
|
assert call_kwargs["google_api_key"] == "test-key"
|
|
|
|
@patch("langchain_google_genai.ChatGoogleGenerativeAI")
|
|
def test_create_google_llm_with_max_tokens(self, mock_google):
|
|
"""Test Google LLM creation with max_tokens."""
|
|
mock_google.return_value = Mock()
|
|
|
|
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}):
|
|
LLMFactory.create_llm("google", "gemini-1.5-pro", max_tokens=4096)
|
|
|
|
call_kwargs = mock_google.call_args[1]
|
|
# Google uses max_output_tokens instead of max_tokens
|
|
assert call_kwargs["max_output_tokens"] == 4096
|
|
|
|
def test_create_google_llm_missing_api_key(self):
|
|
"""Test that missing API key raises ValueError."""
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
with pytest.raises(ValueError, match="GOOGLE_API_KEY"):
|
|
LLMFactory.create_llm("google", "gemini-1.5-pro")
|
|
|
|
def test_create_google_llm_missing_package(self):
|
|
"""Test that missing langchain-google-genai package raises ImportError."""
|
|
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}):
|
|
with patch.dict("sys.modules", {"langchain_google_genai": None}):
|
|
with pytest.raises(ImportError, match="langchain-google-genai"):
|
|
LLMFactory.create_llm("google", "gemini-1.5-pro")
|
|
|
|
|
|
class TestModelRecommendations:
|
|
"""Test model recommendation functionality."""
|
|
|
|
def test_get_openai_recommendations(self):
|
|
"""Test getting OpenAI model recommendations."""
|
|
models = LLMFactory.get_recommended_models("openai")
|
|
|
|
assert "deep_thinking" in models
|
|
assert "fast_thinking" in models
|
|
assert "budget" in models
|
|
assert "legacy" in models
|
|
|
|
assert models["deep_thinking"] == "o1-preview"
|
|
assert models["fast_thinking"] == "gpt-4o"
|
|
assert models["budget"] == "gpt-4o-mini"
|
|
|
|
def test_get_anthropic_recommendations(self):
|
|
"""Test getting Anthropic model recommendations."""
|
|
models = LLMFactory.get_recommended_models("anthropic")
|
|
|
|
assert models["deep_thinking"] == "claude-3-5-sonnet-20241022"
|
|
assert models["fast_thinking"] == "claude-3-5-sonnet-20241022"
|
|
assert models["budget"] == "claude-3-5-haiku-20241022"
|
|
|
|
def test_get_google_recommendations(self):
|
|
"""Test getting Google model recommendations."""
|
|
models = LLMFactory.get_recommended_models("google")
|
|
|
|
assert models["deep_thinking"] == "gemini-1.5-pro"
|
|
assert models["fast_thinking"] == "gemini-1.5-flash"
|
|
assert models["budget"] == "gemini-1.5-flash"
|
|
|
|
def test_get_recommendations_case_insensitive(self):
|
|
"""Test that get_recommended_models is case-insensitive."""
|
|
models1 = LLMFactory.get_recommended_models("OpenAI")
|
|
models2 = LLMFactory.get_recommended_models("openai")
|
|
|
|
assert models1 == models2
|
|
|
|
def test_get_recommendations_unknown_provider(self):
|
|
"""Test that unknown provider raises ValueError."""
|
|
with pytest.raises(ValueError, match="Unknown provider"):
|
|
LLMFactory.get_recommended_models("unknown_provider")
|
|
|
|
|
|
class TestProviderValidation:
|
|
"""Test provider validation functionality."""
|
|
|
|
def test_validate_openai_setup_complete(self):
|
|
"""Test validating complete OpenAI setup."""
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
with patch("langchain_openai"):
|
|
result = LLMFactory.validate_provider_setup("openai")
|
|
|
|
assert result["provider"] == "openai"
|
|
assert result["valid"] is True
|
|
assert result["api_key_set"] is True
|
|
assert result["package_installed"] is True
|
|
assert len(result["errors"]) == 0
|
|
|
|
def test_validate_openai_missing_key(self):
|
|
"""Test validating OpenAI setup with missing API key."""
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
with patch("langchain_openai"):
|
|
result = LLMFactory.validate_provider_setup("openai")
|
|
|
|
assert result["valid"] is False
|
|
assert result["api_key_set"] is False
|
|
assert result["package_installed"] is True
|
|
assert any("OPENAI_API_KEY" in error for error in result["errors"])
|
|
|
|
def test_validate_openai_missing_package(self):
|
|
"""Test validating OpenAI setup with missing package."""
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
# Simulate ImportError by patching the import
|
|
import sys
|
|
original_modules = sys.modules.copy()
|
|
|
|
# Remove the module if it exists
|
|
if "langchain_openai" in sys.modules:
|
|
del sys.modules["langchain_openai"]
|
|
|
|
# Make it raise ImportError on import
|
|
sys.modules["langchain_openai"] = None
|
|
|
|
try:
|
|
result = LLMFactory.validate_provider_setup("openai")
|
|
|
|
assert result["valid"] is False
|
|
assert result["package_installed"] is False
|
|
assert result["api_key_set"] is True
|
|
assert any("Package not installed" in error for error in result["errors"])
|
|
finally:
|
|
# Restore original modules
|
|
sys.modules.update(original_modules)
|
|
|
|
def test_validate_anthropic_setup_complete(self):
|
|
"""Test validating complete Anthropic setup."""
|
|
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
|
with patch("langchain_anthropic"):
|
|
result = LLMFactory.validate_provider_setup("anthropic")
|
|
|
|
assert result["provider"] == "anthropic"
|
|
assert result["valid"] is True
|
|
assert result["api_key_set"] is True
|
|
assert result["package_installed"] is True
|
|
|
|
def test_validate_google_setup_complete(self):
|
|
"""Test validating complete Google setup."""
|
|
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}):
|
|
with patch("langchain_google_genai"):
|
|
result = LLMFactory.validate_provider_setup("google")
|
|
|
|
assert result["provider"] == "google"
|
|
assert result["valid"] is True
|
|
assert result["api_key_set"] is True
|
|
assert result["package_installed"] is True
|
|
|
|
|
|
class TestConvenienceFunction:
|
|
"""Test the convenience create_llm function."""
|
|
|
|
@patch("langchain_openai.ChatOpenAI")
|
|
def test_create_llm_defaults_to_openai(self, mock_openai):
|
|
"""Test that create_llm defaults to OpenAI."""
|
|
mock_openai.return_value = Mock()
|
|
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
llm = create_llm()
|
|
|
|
assert mock_openai.called
|
|
|
|
@patch("langchain_openai.ChatOpenAI")
|
|
def test_create_llm_auto_selects_model(self, mock_openai):
|
|
"""Test that create_llm auto-selects recommended model."""
|
|
mock_openai.return_value = Mock()
|
|
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
llm = create_llm("openai")
|
|
|
|
call_kwargs = mock_openai.call_args[1]
|
|
# Should use recommended deep thinking model
|
|
assert call_kwargs["model"] == "o1-preview"
|
|
|
|
@patch("langchain_anthropic.ChatAnthropic")
|
|
def test_create_llm_with_specified_model(self, mock_anthropic):
|
|
"""Test create_llm with specified model."""
|
|
mock_anthropic.return_value = Mock()
|
|
|
|
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
|
llm = create_llm("anthropic", "claude-3-5-haiku-20241022")
|
|
|
|
call_kwargs = mock_anthropic.call_args[1]
|
|
assert call_kwargs["model"] == "claude-3-5-haiku-20241022"
|
|
|
|
|
|
@pytest.mark.parametrize("provider,model,env_var", [
|
|
("openai", "gpt-4o", "OPENAI_API_KEY"),
|
|
("anthropic", "claude-3-5-sonnet-20241022", "ANTHROPIC_API_KEY"),
|
|
("google", "gemini-1.5-pro", "GOOGLE_API_KEY"),
|
|
])
|
|
def test_all_providers_require_api_key(provider, model, env_var):
|
|
"""Parametrized test: all providers require API keys."""
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
with pytest.raises(ValueError, match=env_var):
|
|
LLMFactory.create_llm(provider, model)
|
|
|
|
|
|
@pytest.mark.parametrize("temperature", [0.0, 0.5, 1.0, 1.5, 2.0])
|
|
def test_temperature_values(temperature):
|
|
"""Parametrized test: various temperature values."""
|
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
with patch("tradingagents.llm_factory.ChatOpenAI") as mock_openai:
|
|
mock_openai.return_value = Mock()
|
|
|
|
LLMFactory.create_llm("openai", "gpt-4o", temperature=temperature)
|
|
|
|
call_kwargs = mock_openai.call_args[1]
|
|
assert call_kwargs["temperature"] == temperature
|