fix: narrow llm routing client setup

This commit is contained in:
Garrick 2026-03-24 13:06:32 -07:00
parent dfcd669d28
commit f04a1fafd1
3 changed files with 189 additions and 19 deletions

View File

@ -1,6 +1,8 @@
from copy import deepcopy
import json
import tradingagents.dataflows.config as dataflow_config
import tradingagents.default_config as default_config
from tradingagents.graph.trading_graph import TradingAgentsGraph
@ -20,7 +22,87 @@ class DummyClient:
}
def test_role_specific_llm_config_overrides_default(monkeypatch):
class DummyStateGraph:
def __init__(self, _state_type):
self.nodes = {}
def add_node(self, name, node):
self.nodes[name] = node
def add_edge(self, *_args, **_kwargs):
return None
def add_conditional_edges(self, *_args, **_kwargs):
return None
def compile(self):
return {"nodes": self.nodes}
def _patch_graph_setup_wiring(monkeypatch, recorded_llms):
monkeypatch.setattr("tradingagents.graph.setup.StateGraph", DummyStateGraph)
monkeypatch.setattr("tradingagents.graph.setup.create_msg_delete", lambda: "delete")
def make_factory(node_name):
def factory(llm, *_args):
recorded_llms[node_name] = llm
return node_name
return factory
monkeypatch.setattr(
"tradingagents.graph.setup.create_market_analyst",
make_factory("Market Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_social_media_analyst",
make_factory("Social Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_news_analyst",
make_factory("News Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_fundamentals_analyst",
make_factory("Fundamentals Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_bull_researcher",
make_factory("Bull Researcher"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_bear_researcher",
make_factory("Bear Researcher"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_research_manager",
make_factory("Research Manager"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_trader",
make_factory("Trader"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_aggressive_debator",
make_factory("Aggressive Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_neutral_debator",
make_factory("Neutral Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_conservative_debator",
make_factory("Conservative Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_portfolio_manager",
make_factory("Portfolio Manager"),
)
def test_role_specific_llm_config_overrides_actual_graph_wiring(monkeypatch):
recorded_llms = {}
monkeypatch.setattr(
"tradingagents.graph.trading_graph.create_llm_client",
lambda provider, model, base_url=None, **kwargs: DummyClient(
@ -31,10 +113,7 @@ def test_role_specific_llm_config_overrides_default(monkeypatch):
"tradingagents.graph.trading_graph.FinancialSituationMemory",
lambda *args, **kwargs: object(),
)
monkeypatch.setattr(
"tradingagents.graph.trading_graph.GraphSetup.setup_graph",
lambda self, selected_analysts: {"selected_analysts": selected_analysts},
)
_patch_graph_setup_wiring(monkeypatch, recorded_llms)
config = {
"llm_routing": {
@ -48,12 +127,68 @@ def test_role_specific_llm_config_overrides_default(monkeypatch):
}
}
graph = TradingAgentsGraph(
TradingAgentsGraph(
selected_analysts=["market"],
config=deepcopy(config),
)
assert graph.graph_setup.portfolio_manager_llm["model"] == "gpt-5.2"
assert recorded_llms["Market Analyst"]["model"] == "gpt-5-mini"
assert recorded_llms["Portfolio Manager"]["model"] == "gpt-5.2"
assert "News Analyst" not in recorded_llms
def test_unused_role_routes_do_not_instantiate_clients(monkeypatch):
created_clients = []
def fake_create_llm_client(provider, model, base_url=None, **kwargs):
created_clients.append((provider, model))
if provider == "bad-provider":
raise AssertionError("unused role route should not be instantiated")
return DummyClient(provider, model, base_url, **kwargs)
monkeypatch.setattr(
"tradingagents.graph.trading_graph.create_llm_client",
fake_create_llm_client,
)
monkeypatch.setattr(
"tradingagents.graph.trading_graph.FinancialSituationMemory",
lambda *args, **kwargs: object(),
)
monkeypatch.setattr(
"tradingagents.graph.trading_graph.GraphSetup.setup_graph",
lambda self, selected_analysts: {"selected_analysts": selected_analysts},
)
TradingAgentsGraph(
selected_analysts=["market"],
config={
"llm_routing": {
"default": {"provider": "openai", "model": "gpt-5-mini"},
"roles": {
"news": {
"provider": "bad-provider",
"model": "unused-model",
}
},
}
},
)
assert ("bad-provider", "unused-model") not in created_clients
def test_dataflow_config_returns_isolated_nested_routing(monkeypatch):
monkeypatch.setattr(dataflow_config, "_config", None)
dataflow_config.initialize_config()
config = dataflow_config.get_config()
config["llm_routing"]["roles"]["portfolio_manager"] = {
"provider": "openai",
"model": "gpt-5.2",
}
assert dataflow_config.get_config()["llm_routing"]["roles"] == {}
assert default_config.DEFAULT_CONFIG["llm_routing"]["roles"] == {}
def test_log_state_writes_json_snapshot(tmp_path, monkeypatch):
@ -99,4 +234,7 @@ def test_log_state_writes_json_snapshot(tmp_path, monkeypatch):
/ "full_states_log_2026-03-24.json"
)
assert output_path.exists()
assert json.loads(output_path.read_text())["2026-03-24"]["company_of_interest"] == "Apple"
assert (
json.loads(output_path.read_text())["2026-03-24"]["company_of_interest"]
== "Apple"
)

View File

@ -1,3 +1,4 @@
from copy import deepcopy
import tradingagents.default_config as default_config
from typing import Dict, Optional
@ -5,26 +6,34 @@ from typing import Dict, Optional
_config: Optional[Dict] = None
def _deep_merge_dicts(base: Dict, override: Dict) -> Dict:
merged = deepcopy(base)
for key, value in override.items():
if isinstance(value, dict) and isinstance(merged.get(key), dict):
merged[key] = _deep_merge_dicts(merged[key], value)
else:
merged[key] = deepcopy(value)
return merged
def initialize_config():
"""Initialize the configuration with default values."""
global _config
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
_config = deepcopy(default_config.DEFAULT_CONFIG)
def set_config(config: Dict):
"""Update the configuration with custom values."""
global _config
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
_config.update(config)
_config = _deep_merge_dicts(default_config.DEFAULT_CONFIG, config)
def get_config() -> Dict:
"""Get the current configuration."""
if _config is None:
initialize_config()
return _config.copy()
return deepcopy(_config)
# Initialize with default config

View File

@ -43,6 +43,16 @@ from .signal_processing import SignalProcessor
class TradingAgentsGraph:
"""Main class that orchestrates the trading agents framework."""
ALWAYS_ON_ROLES = {
"bull_researcher",
"bear_researcher",
"research_manager",
"trader",
"aggressive_analyst",
"neutral_analyst",
"conservative_analyst",
"portfolio_manager",
}
QUICK_THINKING_ROLES = {
"market",
"social",
@ -90,7 +100,7 @@ class TradingAgentsGraph:
self.quick_thinking_llm = self._create_legacy_llm("quick")
self.deep_thinking_llm = self._create_legacy_llm("deep")
self.role_llms = self._create_role_llms()
self.role_llms = self._create_role_llms(selected_analysts)
# Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
@ -164,15 +174,20 @@ class TradingAgentsGraph:
)
return client.get_llm()
def _create_role_llms(self) -> Dict[str, Any]:
def _create_role_llms(self, selected_analysts: List[str]) -> Dict[str, Any]:
role_llms = {}
for role in self.QUICK_THINKING_ROLES | self.DEEP_THINKING_ROLES:
for role in self._get_required_roles(selected_analysts):
thinker_depth = "deep" if role in self.DEEP_THINKING_ROLES else "quick"
role_llms[role] = self._create_routed_llm(role, thinker_depth)
llm_config = self._resolve_llm_config(role, thinker_depth)
if self._uses_legacy_llm(llm_config, thinker_depth):
continue
role_llms[role] = self._create_llm_from_config(llm_config)
return role_llms
def _create_routed_llm(self, role: str, thinker_depth: str):
llm_config = self._resolve_llm_config(role, thinker_depth)
def _get_required_roles(self, selected_analysts: List[str]) -> set[str]:
return self.ALWAYS_ON_ROLES | set(selected_analysts)
def _create_llm_from_config(self, llm_config: Dict[str, Any]):
llm_kwargs = self._get_provider_kwargs(llm_config["provider"])
if self.callbacks:
llm_kwargs["callbacks"] = self.callbacks
@ -185,6 +200,14 @@ class TradingAgentsGraph:
)
return client.get_llm()
def _uses_legacy_llm(self, llm_config: Dict[str, Any], thinker_depth: str) -> bool:
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
return (
llm_config["provider"] == self.config["llm_provider"]
and llm_config["model"] == self.config[model_key]
and llm_config.get("base_url") == self.config.get("backend_url")
)
def _resolve_llm_config(
self,
role: str,