fix: normalize routed provider comparisons

This commit is contained in:
Garrick 2026-03-24 13:12:55 -07:00
parent f04a1fafd1
commit 5c483b4293
2 changed files with 44 additions and 3 deletions

View File

@ -177,6 +177,42 @@ def test_unused_role_routes_do_not_instantiate_clients(monkeypatch):
assert ("bad-provider", "unused-model") not in created_clients
def test_provider_normalization_avoids_duplicate_legacy_client_creation(monkeypatch):
created_clients = []
def fake_create_llm_client(provider, model, base_url=None, **kwargs):
created_clients.append((provider, model))
return DummyClient(provider, model, base_url, **kwargs)
monkeypatch.setattr(
"tradingagents.graph.trading_graph.create_llm_client",
fake_create_llm_client,
)
monkeypatch.setattr(
"tradingagents.graph.trading_graph.FinancialSituationMemory",
lambda *args, **kwargs: object(),
)
monkeypatch.setattr(
"tradingagents.graph.trading_graph.GraphSetup.setup_graph",
lambda self, selected_analysts: {"selected_analysts": selected_analysts},
)
TradingAgentsGraph(
selected_analysts=["market"],
config={
"llm_provider": "OpenAI",
"quick_think_llm": "gpt-5-mini",
"llm_routing": {
"roles": {
"market": {"provider": "openai", "model": "gpt-5-mini"},
},
},
},
)
assert created_clients.count(("openai", "gpt-5-mini")) == 1
def test_dataflow_config_returns_isolated_nested_routing(monkeypatch):
monkeypatch.setattr(dataflow_config, "_config", None)
dataflow_config.initialize_config()

View File

@ -146,6 +146,9 @@ class TradingAgentsGraph:
"""Merge user config over defaults without mutating the shared defaults."""
return self._deep_merge_dicts(DEFAULT_CONFIG, config or {})
def _normalize_provider(self, provider: Optional[str]) -> str:
return (provider or "").lower()
def _deep_merge_dicts(
self,
base: Dict[str, Any],
@ -161,7 +164,7 @@ class TradingAgentsGraph:
def _create_legacy_llm(self, thinker_depth: str):
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
provider = self.config["llm_provider"]
provider = self._normalize_provider(self.config["llm_provider"])
llm_kwargs = self._get_provider_kwargs(provider)
if self.callbacks:
llm_kwargs["callbacks"] = self.callbacks
@ -203,7 +206,7 @@ class TradingAgentsGraph:
def _uses_legacy_llm(self, llm_config: Dict[str, Any], thinker_depth: str) -> bool:
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
return (
llm_config["provider"] == self.config["llm_provider"]
llm_config["provider"] == self._normalize_provider(self.config["llm_provider"])
and llm_config["model"] == self.config[model_key]
and llm_config.get("base_url") == self.config.get("backend_url")
)
@ -218,7 +221,9 @@ class TradingAgentsGraph:
route = role_routes.get(role) or routing.get("default") or {}
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
provider = route.get("provider", self.config["llm_provider"]).lower()
provider = self._normalize_provider(
route.get("provider", self.config["llm_provider"])
)
return {
"provider": provider,