fix: narrow llm routing client setup
This commit is contained in:
parent
dfcd669d28
commit
f04a1fafd1
|
|
@ -1,6 +1,8 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
import tradingagents.dataflows.config as dataflow_config
|
||||||
|
import tradingagents.default_config as default_config
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -20,7 +22,87 @@ class DummyClient:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_role_specific_llm_config_overrides_default(monkeypatch):
|
class DummyStateGraph:
|
||||||
|
def __init__(self, _state_type):
|
||||||
|
self.nodes = {}
|
||||||
|
|
||||||
|
def add_node(self, name, node):
|
||||||
|
self.nodes[name] = node
|
||||||
|
|
||||||
|
def add_edge(self, *_args, **_kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add_conditional_edges(self, *_args, **_kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def compile(self):
|
||||||
|
return {"nodes": self.nodes}
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_graph_setup_wiring(monkeypatch, recorded_llms):
|
||||||
|
monkeypatch.setattr("tradingagents.graph.setup.StateGraph", DummyStateGraph)
|
||||||
|
monkeypatch.setattr("tradingagents.graph.setup.create_msg_delete", lambda: "delete")
|
||||||
|
|
||||||
|
def make_factory(node_name):
|
||||||
|
def factory(llm, *_args):
|
||||||
|
recorded_llms[node_name] = llm
|
||||||
|
return node_name
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.setup.create_market_analyst",
|
||||||
|
make_factory("Market Analyst"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.setup.create_social_media_analyst",
|
||||||
|
make_factory("Social Analyst"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.setup.create_news_analyst",
|
||||||
|
make_factory("News Analyst"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.setup.create_fundamentals_analyst",
|
||||||
|
make_factory("Fundamentals Analyst"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.setup.create_bull_researcher",
|
||||||
|
make_factory("Bull Researcher"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.setup.create_bear_researcher",
|
||||||
|
make_factory("Bear Researcher"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.setup.create_research_manager",
|
||||||
|
make_factory("Research Manager"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.setup.create_trader",
|
||||||
|
make_factory("Trader"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.setup.create_aggressive_debator",
|
||||||
|
make_factory("Aggressive Analyst"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.setup.create_neutral_debator",
|
||||||
|
make_factory("Neutral Analyst"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.setup.create_conservative_debator",
|
||||||
|
make_factory("Conservative Analyst"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.setup.create_portfolio_manager",
|
||||||
|
make_factory("Portfolio Manager"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_role_specific_llm_config_overrides_actual_graph_wiring(monkeypatch):
|
||||||
|
recorded_llms = {}
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"tradingagents.graph.trading_graph.create_llm_client",
|
"tradingagents.graph.trading_graph.create_llm_client",
|
||||||
lambda provider, model, base_url=None, **kwargs: DummyClient(
|
lambda provider, model, base_url=None, **kwargs: DummyClient(
|
||||||
|
|
@ -31,10 +113,7 @@ def test_role_specific_llm_config_overrides_default(monkeypatch):
|
||||||
"tradingagents.graph.trading_graph.FinancialSituationMemory",
|
"tradingagents.graph.trading_graph.FinancialSituationMemory",
|
||||||
lambda *args, **kwargs: object(),
|
lambda *args, **kwargs: object(),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
_patch_graph_setup_wiring(monkeypatch, recorded_llms)
|
||||||
"tradingagents.graph.trading_graph.GraphSetup.setup_graph",
|
|
||||||
lambda self, selected_analysts: {"selected_analysts": selected_analysts},
|
|
||||||
)
|
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"llm_routing": {
|
"llm_routing": {
|
||||||
|
|
@ -48,12 +127,68 @@ def test_role_specific_llm_config_overrides_default(monkeypatch):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
graph = TradingAgentsGraph(
|
TradingAgentsGraph(
|
||||||
selected_analysts=["market"],
|
selected_analysts=["market"],
|
||||||
config=deepcopy(config),
|
config=deepcopy(config),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert graph.graph_setup.portfolio_manager_llm["model"] == "gpt-5.2"
|
assert recorded_llms["Market Analyst"]["model"] == "gpt-5-mini"
|
||||||
|
assert recorded_llms["Portfolio Manager"]["model"] == "gpt-5.2"
|
||||||
|
assert "News Analyst" not in recorded_llms
|
||||||
|
|
||||||
|
|
||||||
|
def test_unused_role_routes_do_not_instantiate_clients(monkeypatch):
|
||||||
|
created_clients = []
|
||||||
|
|
||||||
|
def fake_create_llm_client(provider, model, base_url=None, **kwargs):
|
||||||
|
created_clients.append((provider, model))
|
||||||
|
if provider == "bad-provider":
|
||||||
|
raise AssertionError("unused role route should not be instantiated")
|
||||||
|
return DummyClient(provider, model, base_url, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.trading_graph.create_llm_client",
|
||||||
|
fake_create_llm_client,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.trading_graph.FinancialSituationMemory",
|
||||||
|
lambda *args, **kwargs: object(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tradingagents.graph.trading_graph.GraphSetup.setup_graph",
|
||||||
|
lambda self, selected_analysts: {"selected_analysts": selected_analysts},
|
||||||
|
)
|
||||||
|
|
||||||
|
TradingAgentsGraph(
|
||||||
|
selected_analysts=["market"],
|
||||||
|
config={
|
||||||
|
"llm_routing": {
|
||||||
|
"default": {"provider": "openai", "model": "gpt-5-mini"},
|
||||||
|
"roles": {
|
||||||
|
"news": {
|
||||||
|
"provider": "bad-provider",
|
||||||
|
"model": "unused-model",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ("bad-provider", "unused-model") not in created_clients
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataflow_config_returns_isolated_nested_routing(monkeypatch):
|
||||||
|
monkeypatch.setattr(dataflow_config, "_config", None)
|
||||||
|
dataflow_config.initialize_config()
|
||||||
|
|
||||||
|
config = dataflow_config.get_config()
|
||||||
|
config["llm_routing"]["roles"]["portfolio_manager"] = {
|
||||||
|
"provider": "openai",
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert dataflow_config.get_config()["llm_routing"]["roles"] == {}
|
||||||
|
assert default_config.DEFAULT_CONFIG["llm_routing"]["roles"] == {}
|
||||||
|
|
||||||
|
|
||||||
def test_log_state_writes_json_snapshot(tmp_path, monkeypatch):
|
def test_log_state_writes_json_snapshot(tmp_path, monkeypatch):
|
||||||
|
|
@ -99,4 +234,7 @@ def test_log_state_writes_json_snapshot(tmp_path, monkeypatch):
|
||||||
/ "full_states_log_2026-03-24.json"
|
/ "full_states_log_2026-03-24.json"
|
||||||
)
|
)
|
||||||
assert output_path.exists()
|
assert output_path.exists()
|
||||||
assert json.loads(output_path.read_text())["2026-03-24"]["company_of_interest"] == "Apple"
|
assert (
|
||||||
|
json.loads(output_path.read_text())["2026-03-24"]["company_of_interest"]
|
||||||
|
== "Apple"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from copy import deepcopy
|
||||||
import tradingagents.default_config as default_config
|
import tradingagents.default_config as default_config
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
@ -5,26 +6,34 @@ from typing import Dict, Optional
|
||||||
_config: Optional[Dict] = None
|
_config: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _deep_merge_dicts(base: Dict, override: Dict) -> Dict:
|
||||||
|
merged = deepcopy(base)
|
||||||
|
for key, value in override.items():
|
||||||
|
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
||||||
|
merged[key] = _deep_merge_dicts(merged[key], value)
|
||||||
|
else:
|
||||||
|
merged[key] = deepcopy(value)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
def initialize_config():
|
def initialize_config():
|
||||||
"""Initialize the configuration with default values."""
|
"""Initialize the configuration with default values."""
|
||||||
global _config
|
global _config
|
||||||
if _config is None:
|
if _config is None:
|
||||||
_config = default_config.DEFAULT_CONFIG.copy()
|
_config = deepcopy(default_config.DEFAULT_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
def set_config(config: Dict):
|
def set_config(config: Dict):
|
||||||
"""Update the configuration with custom values."""
|
"""Update the configuration with custom values."""
|
||||||
global _config
|
global _config
|
||||||
if _config is None:
|
_config = _deep_merge_dicts(default_config.DEFAULT_CONFIG, config)
|
||||||
_config = default_config.DEFAULT_CONFIG.copy()
|
|
||||||
_config.update(config)
|
|
||||||
|
|
||||||
|
|
||||||
def get_config() -> Dict:
|
def get_config() -> Dict:
|
||||||
"""Get the current configuration."""
|
"""Get the current configuration."""
|
||||||
if _config is None:
|
if _config is None:
|
||||||
initialize_config()
|
initialize_config()
|
||||||
return _config.copy()
|
return deepcopy(_config)
|
||||||
|
|
||||||
|
|
||||||
# Initialize with default config
|
# Initialize with default config
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,16 @@ from .signal_processing import SignalProcessor
|
||||||
class TradingAgentsGraph:
|
class TradingAgentsGraph:
|
||||||
"""Main class that orchestrates the trading agents framework."""
|
"""Main class that orchestrates the trading agents framework."""
|
||||||
|
|
||||||
|
ALWAYS_ON_ROLES = {
|
||||||
|
"bull_researcher",
|
||||||
|
"bear_researcher",
|
||||||
|
"research_manager",
|
||||||
|
"trader",
|
||||||
|
"aggressive_analyst",
|
||||||
|
"neutral_analyst",
|
||||||
|
"conservative_analyst",
|
||||||
|
"portfolio_manager",
|
||||||
|
}
|
||||||
QUICK_THINKING_ROLES = {
|
QUICK_THINKING_ROLES = {
|
||||||
"market",
|
"market",
|
||||||
"social",
|
"social",
|
||||||
|
|
@ -90,7 +100,7 @@ class TradingAgentsGraph:
|
||||||
|
|
||||||
self.quick_thinking_llm = self._create_legacy_llm("quick")
|
self.quick_thinking_llm = self._create_legacy_llm("quick")
|
||||||
self.deep_thinking_llm = self._create_legacy_llm("deep")
|
self.deep_thinking_llm = self._create_legacy_llm("deep")
|
||||||
self.role_llms = self._create_role_llms()
|
self.role_llms = self._create_role_llms(selected_analysts)
|
||||||
|
|
||||||
# Initialize memories
|
# Initialize memories
|
||||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||||
|
|
@ -164,15 +174,20 @@ class TradingAgentsGraph:
|
||||||
)
|
)
|
||||||
return client.get_llm()
|
return client.get_llm()
|
||||||
|
|
||||||
def _create_role_llms(self) -> Dict[str, Any]:
|
def _create_role_llms(self, selected_analysts: List[str]) -> Dict[str, Any]:
|
||||||
role_llms = {}
|
role_llms = {}
|
||||||
for role in self.QUICK_THINKING_ROLES | self.DEEP_THINKING_ROLES:
|
for role in self._get_required_roles(selected_analysts):
|
||||||
thinker_depth = "deep" if role in self.DEEP_THINKING_ROLES else "quick"
|
thinker_depth = "deep" if role in self.DEEP_THINKING_ROLES else "quick"
|
||||||
role_llms[role] = self._create_routed_llm(role, thinker_depth)
|
llm_config = self._resolve_llm_config(role, thinker_depth)
|
||||||
|
if self._uses_legacy_llm(llm_config, thinker_depth):
|
||||||
|
continue
|
||||||
|
role_llms[role] = self._create_llm_from_config(llm_config)
|
||||||
return role_llms
|
return role_llms
|
||||||
|
|
||||||
def _create_routed_llm(self, role: str, thinker_depth: str):
|
def _get_required_roles(self, selected_analysts: List[str]) -> set[str]:
|
||||||
llm_config = self._resolve_llm_config(role, thinker_depth)
|
return self.ALWAYS_ON_ROLES | set(selected_analysts)
|
||||||
|
|
||||||
|
def _create_llm_from_config(self, llm_config: Dict[str, Any]):
|
||||||
llm_kwargs = self._get_provider_kwargs(llm_config["provider"])
|
llm_kwargs = self._get_provider_kwargs(llm_config["provider"])
|
||||||
if self.callbacks:
|
if self.callbacks:
|
||||||
llm_kwargs["callbacks"] = self.callbacks
|
llm_kwargs["callbacks"] = self.callbacks
|
||||||
|
|
@ -185,6 +200,14 @@ class TradingAgentsGraph:
|
||||||
)
|
)
|
||||||
return client.get_llm()
|
return client.get_llm()
|
||||||
|
|
||||||
|
def _uses_legacy_llm(self, llm_config: Dict[str, Any], thinker_depth: str) -> bool:
|
||||||
|
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
|
||||||
|
return (
|
||||||
|
llm_config["provider"] == self.config["llm_provider"]
|
||||||
|
and llm_config["model"] == self.config[model_key]
|
||||||
|
and llm_config.get("base_url") == self.config.get("backend_url")
|
||||||
|
)
|
||||||
|
|
||||||
def _resolve_llm_config(
|
def _resolve_llm_config(
|
||||||
self,
|
self,
|
||||||
role: str,
|
role: str,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue