TradingAgents/tradingagents/config.py

148 lines
5.3 KiB
Python

import os
from typing import Optional, Dict, Any, List
from pydantic import BaseModel, Field, field_validator
from pydantic_settings import BaseSettings
class DataVendorsConfig(BaseModel):
core_stock_apis: str = "yfinance"
technical_indicators: str = "yfinance"
fundamental_data: str = "alpha_vantage"
news_data: str = "alpha_vantage"
class TradingAgentsSettings(BaseSettings):
project_dir: str = Field(
default_factory=lambda: os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
)
results_dir: str = Field(default="./results")
data_dir: str = Field(default="./data")
data_cache_dir: Optional[str] = None
llm_provider: str = Field(default="openai")
deep_think_llm: str = Field(default="gpt-5")
quick_think_llm: str = Field(default="gpt-5-mini")
backend_url: str = Field(default="https://api.openai.com/v1")
max_debate_rounds: int = Field(default=2, ge=1, le=10)
max_risk_discuss_rounds: int = Field(default=2, ge=1, le=10)
max_recur_limit: int = Field(default=100, ge=10, le=500)
discovery_timeout: int = Field(default=60, ge=10)
discovery_hard_timeout: int = Field(default=120, ge=30)
discovery_cache_ttl: int = Field(default=300, ge=60)
discovery_max_results: int = Field(default=20, ge=1, le=100)
discovery_min_mentions: int = Field(default=2, ge=1)
bulk_news_vendor_order: List[str] = Field(
default=["tavily", "brave", "alpha_vantage", "openai", "google"]
)
bulk_news_timeout: int = Field(default=30, ge=5)
bulk_news_max_retries: int = Field(default=3, ge=1)
log_level: str = Field(default="INFO")
log_dir: str = Field(default="./logs")
log_console_enabled: bool = Field(default=True)
log_file_enabled: bool = Field(default=True)
openai_api_key: Optional[str] = Field(default=None)
alpha_vantage_api_key: Optional[str] = Field(default=None)
brave_api_key: Optional[str] = Field(default=None)
tavily_api_key: Optional[str] = Field(default=None)
google_api_key: Optional[str] = Field(default=None)
anthropic_api_key: Optional[str] = Field(default=None)
data_vendors: DataVendorsConfig = Field(default_factory=DataVendorsConfig)
tool_vendors: Dict[str, Any] = Field(default_factory=dict)
model_config = {
"env_prefix": "TRADINGAGENTS_",
"env_nested_delimiter": "__",
"extra": "ignore",
"env_file": ".env",
"env_file_encoding": "utf-8",
}
def model_post_init(self, __context: Any) -> None:
if self.data_cache_dir is None:
self.data_cache_dir = os.path.join(self.project_dir, "dataflows/data_cache")
if self.openai_api_key is None:
self.openai_api_key = os.getenv("OPENAI_API_KEY")
if self.alpha_vantage_api_key is None:
self.alpha_vantage_api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
if self.brave_api_key is None:
self.brave_api_key = os.getenv("BRAVE_API_KEY")
if self.tavily_api_key is None:
self.tavily_api_key = os.getenv("TAVILY_API_KEY")
if self.google_api_key is None:
self.google_api_key = os.getenv("GOOGLE_API_KEY")
if self.anthropic_api_key is None:
self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
@field_validator("log_level")
@classmethod
def validate_log_level(cls, v: str) -> str:
valid_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
if v.upper() not in valid_levels:
raise ValueError(f"Invalid log level: {v}. Must be one of {valid_levels}")
return v.upper()
@field_validator("llm_provider")
@classmethod
def validate_llm_provider(cls, v: str) -> str:
valid_providers = {"openai", "anthropic", "google", "ollama", "openrouter"}
if v.lower() not in valid_providers:
raise ValueError(f"Invalid LLM provider: {v}. Must be one of {valid_providers}")
return v.lower()
def to_dict(self) -> Dict[str, Any]:
result = self.model_dump()
result["data_vendors"] = self.data_vendors.model_dump()
return result
def get_api_key(self, vendor: str) -> Optional[str]:
key_map = {
"openai": self.openai_api_key,
"alpha_vantage": self.alpha_vantage_api_key,
"brave": self.brave_api_key,
"tavily": self.tavily_api_key,
"google": self.google_api_key,
"anthropic": self.anthropic_api_key,
}
return key_map.get(vendor.lower())
def require_api_key(self, vendor: str) -> str:
key = self.get_api_key(vendor)
if not key:
env_var = f"{vendor.upper()}_API_KEY"
raise ValueError(
f"{vendor} API key not configured. "
f"Set {env_var} environment variable or TRADINGAGENTS_{env_var}."
)
return key
_settings: Optional[TradingAgentsSettings] = None
def get_settings() -> TradingAgentsSettings:
global _settings
if _settings is None:
_settings = TradingAgentsSettings()
return _settings
def reset_settings() -> None:
global _settings
_settings = None
def update_settings(**kwargs) -> TradingAgentsSettings:
global _settings
current = get_settings()
new_values = current.model_dump()
new_values.update(kwargs)
_settings = TradingAgentsSettings(**new_values)
return _settings