refactor: consolidate configuration management with Pydantic-based settings

- Add tradingagents/config.py with TradingAgentsSettings Pydantic model
- Centralize all environment variable handling and API key management
- Add validation for log_level, llm_provider, and numeric bounds
- Update dataflows/config.py to use central config as backend
- Update logging.py to get settings from central config
- Update API key access in alpha_vantage, brave, and tavily modules
- Replace DEFAULT_CONFIG imports with get_config() calls
- Add 20 comprehensive tests for config module

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Joseph O'Brien 2025-12-03 03:20:41 -05:00
parent 0c04bdb0ee
commit 9c252fdc2c
10 changed files with 373 additions and 61 deletions

168
tests/test_config.py Normal file
View File

@ -0,0 +1,168 @@
import pytest
import os
from unittest.mock import patch
from tradingagents.config import (
TradingAgentsSettings,
DataVendorsConfig,
get_settings,
reset_settings,
update_settings,
)
class TestDataVendorsConfig:
def test_default_values(self):
config = DataVendorsConfig()
assert config.core_stock_apis == "yfinance"
assert config.technical_indicators == "yfinance"
assert config.fundamental_data == "alpha_vantage"
assert config.news_data == "alpha_vantage"
def test_custom_values(self):
config = DataVendorsConfig(core_stock_apis="local", news_data="openai")
assert config.core_stock_apis == "local"
assert config.news_data == "openai"
class TestTradingAgentsSettings:
def setup_method(self):
reset_settings()
def teardown_method(self):
reset_settings()
def test_default_values(self):
settings = TradingAgentsSettings()
assert settings.llm_provider == "openai"
assert settings.log_level == "INFO"
assert settings.max_debate_rounds == 2
def test_log_level_validation(self):
settings = TradingAgentsSettings(log_level="debug")
assert settings.log_level == "DEBUG"
settings = TradingAgentsSettings(log_level="WARNING")
assert settings.log_level == "WARNING"
def test_log_level_invalid(self):
with pytest.raises(ValueError, match="Invalid log level"):
TradingAgentsSettings(log_level="INVALID")
def test_llm_provider_validation(self):
settings = TradingAgentsSettings(llm_provider="OPENAI")
assert settings.llm_provider == "openai"
settings = TradingAgentsSettings(llm_provider="Anthropic")
assert settings.llm_provider == "anthropic"
def test_llm_provider_invalid(self):
with pytest.raises(ValueError, match="Invalid LLM provider"):
TradingAgentsSettings(llm_provider="invalid_provider")
def test_to_dict(self):
settings = TradingAgentsSettings()
result = settings.to_dict()
assert isinstance(result, dict)
assert "llm_provider" in result
assert "data_vendors" in result
assert isinstance(result["data_vendors"], dict)
def test_get_api_key_returns_value(self):
settings = TradingAgentsSettings(openai_api_key="test-key")
assert settings.get_api_key("openai") == "test-key"
def test_get_api_key_returns_none_when_not_set(self):
with patch.dict(os.environ, {}, clear=True):
settings = TradingAgentsSettings()
settings.openai_api_key = None
assert settings.get_api_key("openai") is None
def test_require_api_key_raises_when_not_set(self):
settings = TradingAgentsSettings()
settings.brave_api_key = None
with pytest.raises(ValueError, match="brave API key not configured"):
settings.require_api_key("brave")
def test_require_api_key_returns_value(self):
settings = TradingAgentsSettings(tavily_api_key="tavily-test")
assert settings.require_api_key("tavily") == "tavily-test"
def test_max_debate_rounds_bounds(self):
settings = TradingAgentsSettings(max_debate_rounds=5)
assert settings.max_debate_rounds == 5
with pytest.raises(ValueError):
TradingAgentsSettings(max_debate_rounds=0)
with pytest.raises(ValueError):
TradingAgentsSettings(max_debate_rounds=20)
class TestConfigFunctions:
def setup_method(self):
reset_settings()
def teardown_method(self):
reset_settings()
def test_get_settings_returns_singleton(self):
s1 = get_settings()
s2 = get_settings()
assert s1 is s2
def test_reset_settings_clears_singleton(self):
s1 = get_settings()
reset_settings()
s2 = get_settings()
assert s1 is not s2
def test_update_settings_modifies_values(self):
original = get_settings()
assert original.max_debate_rounds == 2
update_settings(max_debate_rounds=5)
updated = get_settings()
assert updated.max_debate_rounds == 5
def test_update_settings_preserves_other_values(self):
original = get_settings()
original_provider = original.llm_provider
update_settings(max_debate_rounds=5)
updated = get_settings()
assert updated.llm_provider == original_provider
class TestDataflowConfigCompat:
def setup_method(self):
reset_settings()
def teardown_method(self):
reset_settings()
def test_get_config_returns_dict(self):
from tradingagents.dataflows.config import get_config
config = get_config()
assert isinstance(config, dict)
assert "llm_provider" in config
assert "data_vendors" in config
def test_set_config_updates_central_settings(self):
from tradingagents.dataflows.config import get_config, set_config
set_config({"max_debate_rounds": 4})
config = get_config()
assert config["max_debate_rounds"] == 4
def test_get_config_returns_copy(self):
from tradingagents.dataflows.config import get_config
c1 = get_config()
c2 = get_config()
assert c1 == c2
c1["llm_provider"] = "modified"
c3 = get_config()
assert c3["llm_provider"] != "modified"

