diff --git a/cli/main.py b/cli/main.py index fb97d189..1648efba 100644 --- a/cli/main.py +++ b/cli/main.py @@ -459,10 +459,20 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non layout["footer"].update(Panel(stats_table, border_style="grey50")) +def _ask_provider_thinking_config(provider: str): + """Ask for provider-specific thinking config. Returns (thinking_level, reasoning_effort).""" + provider_lower = provider.lower() + if provider_lower == "google": + return ask_gemini_thinking_config(), None + elif provider_lower in ("openai", "xai"): + return None, ask_openai_reasoning_effort() + return None, None + + def get_user_selections(): """Get all user selections before starting the analysis display.""" # Display ASCII art welcome message - with open("./cli/static/welcome.txt", "r") as f: + with open("./cli/static/welcome.txt", "r", encoding="utf-8") as f: welcome_ascii = f.read() # Create welcome box content @@ -536,83 +546,65 @@ def get_user_selections(): ) selected_research_depth = select_research_depth() - # Step 5: OpenAI backend + # Step 5: Quick-thinking provider + model console.print( create_question_box( - "Step 5: OpenAI backend", "Select which service to talk to" + "Step 5: Quick-Thinking Setup", + "Provider and model for analysts & risk debaters (fast, high volume)" ) ) - selected_llm_provider, backend_url = select_llm_provider() - - # Step 6: Thinking agents + quick_provider, quick_backend_url = select_llm_provider() + selected_shallow_thinker = select_shallow_thinking_agent(quick_provider) + quick_thinking_level, quick_reasoning_effort = _ask_provider_thinking_config(quick_provider) + + # Step 6: Mid-thinking provider + model console.print( create_question_box( - "Step 6: Thinking Agents", "Select your thinking agents for analysis" + "Step 6: Mid-Thinking Setup", + "Provider and model for researchers & trader (reasoning, argument formation)" ) ) - selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) - selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) + mid_provider, mid_backend_url = select_llm_provider() + selected_mid_thinker = select_mid_thinking_agent(mid_provider) + mid_thinking_level, mid_reasoning_effort = _ask_provider_thinking_config(mid_provider) - # Step 7: Provider-specific thinking configuration - thinking_level = None - reasoning_effort = None - - provider_lower = selected_llm_provider.lower() - if provider_lower == "google": - console.print( - create_question_box( - "Step 7: Thinking Mode", - "Configure Gemini thinking mode" - ) + # Step 7: Deep-thinking provider + model + console.print( + create_question_box( + "Step 7: Deep-Thinking Setup", + "Provider and model for investment judge & risk manager (final decisions)" ) - thinking_level = ask_gemini_thinking_config() - elif provider_lower == "openai": - console.print( - create_question_box( - "Step 7: Reasoning Effort", - "Configure OpenAI reasoning effort level" - ) - ) - reasoning_effort = ask_openai_reasoning_effort() + ) + deep_provider, deep_backend_url = select_llm_provider() + selected_deep_thinker = select_deep_thinking_agent(deep_provider) + deep_thinking_level, deep_reasoning_effort = _ask_provider_thinking_config(deep_provider) return { "ticker": selected_ticker, "analysis_date": analysis_date, "analysts": selected_analysts, "research_depth": selected_research_depth, - "llm_provider": selected_llm_provider.lower(), - "backend_url": backend_url, + # Quick + "quick_provider": quick_provider.lower(), + "quick_backend_url": quick_backend_url, "shallow_thinker": selected_shallow_thinker, + "quick_thinking_level": quick_thinking_level, + "quick_reasoning_effort": quick_reasoning_effort, + # Mid + "mid_provider": mid_provider.lower(), + "mid_backend_url": mid_backend_url, + "mid_thinker": selected_mid_thinker, + "mid_thinking_level": mid_thinking_level, + "mid_reasoning_effort": mid_reasoning_effort, + # Deep + "deep_provider": deep_provider.lower(), + "deep_backend_url": deep_backend_url, "deep_thinker": selected_deep_thinker, - "google_thinking_level": thinking_level, - "openai_reasoning_effort": reasoning_effort, + "deep_thinking_level": deep_thinking_level, + "deep_reasoning_effort": deep_reasoning_effort, } -def get_ticker(): - """Get ticker symbol from user input.""" - return typer.prompt("", default="SPY") - - -def get_analysis_date(): - """Get the analysis date from user input.""" - while True: - date_str = typer.prompt( - "", default=datetime.datetime.now().strftime("%Y-%m-%d") - ) - try: - # Validate date format and ensure it's not in the future - analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d") - if analysis_date.date() > datetime.datetime.now().date(): - console.print("[red]Error: Analysis date cannot be in the future[/red]") - continue - return date_str - except ValueError: - console.print( - "[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]" - ) - - def save_report_to_disk(final_state, ticker: str, save_path: Path): """Save complete analysis report to disk with organized subfolders.""" save_path.mkdir(parents=True, exist_ok=True) @@ -904,13 +896,25 @@ def run_analysis(): config = DEFAULT_CONFIG.copy() config["max_debate_rounds"] = selections["research_depth"] config["max_risk_discuss_rounds"] = selections["research_depth"] + # Per-role LLM configuration config["quick_think_llm"] = selections["shallow_thinker"] + config["quick_think_llm_provider"] = selections["quick_provider"] + config["quick_think_backend_url"] = selections["quick_backend_url"] + config["quick_think_google_thinking_level"] = selections.get("quick_thinking_level") + config["quick_think_openai_reasoning_effort"] = selections.get("quick_reasoning_effort") + config["mid_think_llm"] = selections["mid_thinker"] + config["mid_think_llm_provider"] = selections["mid_provider"] + config["mid_think_backend_url"] = selections["mid_backend_url"] + config["mid_think_google_thinking_level"] = selections.get("mid_thinking_level") + config["mid_think_openai_reasoning_effort"] = selections.get("mid_reasoning_effort") config["deep_think_llm"] = selections["deep_thinker"] - config["backend_url"] = selections["backend_url"] - config["llm_provider"] = selections["llm_provider"].lower() - # Provider-specific thinking configuration - config["google_thinking_level"] = selections.get("google_thinking_level") - config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") + config["deep_think_llm_provider"] = selections["deep_provider"] + config["deep_think_backend_url"] = selections["deep_backend_url"] + config["deep_think_google_thinking_level"] = selections.get("deep_thinking_level") + config["deep_think_openai_reasoning_effort"] = selections.get("deep_reasoning_effort") + # Keep shared llm_provider/backend_url as a fallback (use quick as default) + config["llm_provider"] = selections["quick_provider"] + config["backend_url"] = selections["quick_backend_url"] # Create stats callback handler for tracking LLM/tool calls stats_handler = StatsCallbackHandler() @@ -948,10 +952,10 @@ def run_analysis(): func(*args, **kwargs) timestamp, message_type, content = obj.messages[-1] content = content.replace("\n", " ") # Replace newlines with spaces - with open(log_file, "a") as f: + with open(log_file, "a", encoding="utf-8") as f: f.write(f"{timestamp} [{message_type}] {content}\n") return wrapper - + def save_tool_call_decorator(obj, func_name): func = getattr(obj, func_name) @wraps(func) @@ -959,7 +963,7 @@ def run_analysis(): func(*args, **kwargs) timestamp, tool_name, args = obj.tool_calls[-1] args_str = ", ".join(f"{k}={v}" for k, v in args.items()) - with open(log_file, "a") as f: + with open(log_file, "a", encoding="utf-8") as f: f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n") return wrapper @@ -972,7 +976,7 @@ def run_analysis(): content = obj.report_sections[section_name] if content: file_name = f"{section_name}.md" - with open(report_dir / file_name, "w") as f: + with open(report_dir / file_name, "w", encoding="utf-8") as f: f.write(content) return wrapper diff --git a/cli/utils.py b/cli/utils.py index aa097fb5..f02febfd 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,8 +1,24 @@ import questionary +import requests from typing import List, Optional, Tuple, Dict +from rich.console import Console from cli.models import AnalystType +console = Console() + + +def _fetch_ollama_models(base_url: str = "http://localhost:11434") -> list[tuple[str, str]]: + """Fetch available models from a running Ollama instance.""" + try: + resp = requests.get(f"{base_url}/api/tags", timeout=5) + resp.raise_for_status() + models = resp.json().get("models", []) + return [(m["name"], m["name"]) for m in models] if models else [] + except Exception: + return [] + + ANALYST_ORDER = [ ("Market Analyst", AnalystType.MARKET), ("Social Media Analyst", AnalystType.SOCIAL), @@ -125,48 +141,56 @@ def select_research_depth() -> int: def select_shallow_thinking_agent(provider) -> str: """Select shallow thinking llm engine using an interactive selection.""" - # Define shallow thinking llm engine options with their corresponding model names - SHALLOW_AGENT_OPTIONS = { - "openai": [ - ("GPT-5 Mini - Cost-optimized reasoning", "gpt-5-mini"), - ("GPT-5 Nano - Ultra-fast, high-throughput", "gpt-5-nano"), - ("GPT-5.2 - Latest flagship", "gpt-5.2"), - ("GPT-5.1 - Flexible reasoning", "gpt-5.1"), - ("GPT-4.1 - Smartest non-reasoning, 1M context", "gpt-4.1"), - ], - "anthropic": [ - ("Claude Haiku 4.5 - Fast + extended thinking", "claude-haiku-4-5"), - ("Claude Sonnet 4.5 - Best for agents/coding", "claude-sonnet-4-5"), - ("Claude Sonnet 4 - High-performance", "claude-sonnet-4-20250514"), - ], - "google": [ - ("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"), - ("Gemini 2.5 Flash - Balanced, recommended", "gemini-2.5-flash"), - ("Gemini 3 Pro - Reasoning-first", "gemini-3-pro-preview"), - ("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"), - ], - "xai": [ - ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), - ("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"), - ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), - ("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"), - ], - "openrouter": [ - ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), - ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), - ], - "ollama": [ - ("Qwen3:latest (8B, local)", "qwen3:latest"), - ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), - ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), - ], - } + _provider = provider.lower() + + if _provider == "ollama": + ollama_models = _fetch_ollama_models() + if not ollama_models: + console.print("[yellow]Could not reach Ollama — is it running? Enter a model name manually.[/yellow]") + model = questionary.text("Model name (e.g. qwen3.5:9b):").ask() + if not model: + console.print("\n[red]No model entered. Exiting...[/red]") + exit(1) + return model.strip() + options = ollama_models + else: + SHALLOW_AGENT_OPTIONS = { + "openai": [ + ("GPT-5 Mini - Cost-optimized reasoning", "gpt-5-mini"), + ("GPT-5 Nano - Ultra-fast, high-throughput", "gpt-5-nano"), + ("GPT-5.2 - Latest flagship", "gpt-5.2"), + ("GPT-5.1 - Flexible reasoning", "gpt-5.1"), + ("GPT-4.1 - Smartest non-reasoning, 1M context", "gpt-4.1"), + ], + "anthropic": [ + ("Claude Haiku 4.5 - Fast + extended thinking", "claude-haiku-4-5"), + ("Claude Sonnet 4.5 - Best for agents/coding", "claude-sonnet-4-5"), + ("Claude Sonnet 4 - High-performance", "claude-sonnet-4-20250514"), + ], + "google": [ + ("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"), + ("Gemini 2.5 Flash - Balanced, recommended", "gemini-2.5-flash"), + ("Gemini 3 Pro - Reasoning-first", "gemini-3-pro-preview"), + ("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"), + ], + "xai": [ + ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), + ("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"), + ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), + ("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"), + ], + "openrouter": [ + ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), + ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), + ], + } + options = SHALLOW_AGENT_OPTIONS[_provider] choice = questionary.select( "Select Your [Quick-Thinking LLM Engine]:", choices=[ questionary.Choice(display, value=value) - for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()] + for display, value in options ], instruction="\n- Use arrow keys to navigate\n- Press Enter to select", style=questionary.Style( @@ -187,54 +211,132 @@ def select_shallow_thinking_agent(provider) -> str: return choice +def select_mid_thinking_agent(provider) -> str: + """Select mid thinking llm engine using an interactive selection.""" + + _provider = provider.lower() + + if _provider == "ollama": + ollama_models = _fetch_ollama_models() + if not ollama_models: + console.print("[yellow]Could not reach Ollama — is it running? Enter a model name manually.[/yellow]") + model = questionary.text("Model name (e.g. qwen3.5:27b):").ask() + if not model: + console.print("\n[red]No model entered. Exiting...[/red]") + exit(1) + return model.strip() + options = ollama_models + else: + MID_AGENT_OPTIONS = { + "openai": [ + ("GPT-5.1 - Flexible reasoning", "gpt-5.1"), + ("GPT-5 - Advanced reasoning", "gpt-5"), + ("GPT-4.1 - Smartest non-reasoning, 1M context", "gpt-4.1"), + ("GPT-5 Mini - Cost-optimized reasoning", "gpt-5-mini"), + ], + "anthropic": [ + ("Claude Sonnet 4.5 - Best for agents/coding", "claude-sonnet-4-5"), + ("Claude Sonnet 4 - High-performance", "claude-sonnet-4-20250514"), + ("Claude Haiku 4.5 - Fast + extended thinking", "claude-haiku-4-5"), + ], + "google": [ + ("Gemini 2.5 Flash - Balanced, recommended", "gemini-2.5-flash"), + ("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"), + ("Gemini 3 Pro - Reasoning-first", "gemini-3-pro-preview"), + ], + "xai": [ + ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), + ("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"), + ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), + ], + "openrouter": [ + ("DeepSeek R1 - Strong open-source reasoning", "deepseek/deepseek-r1"), + ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), + ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), + ], + } + options = MID_AGENT_OPTIONS[_provider] + + choice = questionary.select( + "Select Your [Mid-Thinking LLM Engine]:", + choices=[ + questionary.Choice(display, value=value) + for display, value in options + ], + instruction="\n- Use arrow keys to navigate\n- Press Enter to select", + style=questionary.Style( + [ + ("selected", "fg:magenta noinherit"), + ("highlighted", "fg:magenta noinherit"), + ("pointer", "fg:magenta noinherit"), + ] + ), + ).ask() + + if choice is None: + console.print("\n[red]No mid thinking llm engine selected. Exiting...[/red]") + exit(1) + + return choice + + def select_deep_thinking_agent(provider) -> str: """Select deep thinking llm engine using an interactive selection.""" - # Define deep thinking llm engine options with their corresponding model names - DEEP_AGENT_OPTIONS = { - "openai": [ - ("GPT-5.2 - Latest flagship", "gpt-5.2"), - ("GPT-5.1 - Flexible reasoning", "gpt-5.1"), - ("GPT-5 - Advanced reasoning", "gpt-5"), - ("GPT-4.1 - Smartest non-reasoning, 1M context", "gpt-4.1"), - ("GPT-5 Mini - Cost-optimized reasoning", "gpt-5-mini"), - ("GPT-5 Nano - Ultra-fast, high-throughput", "gpt-5-nano"), - ], - "anthropic": [ - ("Claude Sonnet 4.5 - Best for agents/coding", "claude-sonnet-4-5"), - ("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"), - ("Claude Opus 4.1 - Most capable model", "claude-opus-4-1-20250805"), - ("Claude Haiku 4.5 - Fast + extended thinking", "claude-haiku-4-5"), - ("Claude Sonnet 4 - High-performance", "claude-sonnet-4-20250514"), - ], - "google": [ - ("Gemini 3 Pro - Reasoning-first", "gemini-3-pro-preview"), - ("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"), - ("Gemini 2.5 Flash - Balanced, recommended", "gemini-2.5-flash"), - ], - "xai": [ - ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), - ("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"), - ("Grok 4 - Flagship model", "grok-4-0709"), - ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), - ("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"), - ], - "openrouter": [ - ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), - ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), - ], - "ollama": [ - ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), - ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), - ("Qwen3:latest (8B, local)", "qwen3:latest"), - ], - } + _provider = provider.lower() + + if _provider == "ollama": + ollama_models = _fetch_ollama_models() + if not ollama_models: + console.print("[yellow]Could not reach Ollama — is it running? Enter a model name manually.[/yellow]") + model = questionary.text("Model name (e.g. qwen3.5:27b):").ask() + if not model: + console.print("\n[red]No model entered. Exiting...[/red]") + exit(1) + return model.strip() + options = ollama_models + else: + DEEP_AGENT_OPTIONS = { + "openai": [ + ("GPT-5.2 - Latest flagship", "gpt-5.2"), + ("GPT-5.1 - Flexible reasoning", "gpt-5.1"), + ("GPT-5 - Advanced reasoning", "gpt-5"), + ("GPT-4.1 - Smartest non-reasoning, 1M context", "gpt-4.1"), + ("GPT-5 Mini - Cost-optimized reasoning", "gpt-5-mini"), + ("GPT-5 Nano - Ultra-fast, high-throughput", "gpt-5-nano"), + ], + "anthropic": [ + ("Claude Sonnet 4.5 - Best for agents/coding", "claude-sonnet-4-5"), + ("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"), + ("Claude Opus 4.1 - Most capable model", "claude-opus-4-1-20250805"), + ("Claude Haiku 4.5 - Fast + extended thinking", "claude-haiku-4-5"), + ("Claude Sonnet 4 - High-performance", "claude-sonnet-4-20250514"), + ], + "google": [ + ("Gemini 3 Pro - Reasoning-first", "gemini-3-pro-preview"), + ("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"), + ("Gemini 2.5 Flash - Balanced, recommended", "gemini-2.5-flash"), + ], + "xai": [ + ("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"), + ("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"), + ("Grok 4 - Flagship model", "grok-4-0709"), + ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), + ("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"), + ], + "openrouter": [ + ("DeepSeek R1 - Strong open-source reasoning", "deepseek/deepseek-r1"), + ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), + ("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), + ], + } + options = DEEP_AGENT_OPTIONS[_provider] choice = questionary.select( "Select Your [Deep-Thinking LLM Engine]:", choices=[ questionary.Choice(display, value=value) - for display, value in DEEP_AGENT_OPTIONS[provider.lower()] + for display, value in options ], instruction="\n- Use arrow keys to navigate\n- Press Enter to select", style=questionary.Style( diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index ecf0dc29..f84c7063 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -10,11 +10,26 @@ DEFAULT_CONFIG = { # LLM settings "llm_provider": "openai", "deep_think_llm": "gpt-5.2", + "mid_think_llm": None, # falls back to quick_think_llm when None "quick_think_llm": "gpt-5-mini", "backend_url": "https://api.openai.com/v1", - # Provider-specific thinking configuration + # Per-role provider overrides (fall back to llm_provider / backend_url when None) + "deep_think_llm_provider": None, # e.g. "google", "anthropic", "openai" + "deep_think_backend_url": None, # override backend URL for deep-think model + "mid_think_llm_provider": None, # e.g. "ollama" + "mid_think_backend_url": None, # override backend URL for mid-think model + "quick_think_llm_provider": None, # e.g. "openai", "ollama" + "quick_think_backend_url": None, # override backend URL for quick-think model + # Provider-specific thinking configuration (applies to all roles unless overridden) "google_thinking_level": None, # "high", "minimal", etc. "openai_reasoning_effort": None, # "medium", "high", "low" + # Per-role provider-specific thinking configuration + "deep_think_google_thinking_level": None, + "deep_think_openai_reasoning_effort": None, + "mid_think_google_thinking_level": None, + "mid_think_openai_reasoning_effort": None, + "quick_think_google_thinking_level": None, + "quick_think_openai_reasoning_effort": None, # Debate and discussion settings "max_debate_rounds": 1, "max_risk_discuss_rounds": 1, diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index 772efe7f..9dc900a1 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -17,6 +17,7 @@ class GraphSetup: def __init__( self, quick_thinking_llm: ChatOpenAI, + mid_thinking_llm: ChatOpenAI, deep_thinking_llm: ChatOpenAI, tool_nodes: Dict[str, ToolNode], bull_memory, @@ -28,6 +29,7 @@ class GraphSetup: ): """Initialize with required components.""" self.quick_thinking_llm = quick_thinking_llm + self.mid_thinking_llm = mid_thinking_llm self.deep_thinking_llm = deep_thinking_llm self.tool_nodes = tool_nodes self.bull_memory = bull_memory @@ -87,15 +89,15 @@ class GraphSetup: # Create researcher and manager nodes bull_researcher_node = create_bull_researcher( - self.quick_thinking_llm, self.bull_memory + self.mid_thinking_llm, self.bull_memory ) bear_researcher_node = create_bear_researcher( - self.quick_thinking_llm, self.bear_memory + self.mid_thinking_llm, self.bear_memory ) research_manager_node = create_research_manager( self.deep_thinking_llm, self.invest_judge_memory ) - trader_node = create_trader(self.quick_thinking_llm, self.trader_memory) + trader_node = create_trader(self.mid_thinking_llm, self.trader_memory) # Create risk analysis nodes aggressive_analyst = create_aggressive_debator(self.quick_thinking_llm) diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 44ecca0c..ccbbfc9f 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -71,27 +71,65 @@ class TradingAgentsGraph: exist_ok=True, ) - # Initialize LLMs with provider-specific thinking configuration - llm_kwargs = self._get_provider_kwargs() + # Initialize LLMs with provider-specific thinking configuration. + # Per-role provider/backend_url keys take precedence over the shared ones. + deep_kwargs = self._get_provider_kwargs("deep_think") + mid_kwargs = self._get_provider_kwargs("mid_think") + quick_kwargs = self._get_provider_kwargs("quick_think") # Add callbacks to kwargs if provided (passed to LLM constructor) if self.callbacks: - llm_kwargs["callbacks"] = self.callbacks + deep_kwargs["callbacks"] = self.callbacks + mid_kwargs["callbacks"] = self.callbacks + quick_kwargs["callbacks"] = self.callbacks + + deep_provider = ( + self.config.get("deep_think_llm_provider") or self.config["llm_provider"] + ) + deep_backend_url = ( + self.config.get("deep_think_backend_url") or self.config.get("backend_url") + ) + quick_provider = ( + self.config.get("quick_think_llm_provider") or self.config["llm_provider"] + ) + quick_backend_url = ( + self.config.get("quick_think_backend_url") or self.config.get("backend_url") + ) + + # mid_think falls back to quick_think when not configured + mid_model = self.config.get("mid_think_llm") or self.config["quick_think_llm"] + mid_provider = ( + self.config.get("mid_think_llm_provider") + or self.config.get("quick_think_llm_provider") + or self.config["llm_provider"] + ) + mid_backend_url = ( + self.config.get("mid_think_backend_url") + or self.config.get("quick_think_backend_url") + or self.config.get("backend_url") + ) deep_client = create_llm_client( - provider=self.config["llm_provider"], + provider=deep_provider, model=self.config["deep_think_llm"], - base_url=self.config.get("backend_url"), - **llm_kwargs, + base_url=deep_backend_url, + **deep_kwargs, + ) + mid_client = create_llm_client( + provider=mid_provider, + model=mid_model, + base_url=mid_backend_url, + **mid_kwargs, ) quick_client = create_llm_client( - provider=self.config["llm_provider"], + provider=quick_provider, model=self.config["quick_think_llm"], - base_url=self.config.get("backend_url"), - **llm_kwargs, + base_url=quick_backend_url, + **quick_kwargs, ) self.deep_thinking_llm = deep_client.get_llm() + self.mid_thinking_llm = mid_client.get_llm() self.quick_thinking_llm = quick_client.get_llm() # Initialize memories @@ -108,6 +146,7 @@ class TradingAgentsGraph: self.conditional_logic = ConditionalLogic() self.graph_setup = GraphSetup( self.quick_thinking_llm, + self.mid_thinking_llm, self.deep_thinking_llm, self.tool_nodes, self.bull_memory, @@ -130,18 +169,33 @@ class TradingAgentsGraph: # Set up the graph self.graph = self.graph_setup.setup_graph(selected_analysts) - def _get_provider_kwargs(self) -> Dict[str, Any]: - """Get provider-specific kwargs for LLM client creation.""" + def _get_provider_kwargs(self, role: str = "") -> Dict[str, Any]: + """Get provider-specific kwargs for LLM client creation. + + Args: + role: Either "deep_think" or "quick_think". When provided the + per-role config keys take precedence over the shared keys. + """ kwargs = {} - provider = self.config.get("llm_provider", "").lower() + prefix = f"{role}_" if role else "" + provider = ( + self.config.get(f"{prefix}llm_provider") + or self.config.get("llm_provider", "") + ).lower() if provider == "google": - thinking_level = self.config.get("google_thinking_level") + thinking_level = ( + self.config.get(f"{prefix}google_thinking_level") + or self.config.get("google_thinking_level") + ) if thinking_level: kwargs["thinking_level"] = thinking_level - elif provider == "openai": - reasoning_effort = self.config.get("openai_reasoning_effort") + elif provider in ("openai", "xai", "openrouter", "ollama"): + reasoning_effort = ( + self.config.get(f"{prefix}openai_reasoning_effort") + or self.config.get("openai_reasoning_effort") + ) if reasoning_effort: kwargs["reasoning_effort"] = reasoning_effort