TradingAgents/tests/test_config.py

170 lines
5.3 KiB
Python

import os
from unittest.mock import patch
import pytest
from tradingagents.config import (
DataVendorsConfig,
TradingAgentsSettings,
get_settings,
reset_settings,
update_settings,
)
class TestDataVendorsConfig:
def test_default_values(self):
config = DataVendorsConfig()
assert config.core_stock_apis == "yfinance"
assert config.technical_indicators == "yfinance"
assert config.fundamental_data == "alpha_vantage"
assert config.news_data == "alpha_vantage"
def test_custom_values(self):
config = DataVendorsConfig(core_stock_apis="local", news_data="openai")
assert config.core_stock_apis == "local"
assert config.news_data == "openai"
class TestTradingAgentsSettings:
def setup_method(self):
reset_settings()
def teardown_method(self):
reset_settings()
def test_default_values(self):
settings = TradingAgentsSettings()
assert settings.llm_provider == "openai"
assert settings.log_level == "INFO"
assert settings.max_debate_rounds == 2
def test_log_level_validation(self):
settings = TradingAgentsSettings(log_level="debug")
assert settings.log_level == "DEBUG"
settings = TradingAgentsSettings(log_level="WARNING")
assert settings.log_level == "WARNING"
def test_log_level_invalid(self):
with pytest.raises(ValueError, match="Invalid log level"):
TradingAgentsSettings(log_level="INVALID")
def test_llm_provider_validation(self):
settings = TradingAgentsSettings(llm_provider="OPENAI")
assert settings.llm_provider == "openai"
settings = TradingAgentsSettings(llm_provider="Anthropic")
assert settings.llm_provider == "anthropic"
def test_llm_provider_invalid(self):
with pytest.raises(ValueError, match="Invalid LLM provider"):
TradingAgentsSettings(llm_provider="invalid_provider")
def test_to_dict(self):
settings = TradingAgentsSettings()
result = settings.to_dict()
assert isinstance(result, dict)
assert "llm_provider" in result
assert "data_vendors" in result
assert isinstance(result["data_vendors"], dict)
def test_get_api_key_returns_value(self):
settings = TradingAgentsSettings(openai_api_key="test-key")
assert settings.get_api_key("openai") == "test-key"
def test_get_api_key_returns_none_when_not_set(self):
with patch.dict(os.environ, {}, clear=True):
settings = TradingAgentsSettings()
settings.openai_api_key = None
assert settings.get_api_key("openai") is None
def test_require_api_key_raises_when_not_set(self):
settings = TradingAgentsSettings()
settings.brave_api_key = None
with pytest.raises(ValueError, match="brave API key not configured"):
settings.require_api_key("brave")
def test_require_api_key_returns_value(self):
settings = TradingAgentsSettings(tavily_api_key="tavily-test")
assert settings.require_api_key("tavily") == "tavily-test"
def test_max_debate_rounds_bounds(self):
settings = TradingAgentsSettings(max_debate_rounds=5)
assert settings.max_debate_rounds == 5
with pytest.raises(ValueError):
TradingAgentsSettings(max_debate_rounds=0)
with pytest.raises(ValueError):
TradingAgentsSettings(max_debate_rounds=20)
class TestConfigFunctions:
def setup_method(self):
reset_settings()
def teardown_method(self):
reset_settings()
def test_get_settings_returns_singleton(self):
s1 = get_settings()
s2 = get_settings()
assert s1 is s2
def test_reset_settings_clears_singleton(self):
s1 = get_settings()
reset_settings()
s2 = get_settings()
assert s1 is not s2
def test_update_settings_modifies_values(self):
original = get_settings()
assert original.max_debate_rounds == 2
update_settings(max_debate_rounds=5)
updated = get_settings()
assert updated.max_debate_rounds == 5
def test_update_settings_preserves_other_values(self):
original = get_settings()
original_provider = original.llm_provider
update_settings(max_debate_rounds=5)
updated = get_settings()
assert updated.llm_provider == original_provider
class TestDataflowConfigCompat:
def setup_method(self):
reset_settings()
def teardown_method(self):
reset_settings()
def test_get_config_returns_dict(self):
from tradingagents.dataflows.config import get_config
config = get_config()
assert isinstance(config, dict)
assert "llm_provider" in config
assert "data_vendors" in config
def test_set_config_updates_central_settings(self):
from tradingagents.dataflows.config import get_config, set_config
set_config({"max_debate_rounds": 4})
config = get_config()
assert config["max_debate_rounds"] == 4
def test_get_config_returns_copy(self):
from tradingagents.dataflows.config import get_config
c1 = get_config()
c2 = get_config()
assert c1 == c2
c1["llm_provider"] = "modified"
c3 = get_config()
assert c3["llm_provider"] != "modified"