diff --git a/cli/main.py b/cli/main.py index 53837db2..c72e92e5 100644 --- a/cli/main.py +++ b/cli/main.py @@ -3,6 +3,7 @@ import datetime import typer from pathlib import Path from functools import wraps +from copy import deepcopy from rich.console import Console from dotenv import load_dotenv @@ -920,13 +921,19 @@ def run_analysis(): selections = get_user_selections() # Create config with selected research depth - config = DEFAULT_CONFIG.copy() + config = deepcopy(DEFAULT_CONFIG) config["max_debate_rounds"] = selections["research_depth"] config["max_risk_discuss_rounds"] = selections["research_depth"] config["quick_think_llm"] = selections["shallow_thinker"] config["deep_think_llm"] = selections["deep_thinker"] config["backend_url"] = selections["backend_url"] config["llm_provider"] = selections["llm_provider"].lower() + config["llm_routing"] = build_llm_routing_config( + provider=selections["llm_provider"], + shallow_model=selections["shallow_thinker"], + deep_model=selections["deep_thinker"], + backend_url=selections["backend_url"], + ) # Provider-specific thinking configuration config["google_thinking_level"] = selections.get("google_thinking_level") config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") diff --git a/cli/utils.py b/cli/utils.py index 18abc3a7..3cd418c2 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,5 +1,5 @@ import questionary -from typing import List, Optional, Tuple, Dict +from typing import Any, Dict, List, Optional, Tuple from rich.console import Console @@ -300,6 +300,34 @@ def select_llm_provider() -> tuple[str, str]: return display_name, url +def build_llm_routing_config( + provider: str, + shallow_model: str, + deep_model: str, + backend_url: Optional[str] = None, +) -> Dict[str, Any]: + default_route: Dict[str, Any] = { + "provider": provider.lower(), + "model": shallow_model, + } + deep_route: Dict[str, Any] = { + "provider": provider.lower(), + "model": deep_model, + } + + if backend_url: + default_route["base_url"] = backend_url + deep_route["base_url"] = backend_url + + return { + "default": default_route, + "roles": { + "research_manager": deep_route.copy(), + "portfolio_manager": deep_route.copy(), + }, + } + + def ask_openai_reasoning_effort() -> str: """Ask for OpenAI reasoning effort level.""" choices = [ diff --git a/tests/test_llm_routing.py b/tests/test_llm_routing.py new file mode 100644 index 00000000..58ec406b --- /dev/null +++ b/tests/test_llm_routing.py @@ -0,0 +1,55 @@ +from copy import deepcopy + +from tradingagents.graph.trading_graph import TradingAgentsGraph + + +class DummyClient: + def __init__(self, provider, model, base_url=None, **kwargs): + self.provider = provider + self.model = model + self.base_url = base_url + self.kwargs = kwargs + + def get_llm(self): + return { + "provider": self.provider, + "model": self.model, + "base_url": self.base_url, + "kwargs": self.kwargs, + } + + +def test_role_specific_llm_config_overrides_default(monkeypatch): + monkeypatch.setattr( + "tradingagents.graph.trading_graph.create_llm_client", + lambda provider, model, base_url=None, **kwargs: DummyClient( + provider, model, base_url, **kwargs + ), + ) + monkeypatch.setattr( + "tradingagents.graph.trading_graph.FinancialSituationMemory", + lambda *args, **kwargs: object(), + ) + monkeypatch.setattr( + "tradingagents.graph.trading_graph.GraphSetup.setup_graph", + lambda self, selected_analysts: {"selected_analysts": selected_analysts}, + ) + + config = { + "llm_routing": { + "default": {"provider": "openai", "model": "gpt-5-mini"}, + "roles": { + "portfolio_manager": { + "provider": "openai", + "model": "gpt-5.2", + } + }, + } + } + + graph = TradingAgentsGraph( + selected_analysts=["market"], + config=deepcopy(config), + ) + + assert graph.graph_setup.portfolio_manager_llm["model"] == "gpt-5.2" diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 898e1e1e..2a3e7f2e 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -12,6 +12,10 @@ DEFAULT_CONFIG = { "deep_think_llm": "gpt-5.2", "quick_think_llm": "gpt-5-mini", "backend_url": "https://api.openai.com/v1", + "llm_routing": { + "default": {}, + "roles": {}, + }, # Provider-specific thinking configuration "google_thinking_level": None, # "high", "minimal", etc. "openai_reasoning_effort": None, # "medium", "high", "low" diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index e0771c65..2388f7c9 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -25,10 +25,12 @@ class GraphSetup: invest_judge_memory, portfolio_manager_memory, conditional_logic: ConditionalLogic, + role_llms: Dict[str, Any] | None = None, ): """Initialize with required components.""" self.quick_thinking_llm = quick_thinking_llm self.deep_thinking_llm = deep_thinking_llm + self.role_llms = role_llms or {} self.tool_nodes = tool_nodes self.bull_memory = bull_memory self.bear_memory = bear_memory @@ -36,6 +38,37 @@ class GraphSetup: self.invest_judge_memory = invest_judge_memory self.portfolio_manager_memory = portfolio_manager_memory self.conditional_logic = conditional_logic + self.market_analyst_llm = self._get_role_llm("market", self.quick_thinking_llm) + self.social_analyst_llm = self._get_role_llm("social", self.quick_thinking_llm) + self.news_analyst_llm = self._get_role_llm("news", self.quick_thinking_llm) + self.fundamentals_analyst_llm = self._get_role_llm( + "fundamentals", self.quick_thinking_llm + ) + self.bull_researcher_llm = self._get_role_llm( + "bull_researcher", self.quick_thinking_llm + ) + self.bear_researcher_llm = self._get_role_llm( + "bear_researcher", self.quick_thinking_llm + ) + self.research_manager_llm = self._get_role_llm( + "research_manager", self.deep_thinking_llm + ) + self.trader_llm = self._get_role_llm("trader", self.quick_thinking_llm) + self.aggressive_analyst_llm = self._get_role_llm( + "aggressive_analyst", self.quick_thinking_llm + ) + self.neutral_analyst_llm = self._get_role_llm( + "neutral_analyst", self.quick_thinking_llm + ) + self.conservative_analyst_llm = self._get_role_llm( + "conservative_analyst", self.quick_thinking_llm + ) + self.portfolio_manager_llm = self._get_role_llm( + "portfolio_manager", self.deep_thinking_llm + ) + + def _get_role_llm(self, role: str, fallback_llm: ChatOpenAI): + return self.role_llms.get(role, fallback_llm) def setup_graph( self, selected_analysts=["market", "social", "news", "fundamentals"] @@ -59,50 +92,52 @@ class GraphSetup: if "market" in selected_analysts: analyst_nodes["market"] = create_market_analyst( - self.quick_thinking_llm + self.market_analyst_llm ) delete_nodes["market"] = create_msg_delete() tool_nodes["market"] = self.tool_nodes["market"] if "social" in selected_analysts: analyst_nodes["social"] = create_social_media_analyst( - self.quick_thinking_llm + self.social_analyst_llm ) delete_nodes["social"] = create_msg_delete() tool_nodes["social"] = self.tool_nodes["social"] if "news" in selected_analysts: analyst_nodes["news"] = create_news_analyst( - self.quick_thinking_llm + self.news_analyst_llm ) delete_nodes["news"] = create_msg_delete() tool_nodes["news"] = self.tool_nodes["news"] if "fundamentals" in selected_analysts: analyst_nodes["fundamentals"] = create_fundamentals_analyst( - self.quick_thinking_llm + self.fundamentals_analyst_llm ) delete_nodes["fundamentals"] = create_msg_delete() tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"] # Create researcher and manager nodes bull_researcher_node = create_bull_researcher( - self.quick_thinking_llm, self.bull_memory + self.bull_researcher_llm, self.bull_memory ) bear_researcher_node = create_bear_researcher( - self.quick_thinking_llm, self.bear_memory + self.bear_researcher_llm, self.bear_memory ) research_manager_node = create_research_manager( - self.deep_thinking_llm, self.invest_judge_memory + self.research_manager_llm, self.invest_judge_memory ) - trader_node = create_trader(self.quick_thinking_llm, self.trader_memory) + trader_node = create_trader(self.trader_llm, self.trader_memory) # Create risk analysis nodes - aggressive_analyst = create_aggressive_debator(self.quick_thinking_llm) - neutral_analyst = create_neutral_debator(self.quick_thinking_llm) - conservative_analyst = create_conservative_debator(self.quick_thinking_llm) + aggressive_analyst = create_aggressive_debator(self.aggressive_analyst_llm) + neutral_analyst = create_neutral_debator(self.neutral_analyst_llm) + conservative_analyst = create_conservative_debator( + self.conservative_analyst_llm + ) portfolio_manager_node = create_portfolio_manager( - self.deep_thinking_llm, self.portfolio_manager_memory + self.portfolio_manager_llm, self.portfolio_manager_memory ) # Create workflow diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index c8cd7492..ca6f10b4 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -1,10 +1,8 @@ # TradingAgents/graph/trading_graph.py import os -from pathlib import Path -import json -from datetime import date -from typing import Dict, Any, Tuple, List, Optional +from copy import deepcopy +from typing import Dict, Any, List, Optional from langgraph.prebuilt import ToolNode @@ -43,6 +41,23 @@ from .signal_processing import SignalProcessor class TradingAgentsGraph: """Main class that orchestrates the trading agents framework.""" + QUICK_THINKING_ROLES = { + "market", + "social", + "news", + "fundamentals", + "bull_researcher", + "bear_researcher", + "trader", + "aggressive_analyst", + "neutral_analyst", + "conservative_analyst", + } + DEEP_THINKING_ROLES = { + "research_manager", + "portfolio_manager", + } + def __init__( self, selected_analysts=["market", "social", "news", "fundamentals"], @@ -59,7 +74,7 @@ class TradingAgentsGraph: callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats) """ self.debug = debug - self.config = config or DEFAULT_CONFIG + self.config = self._build_config(config) self.callbacks = callbacks or [] # Update the interface's config @@ -71,28 +86,9 @@ class TradingAgentsGraph: exist_ok=True, ) - # Initialize LLMs with provider-specific thinking configuration - llm_kwargs = self._get_provider_kwargs() - - # Add callbacks to kwargs if provided (passed to LLM constructor) - if self.callbacks: - llm_kwargs["callbacks"] = self.callbacks - - deep_client = create_llm_client( - provider=self.config["llm_provider"], - model=self.config["deep_think_llm"], - base_url=self.config.get("backend_url"), - **llm_kwargs, - ) - quick_client = create_llm_client( - provider=self.config["llm_provider"], - model=self.config["quick_think_llm"], - base_url=self.config.get("backend_url"), - **llm_kwargs, - ) - - self.deep_thinking_llm = deep_client.get_llm() - self.quick_thinking_llm = quick_client.get_llm() + self.quick_thinking_llm = self._create_legacy_llm("quick") + self.deep_thinking_llm = self._create_legacy_llm("deep") + self.role_llms = self._create_role_llms() # Initialize memories self.bull_memory = FinancialSituationMemory("bull_memory", self.config) @@ -119,6 +115,7 @@ class TradingAgentsGraph: self.invest_judge_memory, self.portfolio_manager_memory, self.conditional_logic, + role_llms=self.role_llms, ) self.propagator = Propagator() @@ -133,10 +130,81 @@ class TradingAgentsGraph: # Set up the graph self.graph = self.graph_setup.setup_graph(selected_analysts) - def _get_provider_kwargs(self) -> Dict[str, Any]: + def _build_config(self, config: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """Merge user config over defaults without mutating the shared defaults.""" + return self._deep_merge_dicts(DEFAULT_CONFIG, config or {}) + + def _deep_merge_dicts( + self, + base: Dict[str, Any], + override: Dict[str, Any], + ) -> Dict[str, Any]: + merged = deepcopy(base) + for key, value in override.items(): + if isinstance(value, dict) and isinstance(merged.get(key), dict): + merged[key] = self._deep_merge_dicts(merged[key], value) + else: + merged[key] = deepcopy(value) + return merged + + def _create_legacy_llm(self, thinker_depth: str): + model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm" + provider = self.config["llm_provider"] + llm_kwargs = self._get_provider_kwargs(provider) + if self.callbacks: + llm_kwargs["callbacks"] = self.callbacks + + client = create_llm_client( + provider=provider, + model=self.config[model_key], + base_url=self.config.get("backend_url"), + **llm_kwargs, + ) + return client.get_llm() + + def _create_role_llms(self) -> Dict[str, Any]: + role_llms = {} + for role in self.QUICK_THINKING_ROLES | self.DEEP_THINKING_ROLES: + thinker_depth = "deep" if role in self.DEEP_THINKING_ROLES else "quick" + role_llms[role] = self._create_routed_llm(role, thinker_depth) + return role_llms + + def _create_routed_llm(self, role: str, thinker_depth: str): + llm_config = self._resolve_llm_config(role, thinker_depth) + llm_kwargs = self._get_provider_kwargs(llm_config["provider"]) + if self.callbacks: + llm_kwargs["callbacks"] = self.callbacks + + client = create_llm_client( + provider=llm_config["provider"], + model=llm_config["model"], + base_url=llm_config.get("base_url"), + **llm_kwargs, + ) + return client.get_llm() + + def _resolve_llm_config( + self, + role: str, + thinker_depth: str, + ) -> Dict[str, Any]: + routing = self.config.get("llm_routing") or {} + role_routes = routing.get("roles") or {} + route = role_routes.get(role) or routing.get("default") or {} + + model_key = "deep_think_llm" if thinker_depth == "deep" else "quick_think_llm" + provider = route.get("provider", self.config["llm_provider"]).lower() + + return { + "provider": provider, + "model": route.get("model", self.config[model_key]), + "base_url": route.get("base_url", self.config.get("backend_url")), + } + + def _get_provider_kwargs(self, provider: Optional[str] = None) -> Dict[str, Any]: """Get provider-specific kwargs for LLM client creation.""" kwargs = {} - provider = self.config.get("llm_provider", "").lower() + provider = (provider or self.config.get("llm_provider", "")).lower() if provider == "google": thinking_level = self.config.get("google_thinking_level")