import os from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, field_validator from pydantic_settings import BaseSettings class DataVendorsConfig(BaseModel): core_stock_apis: str = "yfinance" technical_indicators: str = "yfinance" fundamental_data: str = "alpha_vantage" news_data: str = "alpha_vantage" class TradingAgentsSettings(BaseSettings): project_dir: str = Field( default_factory=lambda: os.path.abspath( os.path.join(os.path.dirname(__file__), ".") ) ) results_dir: str = Field(default="./results") data_dir: str = Field(default="./data") data_cache_dir: str | None = None llm_provider: str = Field(default="openai") deep_think_llm: str = Field(default="gpt-5") quick_think_llm: str = Field(default="gpt-5-mini") backend_url: str = Field(default="https://api.openai.com/v1") max_debate_rounds: int = Field(default=2, ge=1, le=10) max_risk_discuss_rounds: int = Field(default=2, ge=1, le=10) max_recur_limit: int = Field(default=100, ge=10, le=500) discovery_timeout: int = Field(default=60, ge=10) discovery_hard_timeout: int = Field(default=120, ge=30) discovery_cache_ttl: int = Field(default=300, ge=60) discovery_max_results: int = Field(default=20, ge=1, le=100) discovery_min_mentions: int = Field(default=2, ge=1) bulk_news_vendor_order: list[str] = Field( default=["tavily", "brave", "alpha_vantage", "openai", "google"] ) bulk_news_timeout: int = Field(default=30, ge=5) bulk_news_max_retries: int = Field(default=3, ge=1) log_level: str = Field(default="INFO") log_dir: str = Field(default="./logs") log_console_enabled: bool = Field(default=True) log_file_enabled: bool = Field(default=True) openai_api_key: str | None = Field(default=None) alpha_vantage_api_key: str | None = Field(default=None) brave_api_key: str | None = Field(default=None) tavily_api_key: str | None = Field(default=None) google_api_key: str | None = Field(default=None) anthropic_api_key: str | None = Field(default=None) data_vendors: DataVendorsConfig = Field(default_factory=DataVendorsConfig) tool_vendors: dict[str, Any] = Field(default_factory=dict) model_config = { "env_prefix": "TRADINGAGENTS_", "env_nested_delimiter": "__", "extra": "ignore", "env_file": ".env", "env_file_encoding": "utf-8", } def model_post_init(self, __context: Any) -> None: if self.data_cache_dir is None: self.data_cache_dir = os.path.join(self.project_dir, "dataflows/data_cache") if self.openai_api_key is None: self.openai_api_key = os.getenv("OPENAI_API_KEY") if self.alpha_vantage_api_key is None: self.alpha_vantage_api_key = os.getenv("ALPHA_VANTAGE_API_KEY") if self.brave_api_key is None: self.brave_api_key = os.getenv("BRAVE_API_KEY") if self.tavily_api_key is None: self.tavily_api_key = os.getenv("TAVILY_API_KEY") if self.google_api_key is None: self.google_api_key = os.getenv("GOOGLE_API_KEY") if self.anthropic_api_key is None: self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") @field_validator("log_level") @classmethod def validate_log_level(cls, v: str) -> str: valid_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} if v.upper() not in valid_levels: raise ValueError(f"Invalid log level: {v}. Must be one of {valid_levels}") return v.upper() @field_validator("llm_provider") @classmethod def validate_llm_provider(cls, v: str) -> str: valid_providers = {"openai", "anthropic", "google", "ollama", "openrouter"} if v.lower() not in valid_providers: raise ValueError( f"Invalid LLM provider: {v}. Must be one of {valid_providers}" ) return v.lower() def to_dict(self) -> dict[str, Any]: result = self.model_dump() result["data_vendors"] = self.data_vendors.model_dump() return result def get_api_key(self, vendor: str) -> str | None: key_map = { "openai": self.openai_api_key, "alpha_vantage": self.alpha_vantage_api_key, "brave": self.brave_api_key, "tavily": self.tavily_api_key, "google": self.google_api_key, "anthropic": self.anthropic_api_key, } return key_map.get(vendor.lower()) def require_api_key(self, vendor: str) -> str: key = self.get_api_key(vendor) if not key: env_var = f"{vendor.upper()}_API_KEY" raise ValueError( f"{vendor} API key not configured. " f"Set {env_var} environment variable or TRADINGAGENTS_{env_var}." ) return key _settings: TradingAgentsSettings | None = None def get_settings() -> TradingAgentsSettings: global _settings if _settings is None: _settings = TradingAgentsSettings() return _settings def reset_settings() -> None: global _settings _settings = None def update_settings(**kwargs) -> TradingAgentsSettings: global _settings current = get_settings() new_values = current.model_dump() new_values.update(kwargs) _settings = TradingAgentsSettings(**new_values) return _settings