From cd145f14d071e4ea1315950a10a18cd7b07712aa Mon Sep 17 00:00:00 2001 From: Garrick Date: Tue, 24 Mar 2026 13:26:23 -0700 Subject: [PATCH] fix: layer llm route inheritance --- tests/test_llm_routing.py | 10 ++++++---- tradingagents/graph/trading_graph.py | 22 ++++++++++------------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/test_llm_routing.py b/tests/test_llm_routing.py index c503041a..54dc1b8b 100644 --- a/tests/test_llm_routing.py +++ b/tests/test_llm_routing.py @@ -137,7 +137,7 @@ def test_role_specific_llm_config_overrides_actual_graph_wiring(monkeypatch): assert "News Analyst" not in recorded_llms -def test_role_specific_route_inherits_default_provider_and_base_url(monkeypatch): +def test_role_specific_route_inherits_default_provider_base_url_and_model(monkeypatch): recorded_llms = {} monkeypatch.setattr( @@ -161,10 +161,11 @@ def test_role_specific_route_inherits_default_provider_and_base_url(monkeypatch) "default": { "provider": "anthropic", "base_url": "https://anthropic.example/v1", + "model": "claude-sonnet-4-6", }, "roles": { "portfolio_manager": { - "model": "claude-opus-4-6", + "base_url": "https://anthropic-pm.example/v1", } }, }, @@ -173,9 +174,10 @@ def test_role_specific_route_inherits_default_provider_and_base_url(monkeypatch) assert recorded_llms["Market Analyst"]["provider"] == "anthropic" assert recorded_llms["Market Analyst"]["base_url"] == "https://anthropic.example/v1" + assert recorded_llms["Market Analyst"]["model"] == "claude-sonnet-4-6" assert recorded_llms["Portfolio Manager"]["provider"] == "anthropic" - assert recorded_llms["Portfolio Manager"]["base_url"] == "https://anthropic.example/v1" - assert recorded_llms["Portfolio Manager"]["model"] == "claude-opus-4-6" + assert recorded_llms["Portfolio Manager"]["base_url"] == "https://anthropic-pm.example/v1" + assert recorded_llms["Portfolio Manager"]["model"] == "claude-sonnet-4-6" def test_unused_role_routes_do_not_instantiate_clients(monkeypatch): diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 5a789e3d..bec32a33 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -226,20 +226,18 @@ class TradingAgentsGraph: ) -> Dict[str, Any]: routing = self.config.get("llm_routing") or {} role_routes = routing.get("roles") or {} + model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm" + legacy_route = { + "provider": self._normalize_provider(self.config["llm_provider"]), + "model": self.config[model_key], + "base_url": self.config.get("backend_url"), + } default_route = routing.get("default") or {} role_route = role_routes.get(role) or {} - route = self._deep_merge_dicts(default_route, role_route) - - model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm" - provider = self._normalize_provider( - route.get("provider", self.config["llm_provider"]) - ) - - return { - "provider": provider, - "model": route.get("model", self.config[model_key]), - "base_url": route.get("base_url", self.config.get("backend_url")), - } + route = self._deep_merge_dicts(legacy_route, default_route) + route = self._deep_merge_dicts(route, role_route) + route["provider"] = self._normalize_provider(route.get("provider")) + return route def _get_provider_kwargs(self, provider: Optional[str] = None) -> Dict[str, Any]: """Get provider-specific kwargs for LLM client creation."""