diff --git a/tests/test_llm_routing.py b/tests/test_llm_routing.py index b232159c..bd7c5c28 100644 --- a/tests/test_llm_routing.py +++ b/tests/test_llm_routing.py @@ -1,6 +1,8 @@ from copy import deepcopy import json +import tradingagents.dataflows.config as dataflow_config +import tradingagents.default_config as default_config from tradingagents.graph.trading_graph import TradingAgentsGraph @@ -20,7 +22,87 @@ class DummyClient: } -def test_role_specific_llm_config_overrides_default(monkeypatch): +class DummyStateGraph: + def __init__(self, _state_type): + self.nodes = {} + + def add_node(self, name, node): + self.nodes[name] = node + + def add_edge(self, *_args, **_kwargs): + return None + + def add_conditional_edges(self, *_args, **_kwargs): + return None + + def compile(self): + return {"nodes": self.nodes} + + +def _patch_graph_setup_wiring(monkeypatch, recorded_llms): + monkeypatch.setattr("tradingagents.graph.setup.StateGraph", DummyStateGraph) + monkeypatch.setattr("tradingagents.graph.setup.create_msg_delete", lambda: "delete") + + def make_factory(node_name): + def factory(llm, *_args): + recorded_llms[node_name] = llm + return node_name + + return factory + + monkeypatch.setattr( + "tradingagents.graph.setup.create_market_analyst", + make_factory("Market Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_social_media_analyst", + make_factory("Social Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_news_analyst", + make_factory("News Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_fundamentals_analyst", + make_factory("Fundamentals Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_bull_researcher", + make_factory("Bull Researcher"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_bear_researcher", + make_factory("Bear Researcher"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_research_manager", + make_factory("Research Manager"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_trader", + make_factory("Trader"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_aggressive_debator", + make_factory("Aggressive Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_neutral_debator", + make_factory("Neutral Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_conservative_debator", + make_factory("Conservative Analyst"), + ) + monkeypatch.setattr( + "tradingagents.graph.setup.create_portfolio_manager", + make_factory("Portfolio Manager"), + ) + + +def test_role_specific_llm_config_overrides_actual_graph_wiring(monkeypatch): + recorded_llms = {} + monkeypatch.setattr( "tradingagents.graph.trading_graph.create_llm_client", lambda provider, model, base_url=None, **kwargs: DummyClient( @@ -31,10 +113,7 @@ def test_role_specific_llm_config_overrides_default(monkeypatch): "tradingagents.graph.trading_graph.FinancialSituationMemory", lambda *args, **kwargs: object(), ) - monkeypatch.setattr( - "tradingagents.graph.trading_graph.GraphSetup.setup_graph", - lambda self, selected_analysts: {"selected_analysts": selected_analysts}, - ) + _patch_graph_setup_wiring(monkeypatch, recorded_llms) config = { "llm_routing": { @@ -48,12 +127,68 @@ def test_role_specific_llm_config_overrides_default(monkeypatch): } } - graph = TradingAgentsGraph( + TradingAgentsGraph( selected_analysts=["market"], config=deepcopy(config), ) - assert graph.graph_setup.portfolio_manager_llm["model"] == "gpt-5.2" + assert recorded_llms["Market Analyst"]["model"] == "gpt-5-mini" + assert recorded_llms["Portfolio Manager"]["model"] == "gpt-5.2" + assert "News Analyst" not in recorded_llms + + +def test_unused_role_routes_do_not_instantiate_clients(monkeypatch): + created_clients = [] + + def fake_create_llm_client(provider, model, base_url=None, **kwargs): + created_clients.append((provider, model)) + if provider == "bad-provider": + raise AssertionError("unused role route should not be instantiated") + return DummyClient(provider, model, base_url, **kwargs) + + monkeypatch.setattr( + "tradingagents.graph.trading_graph.create_llm_client", + fake_create_llm_client, + ) + monkeypatch.setattr( + "tradingagents.graph.trading_graph.FinancialSituationMemory", + lambda *args, **kwargs: object(), + ) + monkeypatch.setattr( + "tradingagents.graph.trading_graph.GraphSetup.setup_graph", + lambda self, selected_analysts: {"selected_analysts": selected_analysts}, + ) + + TradingAgentsGraph( + selected_analysts=["market"], + config={ + "llm_routing": { + "default": {"provider": "openai", "model": "gpt-5-mini"}, + "roles": { + "news": { + "provider": "bad-provider", + "model": "unused-model", + } + }, + } + }, + ) + + assert ("bad-provider", "unused-model") not in created_clients + + +def test_dataflow_config_returns_isolated_nested_routing(monkeypatch): + monkeypatch.setattr(dataflow_config, "_config", None) + dataflow_config.initialize_config() + + config = dataflow_config.get_config() + config["llm_routing"]["roles"]["portfolio_manager"] = { + "provider": "openai", + "model": "gpt-5.2", + } + + assert dataflow_config.get_config()["llm_routing"]["roles"] == {} + assert default_config.DEFAULT_CONFIG["llm_routing"]["roles"] == {} def test_log_state_writes_json_snapshot(tmp_path, monkeypatch): @@ -99,4 +234,7 @@ def test_log_state_writes_json_snapshot(tmp_path, monkeypatch): / "full_states_log_2026-03-24.json" ) assert output_path.exists() - assert json.loads(output_path.read_text())["2026-03-24"]["company_of_interest"] == "Apple" + assert ( + json.loads(output_path.read_text())["2026-03-24"]["company_of_interest"] + == "Apple" + ) diff --git a/tradingagents/dataflows/config.py b/tradingagents/dataflows/config.py index 5819494a..2714eb61 100644 --- a/tradingagents/dataflows/config.py +++ b/tradingagents/dataflows/config.py @@ -1,3 +1,4 @@ +from copy import deepcopy import tradingagents.default_config as default_config from typing import Dict, Optional @@ -5,26 +6,34 @@ from typing import Dict, Optional _config: Optional[Dict] = None +def _deep_merge_dicts(base: Dict, override: Dict) -> Dict: + merged = deepcopy(base) + for key, value in override.items(): + if isinstance(value, dict) and isinstance(merged.get(key), dict): + merged[key] = _deep_merge_dicts(merged[key], value) + else: + merged[key] = deepcopy(value) + return merged + + def initialize_config(): """Initialize the configuration with default values.""" global _config if _config is None: - _config = default_config.DEFAULT_CONFIG.copy() + _config = deepcopy(default_config.DEFAULT_CONFIG) def set_config(config: Dict): """Update the configuration with custom values.""" global _config - if _config is None: - _config = default_config.DEFAULT_CONFIG.copy() - _config.update(config) + _config = _deep_merge_dicts(default_config.DEFAULT_CONFIG, config) def get_config() -> Dict: """Get the current configuration.""" if _config is None: initialize_config() - return _config.copy() + return deepcopy(_config) # Initialize with default config diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index d6c77409..0c9f75d9 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -43,6 +43,16 @@ from .signal_processing import SignalProcessor class TradingAgentsGraph: """Main class that orchestrates the trading agents framework.""" + ALWAYS_ON_ROLES = { + "bull_researcher", + "bear_researcher", + "research_manager", + "trader", + "aggressive_analyst", + "neutral_analyst", + "conservative_analyst", + "portfolio_manager", + } QUICK_THINKING_ROLES = { "market", "social", @@ -90,7 +100,7 @@ class TradingAgentsGraph: self.quick_thinking_llm = self._create_legacy_llm("quick") self.deep_thinking_llm = self._create_legacy_llm("deep") - self.role_llms = self._create_role_llms() + self.role_llms = self._create_role_llms(selected_analysts) # Initialize memories self.bull_memory = FinancialSituationMemory("bull_memory", self.config) @@ -164,15 +174,20 @@ class TradingAgentsGraph: ) return client.get_llm() - def _create_role_llms(self) -> Dict[str, Any]: + def _create_role_llms(self, selected_analysts: List[str]) -> Dict[str, Any]: role_llms = {} - for role in self.QUICK_THINKING_ROLES | self.DEEP_THINKING_ROLES: + for role in self._get_required_roles(selected_analysts): thinker_depth = "deep" if role in self.DEEP_THINKING_ROLES else "quick" - role_llms[role] = self._create_routed_llm(role, thinker_depth) + llm_config = self._resolve_llm_config(role, thinker_depth) + if self._uses_legacy_llm(llm_config, thinker_depth): + continue + role_llms[role] = self._create_llm_from_config(llm_config) return role_llms - def _create_routed_llm(self, role: str, thinker_depth: str): - llm_config = self._resolve_llm_config(role, thinker_depth) + def _get_required_roles(self, selected_analysts: List[str]) -> set[str]: + return self.ALWAYS_ON_ROLES | set(selected_analysts) + + def _create_llm_from_config(self, llm_config: Dict[str, Any]): llm_kwargs = self._get_provider_kwargs(llm_config["provider"]) if self.callbacks: llm_kwargs["callbacks"] = self.callbacks @@ -185,6 +200,14 @@ class TradingAgentsGraph: ) return client.get_llm() + def _uses_legacy_llm(self, llm_config: Dict[str, Any], thinker_depth: str) -> bool: + model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm" + return ( + llm_config["provider"] == self.config["llm_provider"] + and llm_config["model"] == self.config[model_key] + and llm_config.get("base_url") == self.config.get("backend_url") + ) + def _resolve_llm_config( self, role: str,