feat: add role-based llm routing

This commit is contained in:
Garrick 2026-03-24 12:49:11 -07:00
parent 4dbcdcc8a8
commit 8a4a1d1faa
6 changed files with 240 additions and 43 deletions

View File

@ -3,6 +3,7 @@ import datetime
import typer
from pathlib import Path
from functools import wraps
from copy import deepcopy
from rich.console import Console
from dotenv import load_dotenv
@ -920,13 +921,19 @@ def run_analysis():
selections = get_user_selections()
# Create config with selected research depth
config = DEFAULT_CONFIG.copy()
config = deepcopy(DEFAULT_CONFIG)
config["max_debate_rounds"] = selections["research_depth"]
config["max_risk_discuss_rounds"] = selections["research_depth"]
config["quick_think_llm"] = selections["shallow_thinker"]
config["deep_think_llm"] = selections["deep_thinker"]
config["backend_url"] = selections["backend_url"]
config["llm_provider"] = selections["llm_provider"].lower()
config["llm_routing"] = build_llm_routing_config(
provider=selections["llm_provider"],
shallow_model=selections["shallow_thinker"],
deep_model=selections["deep_thinker"],
backend_url=selections["backend_url"],
)
# Provider-specific thinking configuration
config["google_thinking_level"] = selections.get("google_thinking_level")
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")

View File

