TradingAgents/tradingagents/config_validation.py

187 lines
5.8 KiB
Python

"""Configuration validation for the TradingAgents framework.
This module provides validation functions to ensure configuration
settings are correct before runtime.
"""
import os
from typing import Any
# Valid LLM providers
VALID_PROVIDERS = ["openai", "anthropic", "google", "xai", "ollama", "openrouter"]
# Valid data vendors
VALID_DATA_VENDORS = ["yfinance", "alpha_vantage"]
# Required API key environment variables by provider
PROVIDER_API_KEYS = {
"openai": "OPENAI_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"google": "GOOGLE_API_KEY",
"xai": "XAI_API_KEY",
"openrouter": "OPENROUTER_API_KEY",
}
def validate_config(config: dict[str, Any]) -> list[str]:
"""Validate configuration dictionary.
Args:
config: Configuration dictionary to validate.
Returns:
List of validation error messages (empty if valid).
Example:
>>> errors = validate_config(config)
>>> if errors:
... print("Configuration errors:", errors)
"""
errors = []
# Validate LLM provider
provider = config.get("llm_provider", "").lower()
if provider not in VALID_PROVIDERS:
errors.append(
f"Invalid llm_provider: '{provider}'. Must be one of {VALID_PROVIDERS}"
)
# Validate deep_think_llm
if not config.get("deep_think_llm"):
errors.append("deep_think_llm is required")
# Validate quick_think_llm
if not config.get("quick_think_llm"):
errors.append("quick_think_llm is required")
# Validate data vendors
data_vendors = config.get("data_vendors", {})
for category, vendor in data_vendors.items():
if vendor not in VALID_DATA_VENDORS:
errors.append(
f"Invalid data vendor for {category}: '{vendor}'. "
f"Must be one of {VALID_DATA_VENDORS}"
)
# Validate tool vendors
tool_vendors = config.get("tool_vendors", {})
for tool, vendor in tool_vendors.items():
if vendor not in VALID_DATA_VENDORS:
errors.append(
f"Invalid tool vendor for {tool}: '{vendor}'. "
f"Must be one of {VALID_DATA_VENDORS}"
)
# Validate numeric settings
max_debate_rounds = config.get("max_debate_rounds", 1)
if not isinstance(max_debate_rounds, int) or max_debate_rounds < 1:
errors.append("max_debate_rounds must be a positive integer")
max_risk_discuss_rounds = config.get("max_risk_discuss_rounds", 1)
if not isinstance(max_risk_discuss_rounds, int) or max_risk_discuss_rounds < 1:
errors.append("max_risk_discuss_rounds must be a positive integer")
max_recur_limit = config.get("max_recur_limit", 100)
if not isinstance(max_recur_limit, int) or max_recur_limit < 1:
errors.append("max_recur_limit must be a positive integer")
return errors
def validate_api_keys(config: dict[str, Any]) -> list[str]:
"""Validate that required API keys are set for the configured provider.
Args:
config: Configuration dictionary containing llm_provider.
Returns:
List of validation error messages (empty if valid).
Example:
>>> errors = validate_api_keys(config)
>>> if errors:
... print("Missing API keys:", errors)
"""
errors = []
provider = config.get("llm_provider", "").lower()
env_key = PROVIDER_API_KEYS.get(provider)
if env_key and not os.environ.get(env_key):
errors.append(f"{env_key} not set for {provider} provider")
# Check for Alpha Vantage key if using alpha_vantage vendor
data_vendors = config.get("data_vendors", {})
tool_vendors = config.get("tool_vendors", {})
uses_alpha_vantage = (
any(v == "alpha_vantage" for v in data_vendors.values()) or
any(v == "alpha_vantage" for v in tool_vendors.values())
)
if uses_alpha_vantage and not os.environ.get("ALPHA_VANTAGE_API_KEY"):
errors.append("ALPHA_VANTAGE_API_KEY not set but alpha_vantage vendor is configured")
return errors
def validate_config_full(config: dict[str, Any]) -> list[str]:
"""Perform full configuration validation including API keys.
Args:
config: Configuration dictionary to validate.
Returns:
List of all validation error messages (empty if valid).
Example:
>>> errors = validate_config_full(config)
>>> if errors:
... for error in errors:
... print(f"Error: {error}")
... sys.exit(1)
"""
errors = validate_config(config)
errors.extend(validate_api_keys(config))
return errors
def get_validation_report(config: dict[str, Any]) -> str:
"""Get a human-readable validation report.
Args:
config: Configuration dictionary to validate.
Returns:
Formatted string with validation results.
Example:
>>> report = get_validation_report(config)
>>> print(report)
"""
errors = validate_config_full(config)
lines = ["Configuration Validation Report", "=" * 40]
# Show configuration summary
lines.append(f"\nLLM Provider: {config.get('llm_provider', 'not set')}")
lines.append(f"Deep Think LLM: {config.get('deep_think_llm', 'not set')}")
lines.append(f"Quick Think LLM: {config.get('quick_think_llm', 'not set')}")
lines.append(f"Max Debate Rounds: {config.get('max_debate_rounds', 'not set')}")
lines.append(f"Max Risk Discuss Rounds: {config.get('max_risk_discuss_rounds', 'not set')}")
data_vendors = config.get("data_vendors", {})
if data_vendors:
lines.append("\nData Vendors:")
for category, vendor in data_vendors.items():
lines.append(f" {category}: {vendor}")
if errors:
lines.append("\nValidation Errors:")
for error in errors:
lines.append(f" - {error}")
else:
lines.append("\n✓ Configuration is valid")
return "\n".join(lines)