View File

@ -6,7 +6,7 @@ from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.dataflows.config import get_config
from tradingagents.agents.discovery.models import NewsArticle, EventCategory
@ -37,7 +37,7 @@ class ExtractionResponse(BaseModel):
def _get_llm(config: Optional[dict] = None):
cfg = config or DEFAULT_CONFIG
cfg = config or get_config()
provider = cfg.get("llm_provider", "openai").lower()
model = cfg.get("quick_think_llm", "gpt-4o-mini")
backend_url = cfg.get("backend_url", "https://api.openai.com/v1")

147
tradingagents/config.py Normal file
View File

@ -0,0 +1,147 @@
import os
from typing import Optional, Dict, Any, List
from pydantic import BaseModel, Field, field_validator
from pydantic_settings import BaseSettings
class DataVendorsConfig(BaseModel):
core_stock_apis: str = "yfinance"
technical_indicators: str = "yfinance"
fundamental_data: str = "alpha_vantage"
news_data: str = "alpha_vantage"
class TradingAgentsSettings(BaseSettings):
project_dir: str = Field(
default_factory=lambda: os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
)
results_dir: str = Field(default="./results")
data_dir: str = Field(default="./data")
data_cache_dir: Optional[str] = None
llm_provider: str = Field(default="openai")
deep_think_llm: str = Field(default="gpt-5")
quick_think_llm: str = Field(default="gpt-5-mini")
backend_url: str = Field(default="https://api.openai.com/v1")
max_debate_rounds: int = Field(default=2, ge=1, le=10)
max_risk_discuss_rounds: int = Field(default=2, ge=1, le=10)
max_recur_limit: int = Field(default=100, ge=10, le=500)
discovery_timeout: int = Field(default=60, ge=10)
discovery_hard_timeout: int = Field(default=120, ge=30)
discovery_cache_ttl: int = Field(default=300, ge=60)
discovery_max_results: int = Field(default=20, ge=1, le=100)
discovery_min_mentions: int = Field(default=2, ge=1)
bulk_news_vendor_order: List[str] = Field(
default=["tavily", "brave", "alpha_vantage", "openai", "google"]
)
bulk_news_timeout: int = Field(default=30, ge=5)
bulk_news_max_retries: int = Field(default=3, ge=1)
log_level: str = Field(default="INFO")
log_dir: str = Field(default="./logs")
log_console_enabled: bool = Field(default=True)
log_file_enabled: bool = Field(default=True)
openai_api_key: Optional[str] = Field(default=None)
alpha_vantage_api_key: Optional[str] = Field(default=None)
brave_api_key: Optional[str] = Field(default=None)
tavily_api_key: Optional[str] = Field(default=None)
google_api_key: Optional[str] = Field(default=None)
anthropic_api_key: Optional[str] = Field(default=None)
data_vendors: DataVendorsConfig = Field(default_factory=DataVendorsConfig)
tool_vendors: Dict[str, Any] = Field(default_factory=dict)
model_config = {
"env_prefix": "TRADINGAGENTS_",
"env_nested_delimiter": "__",
"extra": "ignore",
"env_file": ".env",
"env_file_encoding": "utf-8",
}
def model_post_init(self, __context: Any) -> None:
if self.data_cache_dir is None:
self.data_cache_dir = os.path.join(self.project_dir, "dataflows/data_cache")
if self.openai_api_key is None:
self.openai_api_key = os.getenv("OPENAI_API_KEY")
if self.alpha_vantage_api_key is None:
self.alpha_vantage_api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
if self.brave_api_key is None:
self.brave_api_key = os.getenv("BRAVE_API_KEY")
if self.tavily_api_key is None:
self.tavily_api_key = os.getenv("TAVILY_API_KEY")
if self.google_api_key is None:
self.google_api_key = os.getenv("GOOGLE_API_KEY")
if self.anthropic_api_key is None:
self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
@field_validator("log_level")
@classmethod
def validate_log_level(cls, v: str) -> str:
valid_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
if v.upper() not in valid_levels:
raise ValueError(f"Invalid log level: {v}. Must be one of {valid_levels}")
return v.upper()
@field_validator("llm_provider")
@classmethod
def validate_llm_provider(cls, v: str) -> str:
valid_providers = {"openai", "anthropic", "google", "ollama", "openrouter"}
if v.lower() not in valid_providers:
raise ValueError(f"Invalid LLM provider: {v}. Must be one of {valid_providers}")
return v.lower()
def to_dict(self) -> Dict[str, Any]:
result = self.model_dump()
result["data_vendors"] = self.data_vendors.model_dump()
return result
def get_api_key(self, vendor: str) -> Optional[str]:
key_map = {
"openai": self.openai_api_key,
"alpha_vantage": self.alpha_vantage_api_key,
"brave": self.brave_api_key,
"tavily": self.tavily_api_key,
"google": self.google_api_key,
"anthropic": self.anthropic_api_key,
}
return key_map.get(vendor.lower())
def require_api_key(self, vendor: str) -> str:
key = self.get_api_key(vendor)
if not key:
env_var = f"{vendor.upper()}_API_KEY"
raise ValueError(
f"{vendor} API key not configured. "
f"Set {env_var} environment variable or TRADINGAGENTS_{env_var}."
)
return key
_settings: Optional[TradingAgentsSettings] = None
def get_settings() -> TradingAgentsSettings:
global _settings
if _settings is None:
_settings = TradingAgentsSettings()
return _settings
def reset_settings() -> None:
global _settings
_settings = None
def update_settings(**kwargs) -> TradingAgentsSettings:
global _settings
current = get_settings()
new_values = current.model_dump()
new_values.update(kwargs)
_settings = TradingAgentsSettings(**new_values)
return _settings