@ -1,5 +1,5 @@
import questionary
from typing import List, Optional, Tuple, Dict
from typing import Any, Dict, List, Optional, Tuple
from rich.console import Console
@ -300,6 +300,34 @@ def select_llm_provider() -> tuple[str, str]:
return display_name, url
def build_llm_routing_config(
provider: str,
shallow_model: str,
deep_model: str,
backend_url: Optional[str] = None,
) -> Dict[str, Any]:
default_route: Dict[str, Any] = {
"provider": provider.lower(),
"model": shallow_model,
}
deep_route: Dict[str, Any] = {
"provider": provider.lower(),
"model": deep_model,
}
if backend_url:
default_route["base_url"] = backend_url
deep_route["base_url"] = backend_url
return {
"default": default_route,
"roles": {
"research_manager": deep_route.copy(),
"portfolio_manager": deep_route.copy(),
},
}
def ask_openai_reasoning_effort() -> str:
"""Ask for OpenAI reasoning effort level."""
choices = [

55
tests/test_llm_routing.py Normal file
View File

@ -0,0 +1,55 @@
from copy import deepcopy
from tradingagents.graph.trading_graph import TradingAgentsGraph
class DummyClient:
def __init__(self, provider, model, base_url=None, **kwargs):
self.provider = provider
self.model = model
self.base_url = base_url
self.kwargs = kwargs
def get_llm(self):
return {
"provider": self.provider,
"model": self.model,
"base_url": self.base_url,
"kwargs": self.kwargs,
}
def test_role_specific_llm_config_overrides_default(monkeypatch):
monkeypatch.setattr(
"tradingagents.graph.trading_graph.create_llm_client",
lambda provider, model, base_url=None, **kwargs: DummyClient(
provider, model, base_url, **kwargs
),
)
monkeypatch.setattr(
"tradingagents.graph.trading_graph.FinancialSituationMemory",
lambda *args, **kwargs: object(),
)
monkeypatch.setattr(
"tradingagents.graph.trading_graph.GraphSetup.setup_graph",
lambda self, selected_analysts: {"selected_analysts": selected_analysts},
)
config = {
"llm_routing": {
"default": {"provider": "openai", "model": "gpt-5-mini"},
"roles": {
"portfolio_manager": {
"provider": "openai",
"model": "gpt-5.2",
}
},
}
}
graph = TradingAgentsGraph(
selected_analysts=["market"],
config=deepcopy(config),
)
assert graph.graph_setup.portfolio_manager_llm["model"] == "gpt-5.2"

View File

@ -12,6 +12,10 @@ DEFAULT_CONFIG = {
"deep_think_llm": "gpt-5.2",
"quick_think_llm": "gpt-5-mini",
"backend_url": "https://api.openai.com/v1",
"llm_routing": {
"default": {},
"roles": {},
},
# Provider-specific thinking configuration
"google_thinking_level": None, # "high", "minimal", etc.
"openai_reasoning_effort": None, # "medium", "high", "low"

View File

@ -25,10 +25,12 @@ class GraphSetup:
invest_judge_memory,
portfolio_manager_memory,
conditional_logic: ConditionalLogic,
role_llms: Dict[str, Any] | None = None,
):
"""Initialize with required components."""
self.quick_thinking_llm = quick_thinking_llm
self.deep_thinking_llm = deep_thinking_llm
self.role_llms = role_llms or {}
self.tool_nodes = tool_nodes
self.bull_memory = bull_memory
self.bear_memory = bear_memory
@ -36,6 +38,37 @@ class GraphSetup:
self.invest_judge_memory = invest_judge_memory
self.portfolio_manager_memory = portfolio_manager_memory
self.conditional_logic = conditional_logic
self.market_analyst_llm = self._get_role_llm("market", self.quick_thinking_llm)
self.social_analyst_llm = self._get_role_llm("social", self.quick_thinking_llm)
self.news_analyst_llm = self._get_role_llm("news", self.quick_thinking_llm)
self.fundamentals_analyst_llm = self._get_role_llm(
"fundamentals", self.quick_thinking_llm
)
self.bull_researcher_llm = self._get_role_llm(
"bull_researcher", self.quick_thinking_llm
)
self.bear_researcher_llm = self._get_role_llm(
"bear_researcher", self.quick_thinking_llm
)
self.research_manager_llm = self._get_role_llm(
"research_manager", self.deep_thinking_llm
)
self.trader_llm = self._get_role_llm("trader", self.quick_thinking_llm)
self.aggressive_analyst_llm = self._get_role_llm(
"aggressive_analyst", self.quick_thinking_llm
)
self.neutral_analyst_llm = self._get_role_llm(
"neutral_analyst", self.quick_thinking_llm
)
self.conservative_analyst_llm = self._get_role_llm(
"conservative_analyst", self.quick_thinking_llm
)
self.portfolio_manager_llm = self._get_role_llm(
"portfolio_manager", self.deep_thinking_llm
)
def _get_role_llm(self, role: str, fallback_llm: ChatOpenAI):
return self.role_llms.get(role, fallback_llm)
def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"]
@ -59,50 +92,52 @@ class GraphSetup:
if "market" in selected_analysts:
analyst_nodes["market"] = create_market_analyst(
self.quick_thinking_llm
self.market_analyst_llm
)
delete_nodes["market"] = create_msg_delete()
tool_nodes["market"] = self.tool_nodes["market"]
if "social" in selected_analysts:
analyst_nodes["social"] = create_social_media_analyst(
self.quick_thinking_llm
self.social_analyst_llm
)
delete_nodes["social"] = create_msg_delete()
tool_nodes["social"] = self.tool_nodes["social"]
if "news" in selected_analysts:
analyst_nodes["news"] = create_news_analyst(
self.quick_thinking_llm
self.news_analyst_llm
)
delete_nodes["news"] = create_msg_delete()
tool_nodes["news"] = self.tool_nodes["news"]
if "fundamentals" in selected_analysts:
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
self.quick_thinking_llm
self.fundamentals_analyst_llm
)
delete_nodes["fundamentals"] = create_msg_delete()
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
# Create researcher and manager nodes
bull_researcher_node = create_bull_researcher(
self.quick_thinking_llm, self.bull_memory
self.bull_researcher_llm, self.bull_memory
)
bear_researcher_node = create_bear_researcher(
self.quick_thinking_llm, self.bear_memory
self.bear_researcher_llm, self.bear_memory
)
research_manager_node = create_research_manager(
self.deep_thinking_llm, self.invest_judge_memory
self.research_manager_llm, self.invest_judge_memory
)
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
trader_node = create_trader(self.trader_llm, self.trader_memory)
# Create risk analysis nodes
aggressive_analyst = create_aggressive_debator(self.quick_thinking_llm)
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
conservative_analyst = create_conservative_debator(self.quick_thinking_llm)
aggressive_analyst = create_aggressive_debator(self.aggressive_analyst_llm)
neutral_analyst = create_neutral_debator(self.neutral_analyst_llm)
conservative_analyst = create_conservative_debator(
self.conservative_analyst_llm
)
portfolio_manager_node = create_portfolio_manager(
self.deep_thinking_llm, self.portfolio_manager_memory
self.portfolio_manager_llm, self.portfolio_manager_memory
)
# Create workflow

View File

