fix: normalize routed provider comparisons
This commit is contained in:
parent
f04a1fafd1
commit
5c483b4293
|
|
@ -177,6 +177,42 @@ def test_unused_role_routes_do_not_instantiate_clients(monkeypatch):
|
|||
assert ("bad-provider", "unused-model") not in created_clients
|
||||
|
||||
|
||||
def test_provider_normalization_avoids_duplicate_legacy_client_creation(monkeypatch):
|
||||
created_clients = []
|
||||
|
||||
def fake_create_llm_client(provider, model, base_url=None, **kwargs):
|
||||
created_clients.append((provider, model))
|
||||
return DummyClient(provider, model, base_url, **kwargs)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.trading_graph.create_llm_client",
|
||||
fake_create_llm_client,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.trading_graph.FinancialSituationMemory",
|
||||
lambda *args, **kwargs: object(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.trading_graph.GraphSetup.setup_graph",
|
||||
lambda self, selected_analysts: {"selected_analysts": selected_analysts},
|
||||
)
|
||||
|
||||
TradingAgentsGraph(
|
||||
selected_analysts=["market"],
|
||||
config={
|
||||
"llm_provider": "OpenAI",
|
||||
"quick_think_llm": "gpt-5-mini",
|
||||
"llm_routing": {
|
||||
"roles": {
|
||||
"market": {"provider": "openai", "model": "gpt-5-mini"},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert created_clients.count(("openai", "gpt-5-mini")) == 1
|
||||
|
||||
|
||||
def test_dataflow_config_returns_isolated_nested_routing(monkeypatch):
|
||||
monkeypatch.setattr(dataflow_config, "_config", None)
|
||||
dataflow_config.initialize_config()
|
||||
|
|
|
|||
|
|
@ -146,6 +146,9 @@ class TradingAgentsGraph:
|
|||
"""Merge user config over defaults without mutating the shared defaults."""
|
||||
return self._deep_merge_dicts(DEFAULT_CONFIG, config or {})
|
||||
|
||||
def _normalize_provider(self, provider: Optional[str]) -> str:
|
||||
return (provider or "").lower()
|
||||
|
||||
def _deep_merge_dicts(
|
||||
self,
|
||||
base: Dict[str, Any],
|
||||
|
|
@ -161,7 +164,7 @@ class TradingAgentsGraph:
|
|||
|
||||
def _create_legacy_llm(self, thinker_depth: str):
|
||||
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
|
||||
provider = self.config["llm_provider"]
|
||||
provider = self._normalize_provider(self.config["llm_provider"])
|
||||
llm_kwargs = self._get_provider_kwargs(provider)
|
||||
if self.callbacks:
|
||||
llm_kwargs["callbacks"] = self.callbacks
|
||||
|
|
@ -203,7 +206,7 @@ class TradingAgentsGraph:
|
|||
def _uses_legacy_llm(self, llm_config: Dict[str, Any], thinker_depth: str) -> bool:
|
||||
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
|
||||
return (
|
||||
llm_config["provider"] == self.config["llm_provider"]
|
||||
llm_config["provider"] == self._normalize_provider(self.config["llm_provider"])
|
||||
and llm_config["model"] == self.config[model_key]
|
||||
and llm_config.get("base_url") == self.config.get("backend_url")
|
||||
)
|
||||
|
|
@ -218,7 +221,9 @@ class TradingAgentsGraph:
|
|||
route = role_routes.get(role) or routing.get("default") or {}
|
||||
|
||||
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
|
||||
provider = route.get("provider", self.config["llm_provider"]).lower()
|
||||
provider = self._normalize_provider(
|
||||
route.get("provider", self.config["llm_provider"])
|
||||
)
|
||||
|
||||
return {
|
||||
"provider": provider,
|
||||
|
|
|
|||
Loading…
Reference in New Issue