diff --git a/tests/test_llm_routing.py b/tests/test_llm_routing.py index c4841684..a0d631d2 100644 --- a/tests/test_llm_routing.py +++ b/tests/test_llm_routing.py @@ -300,6 +300,18 @@ def test_google_client_passes_base_url_to_langchain(monkeypatch): assert captured_kwargs["base_url"] == "https://google.example/v1beta" +def test_default_config_copy_does_not_share_mutable_llm_routing_state(): + config = default_config.DEFAULT_CONFIG.copy() + + assert config["llm_routing"] is None + config["llm_routing"] = { + "default": {}, + "roles": {"portfolio_manager": {"model": "gpt-5.2"}}, + } + + assert default_config.DEFAULT_CONFIG["llm_routing"] is None + + def test_dataflow_config_returns_isolated_nested_routing(monkeypatch): monkeypatch.setattr(dataflow_config, "_config", None) dataflow_config.initialize_config() @@ -311,7 +323,7 @@ def test_dataflow_config_returns_isolated_nested_routing(monkeypatch): } assert dataflow_config.get_config()["llm_routing"]["roles"] == {} - assert default_config.DEFAULT_CONFIG["llm_routing"]["roles"] == {} + assert default_config.DEFAULT_CONFIG["llm_routing"] is None def test_log_state_writes_json_snapshot(tmp_path, monkeypatch): diff --git a/tradingagents/dataflows/config.py b/tradingagents/dataflows/config.py index 2714eb61..9b398de1 100644 --- a/tradingagents/dataflows/config.py +++ b/tradingagents/dataflows/config.py @@ -20,20 +20,22 @@ def initialize_config(): """Initialize the configuration with default values.""" global _config if _config is None: - _config = deepcopy(default_config.DEFAULT_CONFIG) + _config = default_config.normalize_llm_routing(default_config.DEFAULT_CONFIG) def set_config(config: Dict): """Update the configuration with custom values.""" global _config - _config = _deep_merge_dicts(default_config.DEFAULT_CONFIG, config) + _config = default_config.normalize_llm_routing( + _deep_merge_dicts(default_config.DEFAULT_CONFIG, config) + ) def get_config() -> Dict: """Get the current configuration.""" if _config is None: initialize_config() - return deepcopy(_config) + return default_config.normalize_llm_routing(_config) # Initialize with default config diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 2a3e7f2e..cf3bdc96 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -1,5 +1,24 @@ +from copy import deepcopy import os + +def normalize_llm_routing(config): + """Return config with llm_routing normalized to the expected dict shape.""" + normalized = deepcopy(config) + llm_routing = normalized.get("llm_routing") + + if not isinstance(llm_routing, dict): + normalized["llm_routing"] = {"default": {}, "roles": {}} + return normalized + + default_route = llm_routing.get("default") + role_routes = llm_routing.get("roles") + normalized["llm_routing"] = { + "default": deepcopy(default_route) if isinstance(default_route, dict) else {}, + "roles": deepcopy(role_routes) if isinstance(role_routes, dict) else {}, + } + return normalized + DEFAULT_CONFIG = { "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), @@ -12,10 +31,7 @@ DEFAULT_CONFIG = { "deep_think_llm": "gpt-5.2", "quick_think_llm": "gpt-5-mini", "backend_url": "https://api.openai.com/v1", - "llm_routing": { - "default": {}, - "roles": {}, - }, + "llm_routing": None, # Provider-specific thinking configuration "google_thinking_level": None, # "high", "minimal", etc. "openai_reasoning_effort": None, # "medium", "high", "low" diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index bec32a33..f453d782 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -11,7 +11,7 @@ from langgraph.prebuilt import ToolNode from tradingagents.llm_clients import create_llm_client from tradingagents.agents import * -from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.default_config import DEFAULT_CONFIG, normalize_llm_routing from tradingagents.agents.utils.memory import FinancialSituationMemory from tradingagents.agents.utils.agent_states import ( AgentState, @@ -144,7 +144,7 @@ class TradingAgentsGraph: def _build_config(self, config: Optional[Dict[str, Any]]) -> Dict[str, Any]: """Merge user config over defaults without mutating the shared defaults.""" - return self._deep_merge_dicts(DEFAULT_CONFIG, config or {}) + return normalize_llm_routing(self._deep_merge_dicts(DEFAULT_CONFIG, config or {})) def _normalize_provider(self, provider: Optional[str]) -> str: return (provider or "").lower()