diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 00000000..6e72a962 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,168 @@ +import pytest +import os +from unittest.mock import patch + +from tradingagents.config import ( + TradingAgentsSettings, + DataVendorsConfig, + get_settings, + reset_settings, + update_settings, +) + + +class TestDataVendorsConfig: + def test_default_values(self): + config = DataVendorsConfig() + assert config.core_stock_apis == "yfinance" + assert config.technical_indicators == "yfinance" + assert config.fundamental_data == "alpha_vantage" + assert config.news_data == "alpha_vantage" + + def test_custom_values(self): + config = DataVendorsConfig(core_stock_apis="local", news_data="openai") + assert config.core_stock_apis == "local" + assert config.news_data == "openai" + + +class TestTradingAgentsSettings: + def setup_method(self): + reset_settings() + + def teardown_method(self): + reset_settings() + + def test_default_values(self): + settings = TradingAgentsSettings() + assert settings.llm_provider == "openai" + assert settings.log_level == "INFO" + assert settings.max_debate_rounds == 2 + + def test_log_level_validation(self): + settings = TradingAgentsSettings(log_level="debug") + assert settings.log_level == "DEBUG" + + settings = TradingAgentsSettings(log_level="WARNING") + assert settings.log_level == "WARNING" + + def test_log_level_invalid(self): + with pytest.raises(ValueError, match="Invalid log level"): + TradingAgentsSettings(log_level="INVALID") + + def test_llm_provider_validation(self): + settings = TradingAgentsSettings(llm_provider="OPENAI") + assert settings.llm_provider == "openai" + + settings = TradingAgentsSettings(llm_provider="Anthropic") + assert settings.llm_provider == "anthropic" + + def test_llm_provider_invalid(self): + with pytest.raises(ValueError, match="Invalid LLM provider"): + TradingAgentsSettings(llm_provider="invalid_provider") + + def test_to_dict(self): + settings = TradingAgentsSettings() + result = settings.to_dict() + + assert isinstance(result, dict) + assert "llm_provider" in result + assert "data_vendors" in result + assert isinstance(result["data_vendors"], dict) + + def test_get_api_key_returns_value(self): + settings = TradingAgentsSettings(openai_api_key="test-key") + assert settings.get_api_key("openai") == "test-key" + + def test_get_api_key_returns_none_when_not_set(self): + with patch.dict(os.environ, {}, clear=True): + settings = TradingAgentsSettings() + settings.openai_api_key = None + assert settings.get_api_key("openai") is None + + def test_require_api_key_raises_when_not_set(self): + settings = TradingAgentsSettings() + settings.brave_api_key = None + with pytest.raises(ValueError, match="brave API key not configured"): + settings.require_api_key("brave") + + def test_require_api_key_returns_value(self): + settings = TradingAgentsSettings(tavily_api_key="tavily-test") + assert settings.require_api_key("tavily") == "tavily-test" + + def test_max_debate_rounds_bounds(self): + settings = TradingAgentsSettings(max_debate_rounds=5) + assert settings.max_debate_rounds == 5 + + with pytest.raises(ValueError): + TradingAgentsSettings(max_debate_rounds=0) + + with pytest.raises(ValueError): + TradingAgentsSettings(max_debate_rounds=20) + + +class TestConfigFunctions: + def setup_method(self): + reset_settings() + + def teardown_method(self): + reset_settings() + + def test_get_settings_returns_singleton(self): + s1 = get_settings() + s2 = get_settings() + assert s1 is s2 + + def test_reset_settings_clears_singleton(self): + s1 = get_settings() + reset_settings() + s2 = get_settings() + assert s1 is not s2 + + def test_update_settings_modifies_values(self): + original = get_settings() + assert original.max_debate_rounds == 2 + + update_settings(max_debate_rounds=5) + updated = get_settings() + assert updated.max_debate_rounds == 5 + + def test_update_settings_preserves_other_values(self): + original = get_settings() + original_provider = original.llm_provider + + update_settings(max_debate_rounds=5) + updated = get_settings() + assert updated.llm_provider == original_provider + + +class TestDataflowConfigCompat: + def setup_method(self): + reset_settings() + + def teardown_method(self): + reset_settings() + + def test_get_config_returns_dict(self): + from tradingagents.dataflows.config import get_config + + config = get_config() + assert isinstance(config, dict) + assert "llm_provider" in config + assert "data_vendors" in config + + def test_set_config_updates_central_settings(self): + from tradingagents.dataflows.config import get_config, set_config + + set_config({"max_debate_rounds": 4}) + config = get_config() + assert config["max_debate_rounds"] == 4 + + def test_get_config_returns_copy(self): + from tradingagents.dataflows.config import get_config + + c1 = get_config() + c2 = get_config() + assert c1 == c2 + c1["llm_provider"] = "modified" + c3 = get_config() + assert c3["llm_provider"] != "modified" diff --git a/tradingagents/agents/discovery/entity_extractor.py b/tradingagents/agents/discovery/entity_extractor.py index d4efd7c3..88938883 100644 --- a/tradingagents/agents/discovery/entity_extractor.py +++ b/tradingagents/agents/discovery/entity_extractor.py @@ -6,7 +6,7 @@ from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI -from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.dataflows.config import get_config from tradingagents.agents.discovery.models import NewsArticle, EventCategory @@ -37,7 +37,7 @@ class ExtractionResponse(BaseModel): def _get_llm(config: Optional[dict] = None): - cfg = config or DEFAULT_CONFIG + cfg = config or get_config() provider = cfg.get("llm_provider", "openai").lower() model = cfg.get("quick_think_llm", "gpt-4o-mini") backend_url = cfg.get("backend_url", "https://api.openai.com/v1") diff --git a/tradingagents/config.py b/tradingagents/config.py new file mode 100644 index 00000000..7e434ca2 --- /dev/null +++ b/tradingagents/config.py @@ -0,0 +1,147 @@ +import os +from typing import Optional, Dict, Any, List +from pydantic import BaseModel, Field, field_validator +from pydantic_settings import BaseSettings + + +class DataVendorsConfig(BaseModel): + core_stock_apis: str = "yfinance" + technical_indicators: str = "yfinance" + fundamental_data: str = "alpha_vantage" + news_data: str = "alpha_vantage" + + +class TradingAgentsSettings(BaseSettings): + project_dir: str = Field( + default_factory=lambda: os.path.abspath(os.path.join(os.path.dirname(__file__), ".")) + ) + results_dir: str = Field(default="./results") + data_dir: str = Field(default="./data") + data_cache_dir: Optional[str] = None + + llm_provider: str = Field(default="openai") + deep_think_llm: str = Field(default="gpt-5") + quick_think_llm: str = Field(default="gpt-5-mini") + backend_url: str = Field(default="https://api.openai.com/v1") + + max_debate_rounds: int = Field(default=2, ge=1, le=10) + max_risk_discuss_rounds: int = Field(default=2, ge=1, le=10) + max_recur_limit: int = Field(default=100, ge=10, le=500) + + discovery_timeout: int = Field(default=60, ge=10) + discovery_hard_timeout: int = Field(default=120, ge=30) + discovery_cache_ttl: int = Field(default=300, ge=60) + discovery_max_results: int = Field(default=20, ge=1, le=100) + discovery_min_mentions: int = Field(default=2, ge=1) + + bulk_news_vendor_order: List[str] = Field( + default=["tavily", "brave", "alpha_vantage", "openai", "google"] + ) + bulk_news_timeout: int = Field(default=30, ge=5) + bulk_news_max_retries: int = Field(default=3, ge=1) + + log_level: str = Field(default="INFO") + log_dir: str = Field(default="./logs") + log_console_enabled: bool = Field(default=True) + log_file_enabled: bool = Field(default=True) + + openai_api_key: Optional[str] = Field(default=None) + alpha_vantage_api_key: Optional[str] = Field(default=None) + brave_api_key: Optional[str] = Field(default=None) + tavily_api_key: Optional[str] = Field(default=None) + google_api_key: Optional[str] = Field(default=None) + anthropic_api_key: Optional[str] = Field(default=None) + + data_vendors: DataVendorsConfig = Field(default_factory=DataVendorsConfig) + tool_vendors: Dict[str, Any] = Field(default_factory=dict) + + model_config = { + "env_prefix": "TRADINGAGENTS_", + "env_nested_delimiter": "__", + "extra": "ignore", + "env_file": ".env", + "env_file_encoding": "utf-8", + } + + def model_post_init(self, __context: Any) -> None: + if self.data_cache_dir is None: + self.data_cache_dir = os.path.join(self.project_dir, "dataflows/data_cache") + + if self.openai_api_key is None: + self.openai_api_key = os.getenv("OPENAI_API_KEY") + if self.alpha_vantage_api_key is None: + self.alpha_vantage_api_key = os.getenv("ALPHA_VANTAGE_API_KEY") + if self.brave_api_key is None: + self.brave_api_key = os.getenv("BRAVE_API_KEY") + if self.tavily_api_key is None: + self.tavily_api_key = os.getenv("TAVILY_API_KEY") + if self.google_api_key is None: + self.google_api_key = os.getenv("GOOGLE_API_KEY") + if self.anthropic_api_key is None: + self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") + + @field_validator("log_level") + @classmethod + def validate_log_level(cls, v: str) -> str: + valid_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} + if v.upper() not in valid_levels: + raise ValueError(f"Invalid log level: {v}. Must be one of {valid_levels}") + return v.upper() + + @field_validator("llm_provider") + @classmethod + def validate_llm_provider(cls, v: str) -> str: + valid_providers = {"openai", "anthropic", "google", "ollama", "openrouter"} + if v.lower() not in valid_providers: + raise ValueError(f"Invalid LLM provider: {v}. Must be one of {valid_providers}") + return v.lower() + + def to_dict(self) -> Dict[str, Any]: + result = self.model_dump() + result["data_vendors"] = self.data_vendors.model_dump() + return result + + def get_api_key(self, vendor: str) -> Optional[str]: + key_map = { + "openai": self.openai_api_key, + "alpha_vantage": self.alpha_vantage_api_key, + "brave": self.brave_api_key, + "tavily": self.tavily_api_key, + "google": self.google_api_key, + "anthropic": self.anthropic_api_key, + } + return key_map.get(vendor.lower()) + + def require_api_key(self, vendor: str) -> str: + key = self.get_api_key(vendor) + if not key: + env_var = f"{vendor.upper()}_API_KEY" + raise ValueError( + f"{vendor} API key not configured. " + f"Set {env_var} environment variable or TRADINGAGENTS_{env_var}." + ) + return key + + +_settings: Optional[TradingAgentsSettings] = None + + +def get_settings() -> TradingAgentsSettings: + global _settings + if _settings is None: + _settings = TradingAgentsSettings() + return _settings + + +def reset_settings() -> None: + global _settings + _settings = None + + +def update_settings(**kwargs) -> TradingAgentsSettings: + global _settings + current = get_settings() + new_values = current.model_dump() + new_values.update(kwargs) + _settings = TradingAgentsSettings(**new_values) + return _settings diff --git a/tradingagents/dataflows/alpha_vantage_common.py b/tradingagents/dataflows/alpha_vantage_common.py index 9d4ac5d9..d52259e8 100644 --- a/tradingagents/dataflows/alpha_vantage_common.py +++ b/tradingagents/dataflows/alpha_vantage_common.py @@ -11,10 +11,14 @@ logger = logging.getLogger(__name__) API_BASE_URL = "https://www.alphavantage.co/query" def get_api_key() -> str: - api_key = os.getenv("ALPHA_VANTAGE_API_KEY") - if not api_key: - raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.") - return api_key + try: + from tradingagents.config import get_settings + return get_settings().require_api_key("alpha_vantage") + except ImportError: + api_key = os.getenv("ALPHA_VANTAGE_API_KEY") + if not api_key: + raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.") + return api_key def format_datetime_for_api(date_input) -> str: if isinstance(date_input, str): diff --git a/tradingagents/dataflows/brave.py b/tradingagents/dataflows/brave.py index 5135b21e..cea64609 100644 --- a/tradingagents/dataflows/brave.py +++ b/tradingagents/dataflows/brave.py @@ -14,10 +14,14 @@ RETRY_BACKOFF = 1.0 def get_api_key() -> str: - api_key = os.getenv("BRAVE_API_KEY") - if not api_key: - raise ValueError("BRAVE_API_KEY environment variable is not set.") - return api_key + try: + from tradingagents.config import get_settings + return get_settings().require_api_key("brave") + except ImportError: + api_key = os.getenv("BRAVE_API_KEY") + if not api_key: + raise ValueError("BRAVE_API_KEY environment variable is not set.") + return api_key def _make_request_with_retry(url: str, headers: Dict, params: Dict, max_retries: int = MAX_RETRIES) -> requests.Response: diff --git a/tradingagents/dataflows/config.py b/tradingagents/dataflows/config.py index b8a8f8aa..d25ca09c 100644 --- a/tradingagents/dataflows/config.py +++ b/tradingagents/dataflows/config.py @@ -1,34 +1,35 @@ -import tradingagents.default_config as default_config from typing import Dict, Optional +from tradingagents.config import get_settings, update_settings, TradingAgentsSettings -# Use default config but allow it to be overridden _config: Optional[Dict] = None DATA_DIR: Optional[str] = None def initialize_config(): - """Initialize the configuration with default values.""" global _config, DATA_DIR if _config is None: - _config = default_config.DEFAULT_CONFIG.copy() + settings = get_settings() + _config = settings.to_dict() DATA_DIR = _config["data_dir"] def set_config(config: Dict): - """Update the configuration with custom values.""" global _config, DATA_DIR - if _config is None: - _config = default_config.DEFAULT_CONFIG.copy() - _config.update(config) + + settings = get_settings() + current_dict = settings.to_dict() + current_dict.update(config) + update_settings(**current_dict) + + _config = get_settings().to_dict() DATA_DIR = _config["data_dir"] def get_config() -> Dict: - """Get the current configuration.""" + global _config if _config is None: initialize_config() return _config.copy() -# Initialize with default config initialize_config() diff --git a/tradingagents/dataflows/tavily.py b/tradingagents/dataflows/tavily.py index 0599b325..dfb00110 100644 --- a/tradingagents/dataflows/tavily.py +++ b/tradingagents/dataflows/tavily.py @@ -18,10 +18,14 @@ RETRY_BACKOFF = 1.0 def get_api_key() -> str: - api_key = os.getenv("TAVILY_API_KEY") - if not api_key: - raise ValueError("TAVILY_API_KEY environment variable is not set.") - return api_key + try: + from tradingagents.config import get_settings + return get_settings().require_api_key("tavily") + except ImportError: + api_key = os.getenv("TAVILY_API_KEY") + if not api_key: + raise ValueError("TAVILY_API_KEY environment variable is not set.") + return api_key def _search_with_retry(client, query: str, search_depth: str, topic: str, time_range: str, max_results: int, max_retries: int = MAX_RETRIES) -> Dict[str, Any]: diff --git a/tradingagents/dataflows/trending/sector_classifier.py b/tradingagents/dataflows/trending/sector_classifier.py index a8df7533..62d2847c 100644 --- a/tradingagents/dataflows/trending/sector_classifier.py +++ b/tradingagents/dataflows/trending/sector_classifier.py @@ -205,11 +205,12 @@ _sector_cache: Dict[str, str] = {} def _llm_classify_sector(ticker: str) -> str: from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, SystemMessage - from tradingagents.default_config import DEFAULT_CONFIG + from tradingagents.dataflows.config import get_config - llm_name = DEFAULT_CONFIG.get("quick_think_llm", "gpt-4o-mini") - llm_provider = DEFAULT_CONFIG.get("llm_provider", "openai") - backend_url = DEFAULT_CONFIG.get("backend_url", "https://api.openai.com/v1") + config = get_config() + llm_name = config.get("quick_think_llm", "gpt-4o-mini") + llm_provider = config.get("llm_provider", "openai") + backend_url = config.get("backend_url", "https://api.openai.com/v1") llm = ChatOpenAI( model=llm_name, diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 8bc71d8a..3caee5af 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -14,7 +14,7 @@ from langchain_google_genai import ChatGoogleGenerativeAI from langgraph.prebuilt import ToolNode from tradingagents.agents import * -from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.dataflows.config import get_config from tradingagents.agents.utils.memory import FinancialSituationMemory from tradingagents.agents.utils.agent_states import ( AgentState, @@ -76,7 +76,7 @@ class TradingAgentsGraph: config: Dict[str, Any] = None, ): self.debug = debug - self.config = config or DEFAULT_CONFIG + self.config = config or get_config() set_config(self.config) diff --git a/tradingagents/logging.py b/tradingagents/logging.py index 37a86c55..6caea9c2 100644 --- a/tradingagents/logging.py +++ b/tradingagents/logging.py @@ -4,8 +4,6 @@ import os import json from datetime import datetime -LOG_LEVEL_DEFAULT = "INFO" -LOG_DIR_DEFAULT = "./logs" LOG_FILE_NAME = "tradingagents.log" LOG_MAX_BYTES = 10 * 1024 * 1024 LOG_BACKUP_COUNT = 5 @@ -31,44 +29,29 @@ class JSONFormatter(logging.Formatter): return json.dumps(log_record) -def _parse_bool(value): - if isinstance(value, bool): - return value - if isinstance(value, str): - return value.lower() in ("true", "1", "yes", "on") - return bool(value) - - -def _get_config_value(key, default): +def _get_settings(): try: - from tradingagents.default_config import DEFAULT_CONFIG - return DEFAULT_CONFIG.get(key, default) + from tradingagents.config import get_settings + return get_settings() except ImportError: - return default + return None def setup_logging(): global _logging_initialized - log_level_str = os.getenv("TRADINGAGENTS_LOG_LEVEL") - if log_level_str is None: - log_level_str = _get_config_value("log_level", LOG_LEVEL_DEFAULT) + settings = _get_settings() - log_dir = os.getenv("TRADINGAGENTS_LOG_DIR") - if log_dir is None: - log_dir = _get_config_value("log_dir", LOG_DIR_DEFAULT) - - console_enabled_env = os.getenv("TRADINGAGENTS_LOG_CONSOLE") - if console_enabled_env is not None: - console_enabled = _parse_bool(console_enabled_env) + if settings: + log_level_str = settings.log_level + log_dir = settings.log_dir + console_enabled = settings.log_console_enabled + file_enabled = settings.log_file_enabled else: - console_enabled = _get_config_value("log_console_enabled", True) - - file_enabled_env = os.getenv("TRADINGAGENTS_LOG_FILE") - if file_enabled_env is not None: - file_enabled = _parse_bool(file_enabled_env) - else: - file_enabled = _get_config_value("log_file_enabled", True) + log_level_str = os.getenv("TRADINGAGENTS_LOG_LEVEL", "INFO") + log_dir = os.getenv("TRADINGAGENTS_LOG_DIR", "./logs") + console_enabled = os.getenv("TRADINGAGENTS_LOG_CONSOLE", "true").lower() in ("true", "1", "yes", "on") + file_enabled = os.getenv("TRADINGAGENTS_LOG_FILE", "true").lower() in ("true", "1", "yes", "on") log_level = getattr(logging, log_level_str.upper(), logging.INFO)