refactor: consolidate configuration management with Pydantic-based settings
- Add tradingagents/config.py with TradingAgentsSettings Pydantic model - Centralize all environment variable handling and API key management - Add validation for log_level, llm_provider, and numeric bounds - Update dataflows/config.py to use central config as backend - Update logging.py to get settings from central config - Update API key access in alpha_vantage, brave, and tavily modules - Replace DEFAULT_CONFIG imports with get_config() calls - Add 20 comprehensive tests for config module 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
0c04bdb0ee
commit
9c252fdc2c
|
|
@ -0,0 +1,168 @@
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from tradingagents.config import (
|
||||||
|
TradingAgentsSettings,
|
||||||
|
DataVendorsConfig,
|
||||||
|
get_settings,
|
||||||
|
reset_settings,
|
||||||
|
update_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataVendorsConfig:
|
||||||
|
def test_default_values(self):
|
||||||
|
config = DataVendorsConfig()
|
||||||
|
assert config.core_stock_apis == "yfinance"
|
||||||
|
assert config.technical_indicators == "yfinance"
|
||||||
|
assert config.fundamental_data == "alpha_vantage"
|
||||||
|
assert config.news_data == "alpha_vantage"
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
config = DataVendorsConfig(core_stock_apis="local", news_data="openai")
|
||||||
|
assert config.core_stock_apis == "local"
|
||||||
|
assert config.news_data == "openai"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTradingAgentsSettings:
|
||||||
|
def setup_method(self):
|
||||||
|
reset_settings()
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
reset_settings()
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
settings = TradingAgentsSettings()
|
||||||
|
assert settings.llm_provider == "openai"
|
||||||
|
assert settings.log_level == "INFO"
|
||||||
|
assert settings.max_debate_rounds == 2
|
||||||
|
|
||||||
|
def test_log_level_validation(self):
|
||||||
|
settings = TradingAgentsSettings(log_level="debug")
|
||||||
|
assert settings.log_level == "DEBUG"
|
||||||
|
|
||||||
|
settings = TradingAgentsSettings(log_level="WARNING")
|
||||||
|
assert settings.log_level == "WARNING"
|
||||||
|
|
||||||
|
def test_log_level_invalid(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid log level"):
|
||||||
|
TradingAgentsSettings(log_level="INVALID")
|
||||||
|
|
||||||
|
def test_llm_provider_validation(self):
|
||||||
|
settings = TradingAgentsSettings(llm_provider="OPENAI")
|
||||||
|
assert settings.llm_provider == "openai"
|
||||||
|
|
||||||
|
settings = TradingAgentsSettings(llm_provider="Anthropic")
|
||||||
|
assert settings.llm_provider == "anthropic"
|
||||||
|
|
||||||
|
def test_llm_provider_invalid(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid LLM provider"):
|
||||||
|
TradingAgentsSettings(llm_provider="invalid_provider")
|
||||||
|
|
||||||
|
def test_to_dict(self):
|
||||||
|
settings = TradingAgentsSettings()
|
||||||
|
result = settings.to_dict()
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "llm_provider" in result
|
||||||
|
assert "data_vendors" in result
|
||||||
|
assert isinstance(result["data_vendors"], dict)
|
||||||
|
|
||||||
|
def test_get_api_key_returns_value(self):
|
||||||
|
settings = TradingAgentsSettings(openai_api_key="test-key")
|
||||||
|
assert settings.get_api_key("openai") == "test-key"
|
||||||
|
|
||||||
|
def test_get_api_key_returns_none_when_not_set(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
settings = TradingAgentsSettings()
|
||||||
|
settings.openai_api_key = None
|
||||||
|
assert settings.get_api_key("openai") is None
|
||||||
|
|
||||||
|
def test_require_api_key_raises_when_not_set(self):
|
||||||
|
settings = TradingAgentsSettings()
|
||||||
|
settings.brave_api_key = None
|
||||||
|
with pytest.raises(ValueError, match="brave API key not configured"):
|
||||||
|
settings.require_api_key("brave")
|
||||||
|
|
||||||
|
def test_require_api_key_returns_value(self):
|
||||||
|
settings = TradingAgentsSettings(tavily_api_key="tavily-test")
|
||||||
|
assert settings.require_api_key("tavily") == "tavily-test"
|
||||||
|
|
||||||
|
def test_max_debate_rounds_bounds(self):
|
||||||
|
settings = TradingAgentsSettings(max_debate_rounds=5)
|
||||||
|
assert settings.max_debate_rounds == 5
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
TradingAgentsSettings(max_debate_rounds=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
TradingAgentsSettings(max_debate_rounds=20)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigFunctions:
|
||||||
|
def setup_method(self):
|
||||||
|
reset_settings()
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
reset_settings()
|
||||||
|
|
||||||
|
def test_get_settings_returns_singleton(self):
|
||||||
|
s1 = get_settings()
|
||||||
|
s2 = get_settings()
|
||||||
|
assert s1 is s2
|
||||||
|
|
||||||
|
def test_reset_settings_clears_singleton(self):
|
||||||
|
s1 = get_settings()
|
||||||
|
reset_settings()
|
||||||
|
s2 = get_settings()
|
||||||
|
assert s1 is not s2
|
||||||
|
|
||||||
|
def test_update_settings_modifies_values(self):
|
||||||
|
original = get_settings()
|
||||||
|
assert original.max_debate_rounds == 2
|
||||||
|
|
||||||
|
update_settings(max_debate_rounds=5)
|
||||||
|
updated = get_settings()
|
||||||
|
assert updated.max_debate_rounds == 5
|
||||||
|
|
||||||
|
def test_update_settings_preserves_other_values(self):
|
||||||
|
original = get_settings()
|
||||||
|
original_provider = original.llm_provider
|
||||||
|
|
||||||
|
update_settings(max_debate_rounds=5)
|
||||||
|
updated = get_settings()
|
||||||
|
assert updated.llm_provider == original_provider
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataflowConfigCompat:
|
||||||
|
def setup_method(self):
|
||||||
|
reset_settings()
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
reset_settings()
|
||||||
|
|
||||||
|
def test_get_config_returns_dict(self):
|
||||||
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
assert isinstance(config, dict)
|
||||||
|
assert "llm_provider" in config
|
||||||
|
assert "data_vendors" in config
|
||||||
|
|
||||||
|
def test_set_config_updates_central_settings(self):
|
||||||
|
from tradingagents.dataflows.config import get_config, set_config
|
||||||
|
|
||||||
|
set_config({"max_debate_rounds": 4})
|
||||||
|
config = get_config()
|
||||||
|
assert config["max_debate_rounds"] == 4
|
||||||
|
|
||||||
|
def test_get_config_returns_copy(self):
|
||||||
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
|
c1 = get_config()
|
||||||
|
c2 = get_config()
|
||||||
|
assert c1 == c2
|
||||||
|
c1["llm_provider"] = "modified"
|
||||||
|
c3 = get_config()
|
||||||
|
assert c3["llm_provider"] != "modified"
|
||||||
|
|
@ -6,7 +6,7 @@ from langchain_openai import ChatOpenAI
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.dataflows.config import get_config
|
||||||
from tradingagents.agents.discovery.models import NewsArticle, EventCategory
|
from tradingagents.agents.discovery.models import NewsArticle, EventCategory
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -37,7 +37,7 @@ class ExtractionResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
def _get_llm(config: Optional[dict] = None):
|
def _get_llm(config: Optional[dict] = None):
|
||||||
cfg = config or DEFAULT_CONFIG
|
cfg = config or get_config()
|
||||||
provider = cfg.get("llm_provider", "openai").lower()
|
provider = cfg.get("llm_provider", "openai").lower()
|
||||||
model = cfg.get("quick_think_llm", "gpt-4o-mini")
|
model = cfg.get("quick_think_llm", "gpt-4o-mini")
|
||||||
backend_url = cfg.get("backend_url", "https://api.openai.com/v1")
|
backend_url = cfg.get("backend_url", "https://api.openai.com/v1")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,147 @@
|
||||||
|
import os
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class DataVendorsConfig(BaseModel):
|
||||||
|
core_stock_apis: str = "yfinance"
|
||||||
|
technical_indicators: str = "yfinance"
|
||||||
|
fundamental_data: str = "alpha_vantage"
|
||||||
|
news_data: str = "alpha_vantage"
|
||||||
|
|
||||||
|
|
||||||
|
class TradingAgentsSettings(BaseSettings):
|
||||||
|
project_dir: str = Field(
|
||||||
|
default_factory=lambda: os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
|
||||||
|
)
|
||||||
|
results_dir: str = Field(default="./results")
|
||||||
|
data_dir: str = Field(default="./data")
|
||||||
|
data_cache_dir: Optional[str] = None
|
||||||
|
|
||||||
|
llm_provider: str = Field(default="openai")
|
||||||
|
deep_think_llm: str = Field(default="gpt-5")
|
||||||
|
quick_think_llm: str = Field(default="gpt-5-mini")
|
||||||
|
backend_url: str = Field(default="https://api.openai.com/v1")
|
||||||
|
|
||||||
|
max_debate_rounds: int = Field(default=2, ge=1, le=10)
|
||||||
|
max_risk_discuss_rounds: int = Field(default=2, ge=1, le=10)
|
||||||
|
max_recur_limit: int = Field(default=100, ge=10, le=500)
|
||||||
|
|
||||||
|
discovery_timeout: int = Field(default=60, ge=10)
|
||||||
|
discovery_hard_timeout: int = Field(default=120, ge=30)
|
||||||
|
discovery_cache_ttl: int = Field(default=300, ge=60)
|
||||||
|
discovery_max_results: int = Field(default=20, ge=1, le=100)
|
||||||
|
discovery_min_mentions: int = Field(default=2, ge=1)
|
||||||
|
|
||||||
|
bulk_news_vendor_order: List[str] = Field(
|
||||||
|
default=["tavily", "brave", "alpha_vantage", "openai", "google"]
|
||||||
|
)
|
||||||
|
bulk_news_timeout: int = Field(default=30, ge=5)
|
||||||
|
bulk_news_max_retries: int = Field(default=3, ge=1)
|
||||||
|
|
||||||
|
log_level: str = Field(default="INFO")
|
||||||
|
log_dir: str = Field(default="./logs")
|
||||||
|
log_console_enabled: bool = Field(default=True)
|
||||||
|
log_file_enabled: bool = Field(default=True)
|
||||||
|
|
||||||
|
openai_api_key: Optional[str] = Field(default=None)
|
||||||
|
alpha_vantage_api_key: Optional[str] = Field(default=None)
|
||||||
|
brave_api_key: Optional[str] = Field(default=None)
|
||||||
|
tavily_api_key: Optional[str] = Field(default=None)
|
||||||
|
google_api_key: Optional[str] = Field(default=None)
|
||||||
|
anthropic_api_key: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
data_vendors: DataVendorsConfig = Field(default_factory=DataVendorsConfig)
|
||||||
|
tool_vendors: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_prefix": "TRADINGAGENTS_",
|
||||||
|
"env_nested_delimiter": "__",
|
||||||
|
"extra": "ignore",
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
}
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
if self.data_cache_dir is None:
|
||||||
|
self.data_cache_dir = os.path.join(self.project_dir, "dataflows/data_cache")
|
||||||
|
|
||||||
|
if self.openai_api_key is None:
|
||||||
|
self.openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if self.alpha_vantage_api_key is None:
|
||||||
|
self.alpha_vantage_api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
|
||||||
|
if self.brave_api_key is None:
|
||||||
|
self.brave_api_key = os.getenv("BRAVE_API_KEY")
|
||||||
|
if self.tavily_api_key is None:
|
||||||
|
self.tavily_api_key = os.getenv("TAVILY_API_KEY")
|
||||||
|
if self.google_api_key is None:
|
||||||
|
self.google_api_key = os.getenv("GOOGLE_API_KEY")
|
||||||
|
if self.anthropic_api_key is None:
|
||||||
|
self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
|
||||||
|
@field_validator("log_level")
|
||||||
|
@classmethod
|
||||||
|
def validate_log_level(cls, v: str) -> str:
|
||||||
|
valid_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
|
||||||
|
if v.upper() not in valid_levels:
|
||||||
|
raise ValueError(f"Invalid log level: {v}. Must be one of {valid_levels}")
|
||||||
|
return v.upper()
|
||||||
|
|
||||||
|
@field_validator("llm_provider")
|
||||||
|
@classmethod
|
||||||
|
def validate_llm_provider(cls, v: str) -> str:
|
||||||
|
valid_providers = {"openai", "anthropic", "google", "ollama", "openrouter"}
|
||||||
|
if v.lower() not in valid_providers:
|
||||||
|
raise ValueError(f"Invalid LLM provider: {v}. Must be one of {valid_providers}")
|
||||||
|
return v.lower()
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
result = self.model_dump()
|
||||||
|
result["data_vendors"] = self.data_vendors.model_dump()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_api_key(self, vendor: str) -> Optional[str]:
|
||||||
|
key_map = {
|
||||||
|
"openai": self.openai_api_key,
|
||||||
|
"alpha_vantage": self.alpha_vantage_api_key,
|
||||||
|
"brave": self.brave_api_key,
|
||||||
|
"tavily": self.tavily_api_key,
|
||||||
|
"google": self.google_api_key,
|
||||||
|
"anthropic": self.anthropic_api_key,
|
||||||
|
}
|
||||||
|
return key_map.get(vendor.lower())
|
||||||
|
|
||||||
|
def require_api_key(self, vendor: str) -> str:
|
||||||
|
key = self.get_api_key(vendor)
|
||||||
|
if not key:
|
||||||
|
env_var = f"{vendor.upper()}_API_KEY"
|
||||||
|
raise ValueError(
|
||||||
|
f"{vendor} API key not configured. "
|
||||||
|
f"Set {env_var} environment variable or TRADINGAGENTS_{env_var}."
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
_settings: Optional[TradingAgentsSettings] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings() -> TradingAgentsSettings:
|
||||||
|
global _settings
|
||||||
|
if _settings is None:
|
||||||
|
_settings = TradingAgentsSettings()
|
||||||
|
return _settings
|
||||||
|
|
||||||
|
|
||||||
|
def reset_settings() -> None:
|
||||||
|
global _settings
|
||||||
|
_settings = None
|
||||||
|
|
||||||
|
|
||||||
|
def update_settings(**kwargs) -> TradingAgentsSettings:
|
||||||
|
global _settings
|
||||||
|
current = get_settings()
|
||||||
|
new_values = current.model_dump()
|
||||||
|
new_values.update(kwargs)
|
||||||
|
_settings = TradingAgentsSettings(**new_values)
|
||||||
|
return _settings
|
||||||
|
|
@ -11,10 +11,14 @@ logger = logging.getLogger(__name__)
|
||||||
API_BASE_URL = "https://www.alphavantage.co/query"
|
API_BASE_URL = "https://www.alphavantage.co/query"
|
||||||
|
|
||||||
def get_api_key() -> str:
|
def get_api_key() -> str:
|
||||||
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
|
try:
|
||||||
if not api_key:
|
from tradingagents.config import get_settings
|
||||||
raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.")
|
return get_settings().require_api_key("alpha_vantage")
|
||||||
return api_key
|
except ImportError:
|
||||||
|
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.")
|
||||||
|
return api_key
|
||||||
|
|
||||||
def format_datetime_for_api(date_input) -> str:
|
def format_datetime_for_api(date_input) -> str:
|
||||||
if isinstance(date_input, str):
|
if isinstance(date_input, str):
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,14 @@ RETRY_BACKOFF = 1.0
|
||||||
|
|
||||||
|
|
||||||
def get_api_key() -> str:
|
def get_api_key() -> str:
|
||||||
api_key = os.getenv("BRAVE_API_KEY")
|
try:
|
||||||
if not api_key:
|
from tradingagents.config import get_settings
|
||||||
raise ValueError("BRAVE_API_KEY environment variable is not set.")
|
return get_settings().require_api_key("brave")
|
||||||
return api_key
|
except ImportError:
|
||||||
|
api_key = os.getenv("BRAVE_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("BRAVE_API_KEY environment variable is not set.")
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
def _make_request_with_retry(url: str, headers: Dict, params: Dict, max_retries: int = MAX_RETRIES) -> requests.Response:
|
def _make_request_with_retry(url: str, headers: Dict, params: Dict, max_retries: int = MAX_RETRIES) -> requests.Response:
|
||||||
|
|
|
||||||
|
|
@ -1,34 +1,35 @@
|
||||||
import tradingagents.default_config as default_config
|
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
from tradingagents.config import get_settings, update_settings, TradingAgentsSettings
|
||||||
|
|
||||||
# Use default config but allow it to be overridden
|
|
||||||
_config: Optional[Dict] = None
|
_config: Optional[Dict] = None
|
||||||
DATA_DIR: Optional[str] = None
|
DATA_DIR: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def initialize_config():
|
def initialize_config():
|
||||||
"""Initialize the configuration with default values."""
|
|
||||||
global _config, DATA_DIR
|
global _config, DATA_DIR
|
||||||
if _config is None:
|
if _config is None:
|
||||||
_config = default_config.DEFAULT_CONFIG.copy()
|
settings = get_settings()
|
||||||
|
_config = settings.to_dict()
|
||||||
DATA_DIR = _config["data_dir"]
|
DATA_DIR = _config["data_dir"]
|
||||||
|
|
||||||
|
|
||||||
def set_config(config: Dict):
|
def set_config(config: Dict):
|
||||||
"""Update the configuration with custom values."""
|
|
||||||
global _config, DATA_DIR
|
global _config, DATA_DIR
|
||||||
if _config is None:
|
|
||||||
_config = default_config.DEFAULT_CONFIG.copy()
|
settings = get_settings()
|
||||||
_config.update(config)
|
current_dict = settings.to_dict()
|
||||||
|
current_dict.update(config)
|
||||||
|
update_settings(**current_dict)
|
||||||
|
|
||||||
|
_config = get_settings().to_dict()
|
||||||
DATA_DIR = _config["data_dir"]
|
DATA_DIR = _config["data_dir"]
|
||||||
|
|
||||||
|
|
||||||
def get_config() -> Dict:
|
def get_config() -> Dict:
|
||||||
"""Get the current configuration."""
|
global _config
|
||||||
if _config is None:
|
if _config is None:
|
||||||
initialize_config()
|
initialize_config()
|
||||||
return _config.copy()
|
return _config.copy()
|
||||||
|
|
||||||
|
|
||||||
# Initialize with default config
|
|
||||||
initialize_config()
|
initialize_config()
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,14 @@ RETRY_BACKOFF = 1.0
|
||||||
|
|
||||||
|
|
||||||
def get_api_key() -> str:
|
def get_api_key() -> str:
|
||||||
api_key = os.getenv("TAVILY_API_KEY")
|
try:
|
||||||
if not api_key:
|
from tradingagents.config import get_settings
|
||||||
raise ValueError("TAVILY_API_KEY environment variable is not set.")
|
return get_settings().require_api_key("tavily")
|
||||||
return api_key
|
except ImportError:
|
||||||
|
api_key = os.getenv("TAVILY_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("TAVILY_API_KEY environment variable is not set.")
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
def _search_with_retry(client, query: str, search_depth: str, topic: str, time_range: str, max_results: int, max_retries: int = MAX_RETRIES) -> Dict[str, Any]:
|
def _search_with_retry(client, query: str, search_depth: str, topic: str, time_range: str, max_results: int, max_retries: int = MAX_RETRIES) -> Dict[str, Any]:
|
||||||
|
|
|
||||||
|
|
@ -205,11 +205,12 @@ _sector_cache: Dict[str, str] = {}
|
||||||
def _llm_classify_sector(ticker: str) -> str:
|
def _llm_classify_sector(ticker: str) -> str:
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
llm_name = DEFAULT_CONFIG.get("quick_think_llm", "gpt-4o-mini")
|
config = get_config()
|
||||||
llm_provider = DEFAULT_CONFIG.get("llm_provider", "openai")
|
llm_name = config.get("quick_think_llm", "gpt-4o-mini")
|
||||||
backend_url = DEFAULT_CONFIG.get("backend_url", "https://api.openai.com/v1")
|
llm_provider = config.get("llm_provider", "openai")
|
||||||
|
backend_url = config.get("backend_url", "https://api.openai.com/v1")
|
||||||
|
|
||||||
llm = ChatOpenAI(
|
llm = ChatOpenAI(
|
||||||
model=llm_name,
|
model=llm_name,
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
from tradingagents.agents import *
|
from tradingagents.agents import *
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.dataflows.config import get_config
|
||||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||||
from tradingagents.agents.utils.agent_states import (
|
from tradingagents.agents.utils.agent_states import (
|
||||||
AgentState,
|
AgentState,
|
||||||
|
|
@ -76,7 +76,7 @@ class TradingAgentsGraph:
|
||||||
config: Dict[str, Any] = None,
|
config: Dict[str, Any] = None,
|
||||||
):
|
):
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.config = config or DEFAULT_CONFIG
|
self.config = config or get_config()
|
||||||
|
|
||||||
set_config(self.config)
|
set_config(self.config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,6 @@ import os
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
LOG_LEVEL_DEFAULT = "INFO"
|
|
||||||
LOG_DIR_DEFAULT = "./logs"
|
|
||||||
LOG_FILE_NAME = "tradingagents.log"
|
LOG_FILE_NAME = "tradingagents.log"
|
||||||
LOG_MAX_BYTES = 10 * 1024 * 1024
|
LOG_MAX_BYTES = 10 * 1024 * 1024
|
||||||
LOG_BACKUP_COUNT = 5
|
LOG_BACKUP_COUNT = 5
|
||||||
|
|
@ -31,44 +29,29 @@ class JSONFormatter(logging.Formatter):
|
||||||
return json.dumps(log_record)
|
return json.dumps(log_record)
|
||||||
|
|
||||||
|
|
||||||
def _parse_bool(value):
|
def _get_settings():
|
||||||
if isinstance(value, bool):
|
|
||||||
return value
|
|
||||||
if isinstance(value, str):
|
|
||||||
return value.lower() in ("true", "1", "yes", "on")
|
|
||||||
return bool(value)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_config_value(key, default):
|
|
||||||
try:
|
try:
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.config import get_settings
|
||||||
return DEFAULT_CONFIG.get(key, default)
|
return get_settings()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return default
|
return None
|
||||||
|
|
||||||
|
|
||||||
def setup_logging():
|
def setup_logging():
|
||||||
global _logging_initialized
|
global _logging_initialized
|
||||||
|
|
||||||
log_level_str = os.getenv("TRADINGAGENTS_LOG_LEVEL")
|
settings = _get_settings()
|
||||||
if log_level_str is None:
|
|
||||||
log_level_str = _get_config_value("log_level", LOG_LEVEL_DEFAULT)
|
|
||||||
|
|
||||||
log_dir = os.getenv("TRADINGAGENTS_LOG_DIR")
|
if settings:
|
||||||
if log_dir is None:
|
log_level_str = settings.log_level
|
||||||
log_dir = _get_config_value("log_dir", LOG_DIR_DEFAULT)
|
log_dir = settings.log_dir
|
||||||
|
console_enabled = settings.log_console_enabled
|
||||||
console_enabled_env = os.getenv("TRADINGAGENTS_LOG_CONSOLE")
|
file_enabled = settings.log_file_enabled
|
||||||
if console_enabled_env is not None:
|
|
||||||
console_enabled = _parse_bool(console_enabled_env)
|
|
||||||
else:
|
else:
|
||||||
console_enabled = _get_config_value("log_console_enabled", True)
|
log_level_str = os.getenv("TRADINGAGENTS_LOG_LEVEL", "INFO")
|
||||||
|
log_dir = os.getenv("TRADINGAGENTS_LOG_DIR", "./logs")
|
||||||
file_enabled_env = os.getenv("TRADINGAGENTS_LOG_FILE")
|
console_enabled = os.getenv("TRADINGAGENTS_LOG_CONSOLE", "true").lower() in ("true", "1", "yes", "on")
|
||||||
if file_enabled_env is not None:
|
file_enabled = os.getenv("TRADINGAGENTS_LOG_FILE", "true").lower() in ("true", "1", "yes", "on")
|
||||||
file_enabled = _parse_bool(file_enabled_env)
|
|
||||||
else:
|
|
||||||
file_enabled = _get_config_value("log_file_enabled", True)
|
|
||||||
|
|
||||||
log_level = getattr(logging, log_level_str.upper(), logging.INFO)
|
log_level = getattr(logging, log_level_str.upper(), logging.INFO)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue