From 5c483b429348d92655041580525a2219baca12c1 Mon Sep 17 00:00:00 2001 From: Garrick Date: Tue, 24 Mar 2026 13:12:55 -0700 Subject: [PATCH] fix: normalize routed provider comparisons --- tests/test_llm_routing.py | 36 ++++++++++++++++++++++++++++ tradingagents/graph/trading_graph.py | 11 ++++++--- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/tests/test_llm_routing.py b/tests/test_llm_routing.py index bd7c5c28..ca4903cd 100644 --- a/tests/test_llm_routing.py +++ b/tests/test_llm_routing.py @@ -177,6 +177,42 @@ def test_unused_role_routes_do_not_instantiate_clients(monkeypatch): assert ("bad-provider", "unused-model") not in created_clients +def test_provider_normalization_avoids_duplicate_legacy_client_creation(monkeypatch): + created_clients = [] + + def fake_create_llm_client(provider, model, base_url=None, **kwargs): + created_clients.append((provider, model)) + return DummyClient(provider, model, base_url, **kwargs) + + monkeypatch.setattr( + "tradingagents.graph.trading_graph.create_llm_client", + fake_create_llm_client, + ) + monkeypatch.setattr( + "tradingagents.graph.trading_graph.FinancialSituationMemory", + lambda *args, **kwargs: object(), + ) + monkeypatch.setattr( + "tradingagents.graph.trading_graph.GraphSetup.setup_graph", + lambda self, selected_analysts: {"selected_analysts": selected_analysts}, + ) + + TradingAgentsGraph( + selected_analysts=["market"], + config={ + "llm_provider": "OpenAI", + "quick_think_llm": "gpt-5-mini", + "llm_routing": { + "roles": { + "market": {"provider": "openai", "model": "gpt-5-mini"}, + }, + }, + }, + ) + + assert created_clients.count(("openai", "gpt-5-mini")) == 1 + + def test_dataflow_config_returns_isolated_nested_routing(monkeypatch): monkeypatch.setattr(dataflow_config, "_config", None) dataflow_config.initialize_config() diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 0c9f75d9..ff19604b 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -146,6 +146,9 @@ class TradingAgentsGraph: """Merge user config over defaults without mutating the shared defaults.""" return self._deep_merge_dicts(DEFAULT_CONFIG, config or {}) + def _normalize_provider(self, provider: Optional[str]) -> str: + return (provider or "").lower() + def _deep_merge_dicts( self, base: Dict[str, Any], @@ -161,7 +164,7 @@ class TradingAgentsGraph: def _create_legacy_llm(self, thinker_depth: str): model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm" - provider = self.config["llm_provider"] + provider = self._normalize_provider(self.config["llm_provider"]) llm_kwargs = self._get_provider_kwargs(provider) if self.callbacks: llm_kwargs["callbacks"] = self.callbacks @@ -203,7 +206,7 @@ class TradingAgentsGraph: def _uses_legacy_llm(self, llm_config: Dict[str, Any], thinker_depth: str) -> bool: model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm" return ( - llm_config["provider"] == self.config["llm_provider"] + llm_config["provider"] == self._normalize_provider(self.config["llm_provider"]) and llm_config["model"] == self.config[model_key] and llm_config.get("base_url") == self.config.get("backend_url") ) @@ -218,7 +221,9 @@ class TradingAgentsGraph: route = role_routes.get(role) or routing.get("default") or {} model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm" - provider = route.get("provider", self.config["llm_provider"]).lower() + provider = self._normalize_provider( + route.get("provider", self.config["llm_provider"]) + ) return { "provider": provider,