diff --git a/tests/test_llm_routing.py b/tests/test_llm_routing.py index a0d631d2..5f0c4575 100644 --- a/tests/test_llm_routing.py +++ b/tests/test_llm_routing.py @@ -182,6 +182,40 @@ def test_role_specific_route_inherits_default_provider_base_url_and_model(monkey assert recorded_llms["Portfolio Manager"]["model"] == "claude-sonnet-4-6" +def test_mixed_provider_route_does_not_inherit_legacy_backend_url(monkeypatch): + recorded_llms = {} + + monkeypatch.setattr( + "tradingagents.graph.trading_graph.create_llm_client", + lambda provider, model, base_url=None, **kwargs: DummyClient( + provider, model, base_url, **kwargs + ), + ) + monkeypatch.setattr( + "tradingagents.graph.trading_graph.FinancialSituationMemory", + lambda *args, **kwargs: object(), + ) + _patch_graph_setup_wiring(monkeypatch, recorded_llms) + + TradingAgentsGraph( + selected_analysts=["market"], + config={ + "llm_provider": "openai", + "backend_url": "https://api.openai.com/v1", + "llm_routing": { + "default": { + "provider": "anthropic", + "model": "claude-sonnet-4-6", + } + }, + }, + ) + + assert recorded_llms["Market Analyst"]["provider"] == "anthropic" + assert recorded_llms["Market Analyst"]["model"] == "claude-sonnet-4-6" + assert recorded_llms["Market Analyst"]["base_url"] is None + + def test_unused_role_routes_do_not_instantiate_clients(monkeypatch): created_clients = [] diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index f453d782..06476200 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -227,8 +227,9 @@ class TradingAgentsGraph: routing = self.config.get("llm_routing") or {} role_routes = routing.get("roles") or {} model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm" + legacy_provider = self._normalize_provider(self.config["llm_provider"]) legacy_route = { - "provider": self._normalize_provider(self.config["llm_provider"]), + "provider": legacy_provider, "model": self.config[model_key], "base_url": self.config.get("backend_url"), } @@ -237,6 +238,9 @@ class TradingAgentsGraph: route = self._deep_merge_dicts(legacy_route, default_route) route = self._deep_merge_dicts(route, role_route) route["provider"] = self._normalize_provider(route.get("provider")) + explicit_routed_base_url = "base_url" in default_route or "base_url" in role_route + if route["provider"] != legacy_provider and not explicit_routed_base_url: + route["base_url"] = None return route def _get_provider_kwargs(self, provider: Optional[str] = None) -> Dict[str, Any]: