fix: inherit default llm routing fields
This commit is contained in:
parent
5c483b4293
commit
507af6f6e2
|
|
@ -137,6 +137,47 @@ def test_role_specific_llm_config_overrides_actual_graph_wiring(monkeypatch):
|
|||
assert "News Analyst" not in recorded_llms
|
||||
|
||||
|
||||
def test_role_specific_route_inherits_default_provider_and_base_url(monkeypatch):
|
||||
recorded_llms = {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.trading_graph.create_llm_client",
|
||||
lambda provider, model, base_url=None, **kwargs: DummyClient(
|
||||
provider, model, base_url, **kwargs
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.trading_graph.FinancialSituationMemory",
|
||||
lambda *args, **kwargs: object(),
|
||||
)
|
||||
_patch_graph_setup_wiring(monkeypatch, recorded_llms)
|
||||
|
||||
TradingAgentsGraph(
|
||||
selected_analysts=["market"],
|
||||
config={
|
||||
"llm_provider": "openai",
|
||||
"backend_url": "https://legacy.example/v1",
|
||||
"llm_routing": {
|
||||
"default": {
|
||||
"provider": "anthropic",
|
||||
"base_url": "https://anthropic.example/v1",
|
||||
},
|
||||
"roles": {
|
||||
"portfolio_manager": {
|
||||
"model": "claude-opus-4-6",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert recorded_llms["Market Analyst"]["provider"] == "anthropic"
|
||||
assert recorded_llms["Market Analyst"]["base_url"] == "https://anthropic.example/v1"
|
||||
assert recorded_llms["Portfolio Manager"]["provider"] == "anthropic"
|
||||
assert recorded_llms["Portfolio Manager"]["base_url"] == "https://anthropic.example/v1"
|
||||
assert recorded_llms["Portfolio Manager"]["model"] == "claude-opus-4-6"
|
||||
|
||||
|
||||
def test_unused_role_routes_do_not_instantiate_clients(monkeypatch):
|
||||
created_clients = []
|
||||
|
||||
|
|
|
|||
|
|
@ -179,12 +179,20 @@ class TradingAgentsGraph:
|
|||
|
||||
def _create_role_llms(self, selected_analysts: List[str]) -> Dict[str, Any]:
|
||||
role_llms = {}
|
||||
llm_cache = {}
|
||||
for role in self._get_required_roles(selected_analysts):
|
||||
thinker_depth = "deep" if role in self.DEEP_THINKING_ROLES else "quick"
|
||||
llm_config = self._resolve_llm_config(role, thinker_depth)
|
||||
if self._uses_legacy_llm(llm_config, thinker_depth):
|
||||
continue
|
||||
role_llms[role] = self._create_llm_from_config(llm_config)
|
||||
cache_key = (
|
||||
llm_config["provider"],
|
||||
llm_config["model"],
|
||||
llm_config.get("base_url"),
|
||||
)
|
||||
if cache_key not in llm_cache:
|
||||
llm_cache[cache_key] = self._create_llm_from_config(llm_config)
|
||||
role_llms[role] = llm_cache[cache_key]
|
||||
return role_llms
|
||||
|
||||
def _get_required_roles(self, selected_analysts: List[str]) -> set[str]:
|
||||
|
|
@ -218,7 +226,9 @@ class TradingAgentsGraph:
|
|||
) -> Dict[str, Any]:
|
||||
routing = self.config.get("llm_routing") or {}
|
||||
role_routes = routing.get("roles") or {}
|
||||
route = role_routes.get(role) or routing.get("default") or {}
|
||||
default_route = routing.get("default") or {}
|
||||
role_route = role_routes.get(role) or {}
|
||||
route = self._deep_merge_dicts(default_route, role_route)
|
||||
|
||||
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
|
||||
provider = self._normalize_provider(
|
||||
|
|
|
|||
Loading…
Reference in New Issue