feat: add role-based llm routing

This commit is contained in:
Garrick 2026-03-24 12:49:11 -07:00
parent 4dbcdcc8a8
commit 8a4a1d1faa
6 changed files with 240 additions and 43 deletions

View File

@ -3,6 +3,7 @@ import datetime
import typer import typer
from pathlib import Path from pathlib import Path
from functools import wraps from functools import wraps
from copy import deepcopy
from rich.console import Console from rich.console import Console
from dotenv import load_dotenv from dotenv import load_dotenv
@ -920,13 +921,19 @@ def run_analysis():
selections = get_user_selections() selections = get_user_selections()
# Create config with selected research depth # Create config with selected research depth
config = DEFAULT_CONFIG.copy() config = deepcopy(DEFAULT_CONFIG)
config["max_debate_rounds"] = selections["research_depth"] config["max_debate_rounds"] = selections["research_depth"]
config["max_risk_discuss_rounds"] = selections["research_depth"] config["max_risk_discuss_rounds"] = selections["research_depth"]
config["quick_think_llm"] = selections["shallow_thinker"] config["quick_think_llm"] = selections["shallow_thinker"]
config["deep_think_llm"] = selections["deep_thinker"] config["deep_think_llm"] = selections["deep_thinker"]
config["backend_url"] = selections["backend_url"] config["backend_url"] = selections["backend_url"]
config["llm_provider"] = selections["llm_provider"].lower() config["llm_provider"] = selections["llm_provider"].lower()
config["llm_routing"] = build_llm_routing_config(
provider=selections["llm_provider"],
shallow_model=selections["shallow_thinker"],
deep_model=selections["deep_thinker"],
backend_url=selections["backend_url"],
)
# Provider-specific thinking configuration # Provider-specific thinking configuration
config["google_thinking_level"] = selections.get("google_thinking_level") config["google_thinking_level"] = selections.get("google_thinking_level")
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")

View File

