diff --git a/tests/test_llm_routing.py b/tests/test_llm_routing.py index ca4903cd..c503041a 100644 --- a/tests/test_llm_routing.py +++ b/tests/test_llm_routing.py @@ -137,6 +137,47 @@ def test_role_specific_llm_config_overrides_actual_graph_wiring(monkeypatch): assert "News Analyst" not in recorded_llms +def test_role_specific_route_inherits_default_provider_and_base_url(monkeypatch): + recorded_llms = {} + + monkeypatch.setattr( + "tradingagents.graph.trading_graph.create_llm_client", + lambda provider, model, base_url=None, **kwargs: DummyClient( + provider, model, base_url, **kwargs + ), + ) + monkeypatch.setattr( + "tradingagents.graph.trading_graph.FinancialSituationMemory", + lambda *args, **kwargs: object(), + ) + _patch_graph_setup_wiring(monkeypatch, recorded_llms) + + TradingAgentsGraph( + selected_analysts=["market"], + config={ + "llm_provider": "openai", + "backend_url": "https://legacy.example/v1", + "llm_routing": { + "default": { + "provider": "anthropic", + "base_url": "https://anthropic.example/v1", + }, + "roles": { + "portfolio_manager": { + "model": "claude-opus-4-6", + } + }, + }, + }, + ) + + assert recorded_llms["Market Analyst"]["provider"] == "anthropic" + assert recorded_llms["Market Analyst"]["base_url"] == "https://anthropic.example/v1" + assert recorded_llms["Portfolio Manager"]["provider"] == "anthropic" + assert recorded_llms["Portfolio Manager"]["base_url"] == "https://anthropic.example/v1" + assert recorded_llms["Portfolio Manager"]["model"] == "claude-opus-4-6" + + def test_unused_role_routes_do_not_instantiate_clients(monkeypatch): created_clients = [] diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index ff19604b..5a789e3d 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -179,12 +179,20 @@ class TradingAgentsGraph: def _create_role_llms(self, selected_analysts: List[str]) -> Dict[str, Any]: role_llms = {} + llm_cache = {} for role in self._get_required_roles(selected_analysts): thinker_depth = "deep" if role in self.DEEP_THINKING_ROLES else "quick" llm_config = self._resolve_llm_config(role, thinker_depth) if self._uses_legacy_llm(llm_config, thinker_depth): continue - role_llms[role] = self._create_llm_from_config(llm_config) + cache_key = ( + llm_config["provider"], + llm_config["model"], + llm_config.get("base_url"), + ) + if cache_key not in llm_cache: + llm_cache[cache_key] = self._create_llm_from_config(llm_config) + role_llms[role] = llm_cache[cache_key] return role_llms def _get_required_roles(self, selected_analysts: List[str]) -> set[str]: @@ -218,7 +226,9 @@ class TradingAgentsGraph: ) -> Dict[str, Any]: routing = self.config.get("llm_routing") or {} role_routes = routing.get("roles") or {} - route = role_routes.get(role) or routing.get("default") or {} + default_route = routing.get("default") or {} + role_route = role_routes.get(role) or {} + route = self._deep_merge_dicts(default_route, role_route) model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm" provider = self._normalize_provider(