@ -1,10 +1,8 @@
# TradingAgents/graph/trading_graph.py
import os
from pathlib import Path
import json
from datetime import date
from typing import Dict, Any, Tuple, List, Optional
from copy import deepcopy
from typing import Dict, Any, List, Optional
from langgraph.prebuilt import ToolNode
@ -43,6 +41,23 @@ from .signal_processing import SignalProcessor
class TradingAgentsGraph:
"""Main class that orchestrates the trading agents framework."""
QUICK_THINKING_ROLES = {
"market",
"social",
"news",
"fundamentals",
"bull_researcher",
"bear_researcher",
"trader",
"aggressive_analyst",
"neutral_analyst",
"conservative_analyst",
}
DEEP_THINKING_ROLES = {
"research_manager",
"portfolio_manager",
}
def __init__(
self,
selected_analysts=["market", "social", "news", "fundamentals"],
@ -59,7 +74,7 @@ class TradingAgentsGraph:
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
"""
self.debug = debug
self.config = config or DEFAULT_CONFIG
self.config = self._build_config(config)
self.callbacks = callbacks or []
# Update the interface's config
@ -71,28 +86,9 @@ class TradingAgentsGraph:
exist_ok=True,
)
# Initialize LLMs with provider-specific thinking configuration
llm_kwargs = self._get_provider_kwargs()
# Add callbacks to kwargs if provided (passed to LLM constructor)
if self.callbacks:
llm_kwargs["callbacks"] = self.callbacks
deep_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["deep_think_llm"],
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
quick_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["quick_think_llm"],
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
self.deep_thinking_llm = deep_client.get_llm()
self.quick_thinking_llm = quick_client.get_llm()
self.quick_thinking_llm = self._create_legacy_llm("quick")
self.deep_thinking_llm = self._create_legacy_llm("deep")
self.role_llms = self._create_role_llms()
# Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
@ -119,6 +115,7 @@ class TradingAgentsGraph:
self.invest_judge_memory,
self.portfolio_manager_memory,
self.conditional_logic,
role_llms=self.role_llms,
)
self.propagator = Propagator()
@ -133,10 +130,81 @@ class TradingAgentsGraph:
# Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts)
def _get_provider_kwargs(self) -> Dict[str, Any]:
def _build_config(self, config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""Merge user config over defaults without mutating the shared defaults."""
return self._deep_merge_dicts(DEFAULT_CONFIG, config or {})
def _deep_merge_dicts(
self,
base: Dict[str, Any],
override: Dict[str, Any],
) -> Dict[str, Any]:
merged = deepcopy(base)
for key, value in override.items():
if isinstance(value, dict) and isinstance(merged.get(key), dict):
merged[key] = self._deep_merge_dicts(merged[key], value)
else:
merged[key] = deepcopy(value)
return merged
def _create_legacy_llm(self, thinker_depth: str):
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
provider = self.config["llm_provider"]
llm_kwargs = self._get_provider_kwargs(provider)
if self.callbacks:
llm_kwargs["callbacks"] = self.callbacks
client = create_llm_client(
provider=provider,
model=self.config[model_key],
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
return client.get_llm()
def _create_role_llms(self) -> Dict[str, Any]:
role_llms = {}
for role in self.QUICK_THINKING_ROLES | self.DEEP_THINKING_ROLES:
thinker_depth = "deep" if role in self.DEEP_THINKING_ROLES else "quick"
role_llms[role] = self._create_routed_llm(role, thinker_depth)
return role_llms
def _create_routed_llm(self, role: str, thinker_depth: str):
llm_config = self._resolve_llm_config(role, thinker_depth)
llm_kwargs = self._get_provider_kwargs(llm_config["provider"])
if self.callbacks:
llm_kwargs["callbacks"] = self.callbacks
client = create_llm_client(
provider=llm_config["provider"],
model=llm_config["model"],
base_url=llm_config.get("base_url"),
**llm_kwargs,
)
return client.get_llm()
def _resolve_llm_config(
self,
role: str,
thinker_depth: str,
) -> Dict[str, Any]:
routing = self.config.get("llm_routing") or {}
role_routes = routing.get("roles") or {}
route = role_routes.get(role) or routing.get("default") or {}
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
provider = route.get("provider", self.config["llm_provider"]).lower()
return {
"provider": provider,
"model": route.get("model", self.config[model_key]),
"base_url": route.get("base_url", self.config.get("backend_url")),
}
def _get_provider_kwargs(self, provider: Optional[str] = None) -> Dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation."""
kwargs = {}
provider = self.config.get("llm_provider", "").lower()
provider = (provider or self.config.get("llm_provider", "")).lower()
if provider == "google":
thinking_level = self.config.get("google_thinking_level")