@ -1,5 +1,5 @@
import questionary import questionary
from typing import List, Optional, Tuple, Dict from typing import Any, Dict, List, Optional, Tuple
from rich.console import Console from rich.console import Console
@ -300,6 +300,34 @@ def select_llm_provider() -> tuple[str, str]:
return display_name, url return display_name, url
def build_llm_routing_config(
provider: str,
shallow_model: str,
deep_model: str,
backend_url: Optional[str] = None,
) -> Dict[str, Any]:
default_route: Dict[str, Any] = {
"provider": provider.lower(),
"model": shallow_model,
}
deep_route: Dict[str, Any] = {
"provider": provider.lower(),
"model": deep_model,
}
if backend_url:
default_route["base_url"] = backend_url
deep_route["base_url"] = backend_url
return {
"default": default_route,
"roles": {
"research_manager": deep_route.copy(),
"portfolio_manager": deep_route.copy(),
},
}
def ask_openai_reasoning_effort() -> str: def ask_openai_reasoning_effort() -> str:
"""Ask for OpenAI reasoning effort level.""" """Ask for OpenAI reasoning effort level."""
choices = [ choices = [

55
tests/test_llm_routing.py Normal file
View File

@ -0,0 +1,55 @@
from copy import deepcopy
from tradingagents.graph.trading_graph import TradingAgentsGraph
class DummyClient:
def __init__(self, provider, model, base_url=None, **kwargs):
self.provider = provider
self.model = model
self.base_url = base_url
self.kwargs = kwargs
def get_llm(self):
return {
"provider": self.provider,
"model": self.model,
"base_url": self.base_url,
"kwargs": self.kwargs,
}
def test_role_specific_llm_config_overrides_default(monkeypatch):
monkeypatch.setattr(
"tradingagents.graph.trading_graph.create_llm_client",
lambda provider, model, base_url=None, **kwargs: DummyClient(
provider, model, base_url, **kwargs
),
)
monkeypatch.setattr(
"tradingagents.graph.trading_graph.FinancialSituationMemory",
lambda *args, **kwargs: object(),
)
monkeypatch.setattr(
"tradingagents.graph.trading_graph.GraphSetup.setup_graph",
lambda self, selected_analysts: {"selected_analysts": selected_analysts},
)
config = {
"llm_routing": {
"default": {"provider": "openai", "model": "gpt-5-mini"},
"roles": {
"portfolio_manager": {
"provider": "openai",
"model": "gpt-5.2",
}
},
}
}
graph = TradingAgentsGraph(
selected_analysts=["market"],
config=deepcopy(config),
)
assert graph.graph_setup.portfolio_manager_llm["model"] == "gpt-5.2"

View File

@ -12,6 +12,10 @@ DEFAULT_CONFIG = {
"deep_think_llm": "gpt-5.2", "deep_think_llm": "gpt-5.2",
"quick_think_llm": "gpt-5-mini", "quick_think_llm": "gpt-5-mini",
"backend_url": "https://api.openai.com/v1", "backend_url": "https://api.openai.com/v1",
"llm_routing": {
"default": {},
"roles": {},
},
# Provider-specific thinking configuration # Provider-specific thinking configuration
"google_thinking_level": None, # "high", "minimal", etc. "google_thinking_level": None, # "high", "minimal", etc.
"openai_reasoning_effort": None, # "medium", "high", "low" "openai_reasoning_effort": None, # "medium", "high", "low"

View File

@ -25,10 +25,12 @@ class GraphSetup:
invest_judge_memory, invest_judge_memory,
portfolio_manager_memory, portfolio_manager_memory,
conditional_logic: ConditionalLogic, conditional_logic: ConditionalLogic,
role_llms: Dict[str, Any] | None = None,
): ):
"""Initialize with required components.""" """Initialize with required components."""
self.quick_thinking_llm = quick_thinking_llm self.quick_thinking_llm = quick_thinking_llm
self.deep_thinking_llm = deep_thinking_llm self.deep_thinking_llm = deep_thinking_llm
self.role_llms = role_llms or {}
self.tool_nodes = tool_nodes self.tool_nodes = tool_nodes
self.bull_memory = bull_memory self.bull_memory = bull_memory
self.bear_memory = bear_memory self.bear_memory = bear_memory
@ -36,6 +38,37 @@ class GraphSetup:
self.invest_judge_memory = invest_judge_memory self.invest_judge_memory = invest_judge_memory
self.portfolio_manager_memory = portfolio_manager_memory self.portfolio_manager_memory = portfolio_manager_memory
self.conditional_logic = conditional_logic self.conditional_logic = conditional_logic
self.market_analyst_llm = self._get_role_llm("market", self.quick_thinking_llm)
self.social_analyst_llm = self._get_role_llm("social", self.quick_thinking_llm)
self.news_analyst_llm = self._get_role_llm("news", self.quick_thinking_llm)
self.fundamentals_analyst_llm = self._get_role_llm(
"fundamentals", self.quick_thinking_llm
)
self.bull_researcher_llm = self._get_role_llm(
"bull_researcher", self.quick_thinking_llm
)
self.bear_researcher_llm = self._get_role_llm(
"bear_researcher", self.quick_thinking_llm
)
self.research_manager_llm = self._get_role_llm(
"research_manager", self.deep_thinking_llm
)
self.trader_llm = self._get_role_llm("trader", self.quick_thinking_llm)
self.aggressive_analyst_llm = self._get_role_llm(
"aggressive_analyst", self.quick_thinking_llm
)
self.neutral_analyst_llm = self._get_role_llm(
"neutral_analyst", self.quick_thinking_llm
)
self.conservative_analyst_llm = self._get_role_llm(
"conservative_analyst", self.quick_thinking_llm
)
self.portfolio_manager_llm = self._get_role_llm(
"portfolio_manager", self.deep_thinking_llm
)
def _get_role_llm(self, role: str, fallback_llm: ChatOpenAI):
return self.role_llms.get(role, fallback_llm)
def setup_graph( def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"] self, selected_analysts=["market", "social", "news", "fundamentals"]
@ -59,50 +92,52 @@ class GraphSetup:
if "market" in selected_analysts: if "market" in selected_analysts:
analyst_nodes["market"] = create_market_analyst( analyst_nodes["market"] = create_market_analyst(
self.quick_thinking_llm self.market_analyst_llm
) )
delete_nodes["market"] = create_msg_delete() delete_nodes["market"] = create_msg_delete()
tool_nodes["market"] = self.tool_nodes["market"] tool_nodes["market"] = self.tool_nodes["market"]
if "social" in selected_analysts: if "social" in selected_analysts:
analyst_nodes["social"] = create_social_media_analyst( analyst_nodes["social"] = create_social_media_analyst(
self.quick_thinking_llm self.social_analyst_llm
) )
delete_nodes["social"] = create_msg_delete() delete_nodes["social"] = create_msg_delete()
tool_nodes["social"] = self.tool_nodes["social"] tool_nodes["social"] = self.tool_nodes["social"]
if "news" in selected_analysts: if "news" in selected_analysts:
analyst_nodes["news"] = create_news_analyst( analyst_nodes["news"] = create_news_analyst(
self.quick_thinking_llm self.news_analyst_llm
) )
delete_nodes["news"] = create_msg_delete() delete_nodes["news"] = create_msg_delete()
tool_nodes["news"] = self.tool_nodes["news"] tool_nodes["news"] = self.tool_nodes["news"]
if "fundamentals" in selected_analysts: if "fundamentals" in selected_analysts:
analyst_nodes["fundamentals"] = create_fundamentals_analyst( analyst_nodes["fundamentals"] = create_fundamentals_analyst(
self.quick_thinking_llm self.fundamentals_analyst_llm
) )
delete_nodes["fundamentals"] = create_msg_delete() delete_nodes["fundamentals"] = create_msg_delete()
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"] tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
# Create researcher and manager nodes # Create researcher and manager nodes
bull_researcher_node = create_bull_researcher( bull_researcher_node = create_bull_researcher(
self.quick_thinking_llm, self.bull_memory self.bull_researcher_llm, self.bull_memory
) )
bear_researcher_node = create_bear_researcher( bear_researcher_node = create_bear_researcher(
self.quick_thinking_llm, self.bear_memory self.bear_researcher_llm, self.bear_memory
) )
research_manager_node = create_research_manager( research_manager_node = create_research_manager(
self.deep_thinking_llm, self.invest_judge_memory self.research_manager_llm, self.invest_judge_memory
) )
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory) trader_node = create_trader(self.trader_llm, self.trader_memory)
# Create risk analysis nodes # Create risk analysis nodes
aggressive_analyst = create_aggressive_debator(self.quick_thinking_llm) aggressive_analyst = create_aggressive_debator(self.aggressive_analyst_llm)
neutral_analyst = create_neutral_debator(self.quick_thinking_llm) neutral_analyst = create_neutral_debator(self.neutral_analyst_llm)
conservative_analyst = create_conservative_debator(self.quick_thinking_llm) conservative_analyst = create_conservative_debator(
self.conservative_analyst_llm
)
portfolio_manager_node = create_portfolio_manager( portfolio_manager_node = create_portfolio_manager(
self.deep_thinking_llm, self.portfolio_manager_memory self.portfolio_manager_llm, self.portfolio_manager_memory
) )
# Create workflow # Create workflow

View File

@ -1,10 +1,8 @@
# TradingAgents/graph/trading_graph.py # TradingAgents/graph/trading_graph.py
import os import os
from pathlib import Path from copy import deepcopy
import json from typing import Dict, Any, List, Optional
from datetime import date
from typing import Dict, Any, Tuple, List, Optional
from langgraph.prebuilt import ToolNode from langgraph.prebuilt import ToolNode
@ -43,6 +41,23 @@ from .signal_processing import SignalProcessor
class TradingAgentsGraph: class TradingAgentsGraph:
"""Main class that orchestrates the trading agents framework.""" """Main class that orchestrates the trading agents framework."""
QUICK_THINKING_ROLES = {
"market",
"social",
"news",
"fundamentals",
"bull_researcher",
"bear_researcher",
"trader",
"aggressive_analyst",
"neutral_analyst",
"conservative_analyst",
}
DEEP_THINKING_ROLES = {
"research_manager",
"portfolio_manager",
}
def __init__( def __init__(
self, self,
selected_analysts=["market", "social", "news", "fundamentals"], selected_analysts=["market", "social", "news", "fundamentals"],
@ -59,7 +74,7 @@ class TradingAgentsGraph:
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats) callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
""" """
self.debug = debug self.debug = debug
self.config = config or DEFAULT_CONFIG self.config = self._build_config(config)
self.callbacks = callbacks or [] self.callbacks = callbacks or []
# Update the interface's config # Update the interface's config
@ -71,28 +86,9 @@ class TradingAgentsGraph:
exist_ok=True, exist_ok=True,
) )
# Initialize LLMs with provider-specific thinking configuration self.quick_thinking_llm = self._create_legacy_llm("quick")
llm_kwargs = self._get_provider_kwargs() self.deep_thinking_llm = self._create_legacy_llm("deep")
self.role_llms = self._create_role_llms()
# Add callbacks to kwargs if provided (passed to LLM constructor)
if self.callbacks:
llm_kwargs["callbacks"] = self.callbacks
deep_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["deep_think_llm"],
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
quick_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["quick_think_llm"],
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
self.deep_thinking_llm = deep_client.get_llm()
self.quick_thinking_llm = quick_client.get_llm()
# Initialize memories # Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config) self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
@ -119,6 +115,7 @@ class TradingAgentsGraph:
self.invest_judge_memory, self.invest_judge_memory,
self.portfolio_manager_memory, self.portfolio_manager_memory,
self.conditional_logic, self.conditional_logic,
role_llms=self.role_llms,
) )
self.propagator = Propagator() self.propagator = Propagator()
@ -133,10 +130,81 @@ class TradingAgentsGraph:
# Set up the graph # Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts) self.graph = self.graph_setup.setup_graph(selected_analysts)
def _get_provider_kwargs(self) -> Dict[str, Any]: def _build_config(self, config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""Merge user config over defaults without mutating the shared defaults."""
return self._deep_merge_dicts(DEFAULT_CONFIG, config or {})
def _deep_merge_dicts(
self,
base: Dict[str, Any],
override: Dict[str, Any],
) -> Dict[str, Any]:
merged = deepcopy(base)
for key, value in override.items():
if isinstance(value, dict) and isinstance(merged.get(key), dict):
merged[key] = self._deep_merge_dicts(merged[key], value)
else:
merged[key] = deepcopy(value)
return merged
def _create_legacy_llm(self, thinker_depth: str):
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
provider = self.config["llm_provider"]
llm_kwargs = self._get_provider_kwargs(provider)
if self.callbacks:
llm_kwargs["callbacks"] = self.callbacks
client = create_llm_client(
provider=provider,
model=self.config[model_key],
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
return client.get_llm()
def _create_role_llms(self) -> Dict[str, Any]:
role_llms = {}
for role in self.QUICK_THINKING_ROLES | self.DEEP_THINKING_ROLES:
thinker_depth = "deep" if role in self.DEEP_THINKING_ROLES else "quick"
role_llms[role] = self._create_routed_llm(role, thinker_depth)
return role_llms
def _create_routed_llm(self, role: str, thinker_depth: str):
llm_config = self._resolve_llm_config(role, thinker_depth)
llm_kwargs = self._get_provider_kwargs(llm_config["provider"])
if self.callbacks:
llm_kwargs["callbacks"] = self.callbacks
client = create_llm_client(
provider=llm_config["provider"],
model=llm_config["model"],
base_url=llm_config.get("base_url"),
**llm_kwargs,
)
return client.get_llm()
def _resolve_llm_config(
self,
role: str,
thinker_depth: str,
) -> Dict[str, Any]:
routing = self.config.get("llm_routing") or {}
role_routes = routing.get("roles") or {}
route = role_routes.get(role) or routing.get("default") or {}
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
provider = route.get("provider", self.config["llm_provider"]).lower()
return {
"provider": provider,
"model": route.get("model", self.config[model_key]),
"base_url": route.get("base_url", self.config.get("backend_url")),
}
def _get_provider_kwargs(self, provider: Optional[str] = None) -> Dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation.""" """Get provider-specific kwargs for LLM client creation."""
kwargs = {} kwargs = {}
provider = self.config.get("llm_provider", "").lower() provider = (provider or self.config.get("llm_provider", "")).lower()
if provider == "google": if provider == "google":
thinking_level = self.config.get("google_thinking_level") thinking_level = self.config.get("google_thinking_level")