View File

@ -11,10 +11,14 @@ logger = logging.getLogger(__name__)
API_BASE_URL = "https://www.alphavantage.co/query"
def get_api_key() -> str:
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
if not api_key:
raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.")
return api_key
try:
from tradingagents.config import get_settings
return get_settings().require_api_key("alpha_vantage")
except ImportError:
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
if not api_key:
raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.")
return api_key
def format_datetime_for_api(date_input) -> str:
if isinstance(date_input, str):

View File

@ -14,10 +14,14 @@ RETRY_BACKOFF = 1.0
def get_api_key() -> str:
api_key = os.getenv("BRAVE_API_KEY")
if not api_key:
raise ValueError("BRAVE_API_KEY environment variable is not set.")
return api_key
try:
from tradingagents.config import get_settings
return get_settings().require_api_key("brave")
except ImportError:
api_key = os.getenv("BRAVE_API_KEY")
if not api_key:
raise ValueError("BRAVE_API_KEY environment variable is not set.")
return api_key
def _make_request_with_retry(url: str, headers: Dict, params: Dict, max_retries: int = MAX_RETRIES) -> requests.Response:

View File

@ -1,34 +1,35 @@
import tradingagents.default_config as default_config
from typing import Dict, Optional
from tradingagents.config import get_settings, update_settings, TradingAgentsSettings
# Use default config but allow it to be overridden
_config: Optional[Dict] = None
DATA_DIR: Optional[str] = None
def initialize_config():
"""Initialize the configuration with default values."""
global _config, DATA_DIR
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
settings = get_settings()
_config = settings.to_dict()
DATA_DIR = _config["data_dir"]
def set_config(config: Dict):
"""Update the configuration with custom values."""
global _config, DATA_DIR
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
_config.update(config)
settings = get_settings()
current_dict = settings.to_dict()
current_dict.update(config)
update_settings(**current_dict)
_config = get_settings().to_dict()
DATA_DIR = _config["data_dir"]
def get_config() -> Dict:
"""Get the current configuration."""
global _config
if _config is None:
initialize_config()
return _config.copy()
# Initialize with default config
initialize_config()

