feat: add role-based llm routing
This commit is contained in:
parent
4dbcdcc8a8
commit
8a4a1d1faa
|
|
@ -3,6 +3,7 @@ import datetime
|
|||
import typer
|
||||
from pathlib import Path
|
||||
from functools import wraps
|
||||
from copy import deepcopy
|
||||
from rich.console import Console
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
|
@ -920,13 +921,19 @@ def run_analysis():
|
|||
selections = get_user_selections()
|
||||
|
||||
# Create config with selected research depth
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config = deepcopy(DEFAULT_CONFIG)
|
||||
config["max_debate_rounds"] = selections["research_depth"]
|
||||
config["max_risk_discuss_rounds"] = selections["research_depth"]
|
||||
config["quick_think_llm"] = selections["shallow_thinker"]
|
||||
config["deep_think_llm"] = selections["deep_thinker"]
|
||||
config["backend_url"] = selections["backend_url"]
|
||||
config["llm_provider"] = selections["llm_provider"].lower()
|
||||
config["llm_routing"] = build_llm_routing_config(
|
||||
provider=selections["llm_provider"],
|
||||
shallow_model=selections["shallow_thinker"],
|
||||
deep_model=selections["deep_thinker"],
|
||||
backend_url=selections["backend_url"],
|
||||
)
|
||||
# Provider-specific thinking configuration
|
||||
config["google_thinking_level"] = selections.get("google_thinking_level")
|
||||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||
|
|
|
|||
30
cli/utils.py
30
cli/utils.py
|
|
@ -1,5 +1,5 @@
|
|||
import questionary
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
|
|
@ -300,6 +300,34 @@ def select_llm_provider() -> tuple[str, str]:
|
|||
return display_name, url
|
||||
|
||||
|
||||
def build_llm_routing_config(
|
||||
provider: str,
|
||||
shallow_model: str,
|
||||
deep_model: str,
|
||||
backend_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
default_route: Dict[str, Any] = {
|
||||
"provider": provider.lower(),
|
||||
"model": shallow_model,
|
||||
}
|
||||
deep_route: Dict[str, Any] = {
|
||||
"provider": provider.lower(),
|
||||
"model": deep_model,
|
||||
}
|
||||
|
||||
if backend_url:
|
||||
default_route["base_url"] = backend_url
|
||||
deep_route["base_url"] = backend_url
|
||||
|
||||
return {
|
||||
"default": default_route,
|
||||
"roles": {
|
||||
"research_manager": deep_route.copy(),
|
||||
"portfolio_manager": deep_route.copy(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def ask_openai_reasoning_effort() -> str:
|
||||
"""Ask for OpenAI reasoning effort level."""
|
||||
choices = [
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
from copy import deepcopy
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self, provider, model, base_url=None, **kwargs):
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.kwargs = kwargs
|
||||
|
||||
def get_llm(self):
|
||||
return {
|
||||
"provider": self.provider,
|
||||
"model": self.model,
|
||||
"base_url": self.base_url,
|
||||
"kwargs": self.kwargs,
|
||||
}
|
||||
|
||||
|
||||
def test_role_specific_llm_config_overrides_default(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.trading_graph.create_llm_client",
|
||||
lambda provider, model, base_url=None, **kwargs: DummyClient(
|
||||
provider, model, base_url, **kwargs
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.trading_graph.FinancialSituationMemory",
|
||||
lambda *args, **kwargs: object(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"tradingagents.graph.trading_graph.GraphSetup.setup_graph",
|
||||
lambda self, selected_analysts: {"selected_analysts": selected_analysts},
|
||||
)
|
||||
|
||||
config = {
|
||||
"llm_routing": {
|
||||
"default": {"provider": "openai", "model": "gpt-5-mini"},
|
||||
"roles": {
|
||||
"portfolio_manager": {
|
||||
"provider": "openai",
|
||||
"model": "gpt-5.2",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
graph = TradingAgentsGraph(
|
||||
selected_analysts=["market"],
|
||||
config=deepcopy(config),
|
||||
)
|
||||
|
||||
assert graph.graph_setup.portfolio_manager_llm["model"] == "gpt-5.2"
|
||||
|
|
@ -12,6 +12,10 @@ DEFAULT_CONFIG = {
|
|||
"deep_think_llm": "gpt-5.2",
|
||||
"quick_think_llm": "gpt-5-mini",
|
||||
"backend_url": "https://api.openai.com/v1",
|
||||
"llm_routing": {
|
||||
"default": {},
|
||||
"roles": {},
|
||||
},
|
||||
# Provider-specific thinking configuration
|
||||
"google_thinking_level": None, # "high", "minimal", etc.
|
||||
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||
|
|
|
|||
|
|
@ -25,10 +25,12 @@ class GraphSetup:
|
|||
invest_judge_memory,
|
||||
portfolio_manager_memory,
|
||||
conditional_logic: ConditionalLogic,
|
||||
role_llms: Dict[str, Any] | None = None,
|
||||
):
|
||||
"""Initialize with required components."""
|
||||
self.quick_thinking_llm = quick_thinking_llm
|
||||
self.deep_thinking_llm = deep_thinking_llm
|
||||
self.role_llms = role_llms or {}
|
||||
self.tool_nodes = tool_nodes
|
||||
self.bull_memory = bull_memory
|
||||
self.bear_memory = bear_memory
|
||||
|
|
@ -36,6 +38,37 @@ class GraphSetup:
|
|||
self.invest_judge_memory = invest_judge_memory
|
||||
self.portfolio_manager_memory = portfolio_manager_memory
|
||||
self.conditional_logic = conditional_logic
|
||||
self.market_analyst_llm = self._get_role_llm("market", self.quick_thinking_llm)
|
||||
self.social_analyst_llm = self._get_role_llm("social", self.quick_thinking_llm)
|
||||
self.news_analyst_llm = self._get_role_llm("news", self.quick_thinking_llm)
|
||||
self.fundamentals_analyst_llm = self._get_role_llm(
|
||||
"fundamentals", self.quick_thinking_llm
|
||||
)
|
||||
self.bull_researcher_llm = self._get_role_llm(
|
||||
"bull_researcher", self.quick_thinking_llm
|
||||
)
|
||||
self.bear_researcher_llm = self._get_role_llm(
|
||||
"bear_researcher", self.quick_thinking_llm
|
||||
)
|
||||
self.research_manager_llm = self._get_role_llm(
|
||||
"research_manager", self.deep_thinking_llm
|
||||
)
|
||||
self.trader_llm = self._get_role_llm("trader", self.quick_thinking_llm)
|
||||
self.aggressive_analyst_llm = self._get_role_llm(
|
||||
"aggressive_analyst", self.quick_thinking_llm
|
||||
)
|
||||
self.neutral_analyst_llm = self._get_role_llm(
|
||||
"neutral_analyst", self.quick_thinking_llm
|
||||
)
|
||||
self.conservative_analyst_llm = self._get_role_llm(
|
||||
"conservative_analyst", self.quick_thinking_llm
|
||||
)
|
||||
self.portfolio_manager_llm = self._get_role_llm(
|
||||
"portfolio_manager", self.deep_thinking_llm
|
||||
)
|
||||
|
||||
def _get_role_llm(self, role: str, fallback_llm: ChatOpenAI):
|
||||
return self.role_llms.get(role, fallback_llm)
|
||||
|
||||
def setup_graph(
|
||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
||||
|
|
@ -59,50 +92,52 @@ class GraphSetup:
|
|||
|
||||
if "market" in selected_analysts:
|
||||
analyst_nodes["market"] = create_market_analyst(
|
||||
self.quick_thinking_llm
|
||||
self.market_analyst_llm
|
||||
)
|
||||
delete_nodes["market"] = create_msg_delete()
|
||||
tool_nodes["market"] = self.tool_nodes["market"]
|
||||
|
||||
if "social" in selected_analysts:
|
||||
analyst_nodes["social"] = create_social_media_analyst(
|
||||
self.quick_thinking_llm
|
||||
self.social_analyst_llm
|
||||
)
|
||||
delete_nodes["social"] = create_msg_delete()
|
||||
tool_nodes["social"] = self.tool_nodes["social"]
|
||||
|
||||
if "news" in selected_analysts:
|
||||
analyst_nodes["news"] = create_news_analyst(
|
||||
self.quick_thinking_llm
|
||||
self.news_analyst_llm
|
||||
)
|
||||
delete_nodes["news"] = create_msg_delete()
|
||||
tool_nodes["news"] = self.tool_nodes["news"]
|
||||
|
||||
if "fundamentals" in selected_analysts:
|
||||
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
|
||||
self.quick_thinking_llm
|
||||
self.fundamentals_analyst_llm
|
||||
)
|
||||
delete_nodes["fundamentals"] = create_msg_delete()
|
||||
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
||||
|
||||
# Create researcher and manager nodes
|
||||
bull_researcher_node = create_bull_researcher(
|
||||
self.quick_thinking_llm, self.bull_memory
|
||||
self.bull_researcher_llm, self.bull_memory
|
||||
)
|
||||
bear_researcher_node = create_bear_researcher(
|
||||
self.quick_thinking_llm, self.bear_memory
|
||||
self.bear_researcher_llm, self.bear_memory
|
||||
)
|
||||
research_manager_node = create_research_manager(
|
||||
self.deep_thinking_llm, self.invest_judge_memory
|
||||
self.research_manager_llm, self.invest_judge_memory
|
||||
)
|
||||
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
|
||||
trader_node = create_trader(self.trader_llm, self.trader_memory)
|
||||
|
||||
# Create risk analysis nodes
|
||||
aggressive_analyst = create_aggressive_debator(self.quick_thinking_llm)
|
||||
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
|
||||
conservative_analyst = create_conservative_debator(self.quick_thinking_llm)
|
||||
aggressive_analyst = create_aggressive_debator(self.aggressive_analyst_llm)
|
||||
neutral_analyst = create_neutral_debator(self.neutral_analyst_llm)
|
||||
conservative_analyst = create_conservative_debator(
|
||||
self.conservative_analyst_llm
|
||||
)
|
||||
portfolio_manager_node = create_portfolio_manager(
|
||||
self.deep_thinking_llm, self.portfolio_manager_memory
|
||||
self.portfolio_manager_llm, self.portfolio_manager_memory
|
||||
)
|
||||
|
||||
# Create workflow
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
# TradingAgents/graph/trading_graph.py
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import json
|
||||
from datetime import date
|
||||
from typing import Dict, Any, Tuple, List, Optional
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
|
|
@ -43,6 +41,23 @@ from .signal_processing import SignalProcessor
|
|||
class TradingAgentsGraph:
|
||||
"""Main class that orchestrates the trading agents framework."""
|
||||
|
||||
QUICK_THINKING_ROLES = {
|
||||
"market",
|
||||
"social",
|
||||
"news",
|
||||
"fundamentals",
|
||||
"bull_researcher",
|
||||
"bear_researcher",
|
||||
"trader",
|
||||
"aggressive_analyst",
|
||||
"neutral_analyst",
|
||||
"conservative_analyst",
|
||||
}
|
||||
DEEP_THINKING_ROLES = {
|
||||
"research_manager",
|
||||
"portfolio_manager",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
|
|
@ -59,7 +74,7 @@ class TradingAgentsGraph:
|
|||
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
|
||||
"""
|
||||
self.debug = debug
|
||||
self.config = config or DEFAULT_CONFIG
|
||||
self.config = self._build_config(config)
|
||||
self.callbacks = callbacks or []
|
||||
|
||||
# Update the interface's config
|
||||
|
|
@ -71,28 +86,9 @@ class TradingAgentsGraph:
|
|||
exist_ok=True,
|
||||
)
|
||||
|
||||
# Initialize LLMs with provider-specific thinking configuration
|
||||
llm_kwargs = self._get_provider_kwargs()
|
||||
|
||||
# Add callbacks to kwargs if provided (passed to LLM constructor)
|
||||
if self.callbacks:
|
||||
llm_kwargs["callbacks"] = self.callbacks
|
||||
|
||||
deep_client = create_llm_client(
|
||||
provider=self.config["llm_provider"],
|
||||
model=self.config["deep_think_llm"],
|
||||
base_url=self.config.get("backend_url"),
|
||||
**llm_kwargs,
|
||||
)
|
||||
quick_client = create_llm_client(
|
||||
provider=self.config["llm_provider"],
|
||||
model=self.config["quick_think_llm"],
|
||||
base_url=self.config.get("backend_url"),
|
||||
**llm_kwargs,
|
||||
)
|
||||
|
||||
self.deep_thinking_llm = deep_client.get_llm()
|
||||
self.quick_thinking_llm = quick_client.get_llm()
|
||||
self.quick_thinking_llm = self._create_legacy_llm("quick")
|
||||
self.deep_thinking_llm = self._create_legacy_llm("deep")
|
||||
self.role_llms = self._create_role_llms()
|
||||
|
||||
# Initialize memories
|
||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||
|
|
@ -119,6 +115,7 @@ class TradingAgentsGraph:
|
|||
self.invest_judge_memory,
|
||||
self.portfolio_manager_memory,
|
||||
self.conditional_logic,
|
||||
role_llms=self.role_llms,
|
||||
)
|
||||
|
||||
self.propagator = Propagator()
|
||||
|
|
@ -133,10 +130,81 @@ class TradingAgentsGraph:
|
|||
# Set up the graph
|
||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||||
|
||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||
def _build_config(self, config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Merge user config over defaults without mutating the shared defaults."""
|
||||
return self._deep_merge_dicts(DEFAULT_CONFIG, config or {})
|
||||
|
||||
def _deep_merge_dicts(
|
||||
self,
|
||||
base: Dict[str, Any],
|
||||
override: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
merged = deepcopy(base)
|
||||
for key, value in override.items():
|
||||
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
||||
merged[key] = self._deep_merge_dicts(merged[key], value)
|
||||
else:
|
||||
merged[key] = deepcopy(value)
|
||||
return merged
|
||||
|
||||
def _create_legacy_llm(self, thinker_depth: str):
|
||||
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
|
||||
provider = self.config["llm_provider"]
|
||||
llm_kwargs = self._get_provider_kwargs(provider)
|
||||
if self.callbacks:
|
||||
llm_kwargs["callbacks"] = self.callbacks
|
||||
|
||||
client = create_llm_client(
|
||||
provider=provider,
|
||||
model=self.config[model_key],
|
||||
base_url=self.config.get("backend_url"),
|
||||
**llm_kwargs,
|
||||
)
|
||||
return client.get_llm()
|
||||
|
||||
def _create_role_llms(self) -> Dict[str, Any]:
|
||||
role_llms = {}
|
||||
for role in self.QUICK_THINKING_ROLES | self.DEEP_THINKING_ROLES:
|
||||
thinker_depth = "deep" if role in self.DEEP_THINKING_ROLES else "quick"
|
||||
role_llms[role] = self._create_routed_llm(role, thinker_depth)
|
||||
return role_llms
|
||||
|
||||
def _create_routed_llm(self, role: str, thinker_depth: str):
|
||||
llm_config = self._resolve_llm_config(role, thinker_depth)
|
||||
llm_kwargs = self._get_provider_kwargs(llm_config["provider"])
|
||||
if self.callbacks:
|
||||
llm_kwargs["callbacks"] = self.callbacks
|
||||
|
||||
client = create_llm_client(
|
||||
provider=llm_config["provider"],
|
||||
model=llm_config["model"],
|
||||
base_url=llm_config.get("base_url"),
|
||||
**llm_kwargs,
|
||||
)
|
||||
return client.get_llm()
|
||||
|
||||
def _resolve_llm_config(
|
||||
self,
|
||||
role: str,
|
||||
thinker_depth: str,
|
||||
) -> Dict[str, Any]:
|
||||
routing = self.config.get("llm_routing") or {}
|
||||
role_routes = routing.get("roles") or {}
|
||||
route = role_routes.get(role) or routing.get("default") or {}
|
||||
|
||||
model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm"
|
||||
provider = route.get("provider", self.config["llm_provider"]).lower()
|
||||
|
||||
return {
|
||||
"provider": provider,
|
||||
"model": route.get("model", self.config[model_key]),
|
||||
"base_url": route.get("base_url", self.config.get("backend_url")),
|
||||
}
|
||||
|
||||
def _get_provider_kwargs(self, provider: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get provider-specific kwargs for LLM client creation."""
|
||||
kwargs = {}
|
||||
provider = self.config.get("llm_provider", "").lower()
|
||||
provider = (provider or self.config.get("llm_provider", "")).lower()
|
||||
|
||||
if provider == "google":
|
||||
thinking_level = self.config.get("google_thinking_level")
|
||||
|
|
|
|||
Loading…
Reference in New Issue