fix: narrow llm routing client setup
This commit is contained in:
parent
dfcd669d28
commit
f04a1fafd1
|
|
@ -1,6 +1,8 @@
|
|||
from copy import deepcopy
|
||||
import json
|
||||
|
||||
import tradingagents.dataflows.config as dataflow_config
|
||||
import tradingagents.default_config as default_config
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
|
||||
|
|
@ -20,7 +22,87 @@ class DummyClient:
|
|||
}
|
||||
|
||||
|
||||
def test_role_specific_llm_config_overrides_default(monkeypatch):
|
||||
class DummyStateGraph:
|
||||
def __init__(self, _state_type):
|
||||
self.nodes = {}
|
||||
|
||||
def add_node(self, name, node):
|
||||
self.nodes[name] = node
|
||||
|
||||
def add_edge(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
def add_conditional_edges(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
def compile(self):
|
||||
return {"nodes": self.nodes}
|
||||
|
||||
|
||||
def _patch_graph_setup_wiring(monkeypatch, recorded_llms):
|
||||
monkeypatch.setattr("tradingagents.graph.setup.StateGraph", DummyStateGraph)
|
||||
monkeypatch.setattr("tradingagents.graph.setup.create_msg_delete", lambda: "delete")
|
||||
|
||||
def make_factory(node_name):
|
||||
def factory(llm, *_args):
|
||||
recorded_llms[node_name] = llm
|
||||
return node_name
|
||||
|
||||
return factory
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.setup.create_market_analyst",
|
||||
make_factory("Market Analyst"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.setup.create_social_media_analyst",
|
||||
make_factory("Social Analyst"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.setup.create_news_analyst",
|
||||
make_factory("News Analyst"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.setup.create_fundamentals_analyst",
|
||||
make_factory("Fundamentals Analyst"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.setup.create_bull_researcher",
|
||||
make_factory("Bull Researcher"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.setup.create_bear_researcher",
|
||||
make_factory("Bear Researcher"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.setup.create_research_manager",
|
||||
make_factory("Research Manager"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.setup.create_trader",
|
||||
make_factory("Trader"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.setup.create_aggressive_debator",
|
||||
make_factory("Aggressive Analyst"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.setup.create_neutral_debator",
|
||||
make_factory("Neutral Analyst"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.setup.create_conservative_debator",
|
||||
make_factory("Conservative Analyst"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.setup.create_portfolio_manager",
|
||||
make_factory("Portfolio Manager"),
|
||||
)
|
||||
|
||||
|
||||
def test_role_specific_llm_config_overrides_actual_graph_wiring(monkeypatch):
|
||||
recorded_llms = {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.trading_graph.create_llm_client",
|
||||
lambda provider, model, base_url=None, **kwargs: DummyClient(
|
||||
|
|
@ -31,10 +113,7 @@ def test_role_specific_llm_config_overrides_default(monkeypatch):
|
|||
"tradingagents.graph.trading_graph.FinancialSituationMemory",
|
||||
lambda *args, **kwargs: object(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.trading_graph.GraphSetup.setup_graph",
|
||||
lambda self, selected_analysts: {"selected_analysts": selected_analysts},
|
||||
)
|
||||
_patch_graph_setup_wiring(monkeypatch, recorded_llms)
|
||||
|
||||
config = {
|
||||
"llm_routing": {
|
||||
|
|
@ -48,12 +127,68 @@ def test_role_specific_llm_config_overrides_default(monkeypatch):
|
|||
}
|
||||
}
|
||||
|
||||
graph = TradingAgentsGraph(
|
||||
TradingAgentsGraph(
|
||||
selected_analysts=["market"],
|
||||
config=deepcopy(config),
|
||||
)
|
||||
|
||||
assert graph.graph_setup.portfolio_manager_llm["model"] == "gpt-5.2"
|
||||
assert recorded_llms["Market Analyst"]["model"] == "gpt-5-mini"
|
||||
assert recorded_llms["Portfolio Manager"]["model"] == "gpt-5.2"
|
||||
assert "News Analyst" not in recorded_llms
|
||||
|
||||
|
||||
def test_unused_role_routes_do_not_instantiate_clients(monkeypatch):
|
||||
created_clients = []
|
||||
|
||||
def fake_create_llm_client(provider, model, base_url=None, **kwargs):
|
||||
created_clients.append((provider, model))
|
||||
if provider == "bad-provider":
|
||||
raise AssertionError("unused role route should not be instantiated")
|
||||
return DummyClient(provider, model, base_url, **kwargs)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.trading_graph.create_llm_client",
|
||||
fake_create_llm_client,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.trading_graph.FinancialSituationMemory",
|
||||
lambda *args, **kwargs: object(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.trading_graph.GraphSetup.setup_graph",
|
||||
lambda self, selected_analysts: {"selected_analysts": selected_analysts},
|
||||
)
|
||||
|
||||
TradingAgentsGraph(
|
||||
selected_analysts=["market"],
|
||||
config={
|
||||
"llm_routing": {
|
||||
"default": {"provider": "openai", "model": "gpt-5-mini"},
|
||||
"roles": {
|
||||
"news": {
|
||||
"provider": "bad-provider",
|
||||
"model": "unused-model",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert ("bad-provider", "unused-model") not in created_clients
|
||||
|
||||
|
||||
def test_dataflow_config_returns_isolated_nested_routing(monkeypatch):
|
||||
monkeypatch.setattr(dataflow_config, "_config", None)
|
||||
dataflow_config.initialize_config()
|
||||
|
||||
config = dataflow_config.get_config()
|
||||
config["llm_routing"]["roles"]["portfolio_manager"] = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-5.2",
|
||||
}
|
||||
|
||||
assert dataflow_config.get_config()["llm_routing"]["roles"] == {}
|
||||
assert default_config.DEFAULT_CONFIG["llm_routing"]["roles"] == {}
|
||||
|
||||
|
||||
def test_log_state_writes_json_snapshot(tmp_path, monkeypatch):
|
||||
|
|
@ -99,4 +234,7 @@ def test_log_state_writes_json_snapshot(tmp_path, monkeypatch):
|
|||
/ "full_states_log_2026-03-24.json"
|
||||
)
|
||||
assert output_path.exists()
|
||||
assert json.loads(output_path.read_text())["2026-03-24"]["company_of_interest"] == "Apple"
|
||||
assert (
|
||||
json.loads(output_path.read_text())["2026-03-24"]["company_of_interest"]
|
||||
== "Apple"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from copy import deepcopy
|
||||
import tradingagents.default_config as default_config
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
|
@ -5,26 +6,34 @@ from typing import Dict, Optional
|
|||
_config: Optional[Dict] = None
|
||||
|
||||
|
||||
def _deep_merge_dicts(base: Dict, override: Dict) -> Dict:
|
||||
merged = deepcopy(base)
|
||||
for key, value in override.items():
|
||||
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
||||
merged[key] = _deep_merge_dicts(merged[key], value)
|
||||
else:
|
||||
merged[key] = deepcopy(value)
|
||||
return merged
|
||||
|
||||
|
||||
def initialize_config():
|
||||
"""Initialize the configuration with default values."""
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
_config = deepcopy(default_config.DEFAULT_CONFIG)
|
||||
|
||||
|
||||
def set_config(config: Dict):
|
||||
"""Update the configuration with custom values."""
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
_config.update(config)
|
||||
_config = _deep_merge_dicts(default_config.DEFAULT_CONFIG, config)
|
||||
|
||||
|
||||
def get_config() -> Dict:
|
||||
"""Get the current configuration."""
|
||||
if _config is None:
|
||||
initialize_config()
|
||||
return _config.copy()
|
||||
return deepcopy(_config)
|
||||
|
||||
|
||||
# Initialize with default config
|
||||
|
|
|
|||
|
|
@ -43,6 +43,16 @@ from .signal_processing import SignalProcessor
|
|||
class TradingAgentsGraph:
|
||||
"""Main class that orchestrates the trading agents framework."""
|
||||
|
||||
ALWAYS_ON_ROLES = {
|
||||
"bull_researcher",
|
||||
"bear_researcher",
|
||||
"research_manager",
|
||||
"trader",
|
||||
"aggressive_analyst",
|
||||
"neutral_analyst",
|
||||
"conservative_analyst",
|
||||
"portfolio_manager",
|
||||
}
|
||||
QUICK_THINKING_ROLES = {
|
||||
"market",
|
||||
"social",
|
||||
|
|
@ -90,7 +100,7 @@ class TradingAgentsGraph:
|
|||
|
||||
self.quick_thinking_llm = self._create_legacy_llm("quick")
|
||||
self.deep_thinking_llm = self._create_legacy_llm("deep")
|
||||
self.role_llms = self._create_role_llms()
|
||||
self.role_llms = self._create_role_llms(selected_analysts)
|
||||
|
||||
# Initialize memories
|
||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||
|
|
@ -164,15 +174,20 @@ class TradingAgentsGraph:
|
|||
)
|
||||
return client.get_llm()
|
||||
|
||||
def _create_role_llms(self) -> Dict[str, Any]:
|
||||
def _create_role_llms(self, selected_analysts: List[str]) -> Dict[str, Any]:
|
||||
role_llms = {}
|
||||
for role in self.QUICK_THINKING_ROLES | self.DEEP_THINKING_ROLES:
|
||||
for role in self._get_required_roles(selected_analysts):
|
||||
thinker_depth = "deep" if role in self.DEEP_THINKING_ROLES else "quick"
|
||||
role_llms[role] = self._create_routed_llm(role, thinker_depth)
|
||||
llm_config = self._resolve_llm_config(role, thinker_depth)
|
||||
if self._uses_legacy_llm(llm_config, thinker_depth):
|
||||
continue
|
||||
role_llms[role] = self._create_llm_from_config(llm_config)
|
||||
return role_llms
|
||||
|
||||
def _create_routed_llm(self, role: str, thinker_depth: str):
|
||||
llm_config = self._resolve_llm_config(role, thinker_depth)
|
||||
def _get_required_roles(self, selected_analysts: List[str]) -> set[str]:
|
||||
return self.ALWAYS_ON_ROLES | set(selected_analysts)
|
||||
|
||||
def _create_llm_from_config(self, llm_config: Dict[str, Any]):
|
||||
llm_kwargs = self._get_provider_kwargs(llm_config["provider"])
|
||||
if self.callbacks:
|
||||
llm_kwargs["callbacks"] = self.callbacks
|
||||
|
|
@ -185,6 +200,14 @@ class TradingAgentsGraph:
|
|||
)
|
||||
return client.get_llm()
|
||||
|
||||
def _uses_legacy_llm(self, llm_config: Dict[str, Any], thinker_depth: str) -> bool:
|
||||
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
|
||||
return (
|
||||
llm_config["provider"] == self.config["llm_provider"]
|
||||
and llm_config["model"] == self.config[model_key]
|
||||
and llm_config.get("base_url") == self.config.get("backend_url")
|
||||
)
|
||||
|
||||
def _resolve_llm_config(
|
||||
self,
|
||||
role: str,
|
||||
|
|
|
|||
Loading…
Reference in New Issue