View File

@ -18,10 +18,14 @@ RETRY_BACKOFF = 1.0
def get_api_key() -> str:
api_key = os.getenv("TAVILY_API_KEY")
if not api_key:
raise ValueError("TAVILY_API_KEY environment variable is not set.")
return api_key
try:
from tradingagents.config import get_settings
return get_settings().require_api_key("tavily")
except ImportError:
api_key = os.getenv("TAVILY_API_KEY")
if not api_key:
raise ValueError("TAVILY_API_KEY environment variable is not set.")
return api_key
def _search_with_retry(client, query: str, search_depth: str, topic: str, time_range: str, max_results: int, max_retries: int = MAX_RETRIES) -> Dict[str, Any]:

View File

@ -205,11 +205,12 @@ _sector_cache: Dict[str, str] = {}
def _llm_classify_sector(ticker: str) -> str:
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.dataflows.config import get_config
llm_name = DEFAULT_CONFIG.get("quick_think_llm", "gpt-4o-mini")
llm_provider = DEFAULT_CONFIG.get("llm_provider", "openai")
backend_url = DEFAULT_CONFIG.get("backend_url", "https://api.openai.com/v1")
config = get_config()
llm_name = config.get("quick_think_llm", "gpt-4o-mini")
llm_provider = config.get("llm_provider", "openai")
backend_url = config.get("backend_url", "https://api.openai.com/v1")
llm = ChatOpenAI(
model=llm_name,

View File

@ -14,7 +14,7 @@ from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.prebuilt import ToolNode
from tradingagents.agents import *
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.dataflows.config import get_config
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.agents.utils.agent_states import (
AgentState,
@ -76,7 +76,7 @@ class TradingAgentsGraph:
config: Dict[str, Any] = None,
):
self.debug = debug
self.config = config or DEFAULT_CONFIG
self.config = config or get_config()
set_config(self.config)

View File

@ -4,8 +4,6 @@ import os
import json
from datetime import datetime
LOG_LEVEL_DEFAULT = "INFO"
LOG_DIR_DEFAULT = "./logs"
LOG_FILE_NAME = "tradingagents.log"
LOG_MAX_BYTES = 10 * 1024 * 1024
LOG_BACKUP_COUNT = 5
@ -31,44 +29,29 @@ class JSONFormatter(logging.Formatter):
return json.dumps(log_record)
def _parse_bool(value):
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in ("true", "1", "yes", "on")
return bool(value)
def _get_config_value(key, default):
def _get_settings():
try:
from tradingagents.default_config import DEFAULT_CONFIG
return DEFAULT_CONFIG.get(key, default)
from tradingagents.config import get_settings
return get_settings()
except ImportError:
return default
return None
def setup_logging():
global _logging_initialized
log_level_str = os.getenv("TRADINGAGENTS_LOG_LEVEL")
if log_level_str is None:
log_level_str = _get_config_value("log_level", LOG_LEVEL_DEFAULT)
settings = _get_settings()
log_dir = os.getenv("TRADINGAGENTS_LOG_DIR")
if log_dir is None:
log_dir = _get_config_value("log_dir", LOG_DIR_DEFAULT)
console_enabled_env = os.getenv("TRADINGAGENTS_LOG_CONSOLE")
if console_enabled_env is not None:
console_enabled = _parse_bool(console_enabled_env)
if settings:
log_level_str = settings.log_level
log_dir = settings.log_dir
console_enabled = settings.log_console_enabled
file_enabled = settings.log_file_enabled
else:
console_enabled = _get_config_value("log_console_enabled", True)
file_enabled_env = os.getenv("TRADINGAGENTS_LOG_FILE")
if file_enabled_env is not None:
file_enabled = _parse_bool(file_enabled_env)
else:
file_enabled = _get_config_value("log_file_enabled", True)
log_level_str = os.getenv("TRADINGAGENTS_LOG_LEVEL", "INFO")
log_dir = os.getenv("TRADINGAGENTS_LOG_DIR", "./logs")
console_enabled = os.getenv("TRADINGAGENTS_LOG_CONSOLE", "true").lower() in ("true", "1", "yes", "on")
file_enabled = os.getenv("TRADINGAGENTS_LOG_FILE", "true").lower() in ("true", "1", "yes", "on")
log_level = getattr(logging, log_level_str.upper(), logging.INFO)