diff --git a/cli/main.py b/cli/main.py index adda48fc..cb62488c 100644 --- a/cli/main.py +++ b/cli/main.py @@ -5,6 +5,7 @@ from pathlib import Path from functools import wraps from rich.console import Console from dotenv import load_dotenv +import copy # Load environment variables from .env file load_dotenv() @@ -465,16 +466,12 @@ def get_user_selections(): with open("./cli/static/welcome.txt", "r", encoding="utf-8") as f: welcome_ascii = f.read() - # Create welcome box content welcome_content = f"{welcome_ascii}\n" welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n" welcome_content += "[bold]Workflow Steps:[/bold]\n" welcome_content += "I. Analyst Team → II. Research Team → III. Trader → IV. Risk Management → V. Portfolio Management\n\n" - welcome_content += ( - "[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]" - ) + welcome_content += "[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]" - # Create and center the welcome box welcome_box = Panel( welcome_content, border_style="green", @@ -484,13 +481,11 @@ def get_user_selections(): ) console.print(Align.center(welcome_box)) console.print() - console.print() # Add vertical space before announcements + console.print() - # Fetch and display announcements (silent on failure) announcements = fetch_announcements() display_announcements(console, announcements) - # Create a boxed questionnaire for each step def create_question_box(title, prompt, default=None): box_content = f"[bold]{title}[/bold]\n" box_content += f"[dim]{prompt}[/dim]" @@ -498,7 +493,7 @@ def get_user_selections(): box_content += f"\n[dim]Default: {default}[/dim]" return Panel(box_content, border_style="blue", padding=(1, 2)) - # Step 1: Ticker symbol + # Step 1: Ticker console.print( create_question_box( "Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY" @@ -517,10 +512,11 @@ def get_user_selections(): ) analysis_date = get_analysis_date() - # Step 3: Select analysts + # Step 3: Analysts console.print( create_question_box( - "Step 3: Analysts Team", "Select your LLM analyst agents for the analysis" + "Step 3: Analysts Team", + "Select your LLM analyst agents for the analysis", ) ) selected_analysts = select_analysts() @@ -536,54 +532,117 @@ def get_user_selections(): ) selected_research_depth = select_research_depth() - # Step 5: OpenAI backend + # Step 5: Routing mode console.print( create_question_box( - "Step 5: OpenAI backend", "Select which service to talk to" + "Step 5: LLM Routing", + "Choose whether to use one LLM for all agents or customize by stage", ) ) - selected_llm_provider, backend_url = select_llm_provider() - - # Step 6: Thinking agents - console.print( - create_question_box( - "Step 6: Thinking Agents", "Select your thinking agents for analysis" - ) - ) - selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) - selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) + routing_mode = select_routing_mode() - # Step 7: Provider-specific thinking configuration + llm_routing = {"default": None, "roles": {}} + + # Provider-specific thinking configuration thinking_level = None reasoning_effort = None - provider_lower = selected_llm_provider.lower() - if provider_lower == "google": + if routing_mode == "single": console.print( create_question_box( - "Step 7: Thinking Mode", - "Configure Gemini thinking mode" + "Step 6: Default LLM", + "Select one provider/model for all agents", ) ) - thinking_level = ask_gemini_thinking_config() - elif provider_lower == "openai": + default_llm = select_llm_bundle("All Agents") + llm_routing["default"] = default_llm + + provider = default_llm["provider"] + if provider == "google": + console.print( + create_question_box( + "Step 7: Thinking Mode", + "Configure Google/Gemini thinking mode", + ) + ) + thinking_level = ask_gemini_thinking_config() + elif provider == "openai": + console.print( + create_question_box( + "Step 7: Reasoning Effort", + "Configure OpenAI reasoning effort level", + ) + ) + reasoning_effort = ask_openai_reasoning_effort() + + elif routing_mode == "stage": console.print( create_question_box( - "Step 7: Reasoning Effort", - "Configure OpenAI reasoning effort level" + "Step 6: Default LLM", + "Select the default provider/model used unless a stage overrides it", ) ) - reasoning_effort = ask_openai_reasoning_effort() + default_llm = select_llm_bundle("Default LLM") + llm_routing["default"] = default_llm + + console.print( + create_question_box( + "Step 7: Stage Overrides", + "Select LLMs for each stage", + ) + ) + analyst_llm = select_llm_bundle("Analyst Team", depth_hint="quick") + research_llm = select_llm_bundle("Research Team", depth_hint="deep") + trader_llm = select_llm_bundle("Trader", depth_hint="deep") + risk_llm = select_llm_bundle("Risk Team", depth_hint="deep") + pm_llm = select_llm_bundle("Portfolio Manager", depth_hint="deep") + + for role in ["market", "social", "news", "fundamentals"]: + llm_routing["roles"][role] = analyst_llm + + for role in ["bull_researcher", "bear_researcher", "research_manager"]: + llm_routing["roles"][role] = research_llm + + llm_routing["roles"]["trader"] = trader_llm + + for role in ["aggressive_analyst", "neutral_analyst", "conservative_analyst"]: + llm_routing["roles"][role] = risk_llm + + llm_routing["roles"]["portfolio_manager"] = pm_llm + + providers_used = { + cfg["provider"] + for cfg in [default_llm, analyst_llm, research_llm, trader_llm, risk_llm, pm_llm] + if cfg is not None and "provider" in cfg + } + + if "google" in providers_used: + console.print( + create_question_box( + "Step 8: Thinking Mode", + "Configure Google/Gemini thinking mode", + ) + ) + thinking_level = ask_gemini_thinking_config() + + if "openai" in providers_used: + console.print( + create_question_box( + "Step 8: Reasoning Effort", + "Configure OpenAI reasoning effort level", + ) + ) + reasoning_effort = ask_openai_reasoning_effort() + + else: + raise ValueError(f"Unsupported routing mode: {routing_mode}") return { "ticker": selected_ticker, "analysis_date": analysis_date, "analysts": selected_analysts, "research_depth": selected_research_depth, - "llm_provider": selected_llm_provider.lower(), - "backend_url": backend_url, - "shallow_thinker": selected_shallow_thinker, - "deep_thinker": selected_deep_thinker, + "llm_routing": llm_routing, "google_thinking_level": thinking_level, "openai_reasoning_effort": reasoning_effort, } @@ -612,33 +671,62 @@ def get_analysis_date(): "[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]" ) +def get_role_llm_label(llm_routing: dict, role: str) -> str: + """Return a human-readable provider/model label for a role.""" + default_cfg = llm_routing.get("default", {}) + role_cfg = llm_routing.get("roles", {}).get(role) or default_cfg -def save_report_to_disk(final_state, ticker: str, save_path: Path): - """Save complete analysis report to disk with organized subfolders.""" + provider = role_cfg.get("provider", "unknown") + model = role_cfg.get("model", "unknown") + return f"{provider} / {model}" + + +def with_llm_header(content: str, llm_label: str) -> str: + """Prefix saved markdown content with LLM metadata.""" + return f"> LLM: {llm_label}\n\n{content}" + +def save_report_to_disk(final_state, ticker: str, save_path: Path, llm_routing: dict): + """Save complete analysis report to disk with organized subfolders and LLM metadata.""" save_path.mkdir(parents=True, exist_ok=True) sections = [] # 1. Analysts analysts_dir = save_path / "1_analysts" analyst_parts = [] + if final_state.get("market_report"): analysts_dir.mkdir(exist_ok=True) - (analysts_dir / "market.md").write_text(final_state["market_report"]) - analyst_parts.append(("Market Analyst", final_state["market_report"])) + llm_label = get_role_llm_label(llm_routing, "market") + content = with_llm_header(final_state["market_report"], llm_label) + (analysts_dir / "market.md").write_text(content, encoding="utf-8") + analyst_parts.append(("Market Analyst", llm_label, final_state["market_report"])) + if final_state.get("sentiment_report"): analysts_dir.mkdir(exist_ok=True) - (analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"]) - analyst_parts.append(("Social Analyst", final_state["sentiment_report"])) + llm_label = get_role_llm_label(llm_routing, "social") + content = with_llm_header(final_state["sentiment_report"], llm_label) + (analysts_dir / "sentiment.md").write_text(content, encoding="utf-8") + analyst_parts.append(("Social Analyst", llm_label, final_state["sentiment_report"])) + if final_state.get("news_report"): analysts_dir.mkdir(exist_ok=True) - (analysts_dir / "news.md").write_text(final_state["news_report"]) - analyst_parts.append(("News Analyst", final_state["news_report"])) + llm_label = get_role_llm_label(llm_routing, "news") + content = with_llm_header(final_state["news_report"], llm_label) + (analysts_dir / "news.md").write_text(content, encoding="utf-8") + analyst_parts.append(("News Analyst", llm_label, final_state["news_report"])) + if final_state.get("fundamentals_report"): analysts_dir.mkdir(exist_ok=True) - (analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"]) - analyst_parts.append(("Fundamentals Analyst", final_state["fundamentals_report"])) + llm_label = get_role_llm_label(llm_routing, "fundamentals") + content = with_llm_header(final_state["fundamentals_report"], llm_label) + (analysts_dir / "fundamentals.md").write_text(content, encoding="utf-8") + analyst_parts.append(("Fundamentals Analyst", llm_label, final_state["fundamentals_report"])) + if analyst_parts: - content = "\n\n".join(f"### {name}\n{text}" for name, text in analyst_parts) + content = "\n\n".join( + f"### {name}\n*LLM: {llm_label}*\n\n{text}" + for name, llm_label, text in analyst_parts + ) sections.append(f"## I. Analyst Team Reports\n\n{content}") # 2. Research @@ -646,122 +734,168 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path): research_dir = save_path / "2_research" debate = final_state["investment_debate_state"] research_parts = [] + if debate.get("bull_history"): research_dir.mkdir(exist_ok=True) - (research_dir / "bull.md").write_text(debate["bull_history"]) - research_parts.append(("Bull Researcher", debate["bull_history"])) + llm_label = get_role_llm_label(llm_routing, "bull_researcher") + content = with_llm_header(debate["bull_history"], llm_label) + (research_dir / "bull.md").write_text(content, encoding="utf-8") + research_parts.append(("Bull Researcher", llm_label, debate["bull_history"])) + if debate.get("bear_history"): research_dir.mkdir(exist_ok=True) - (research_dir / "bear.md").write_text(debate["bear_history"]) - research_parts.append(("Bear Researcher", debate["bear_history"])) + llm_label = get_role_llm_label(llm_routing, "bear_researcher") + content = with_llm_header(debate["bear_history"], llm_label) + (research_dir / "bear.md").write_text(content, encoding="utf-8") + research_parts.append(("Bear Researcher", llm_label, debate["bear_history"])) + if debate.get("judge_decision"): research_dir.mkdir(exist_ok=True) - (research_dir / "manager.md").write_text(debate["judge_decision"]) - research_parts.append(("Research Manager", debate["judge_decision"])) + llm_label = get_role_llm_label(llm_routing, "research_manager") + content = with_llm_header(debate["judge_decision"], llm_label) + (research_dir / "manager.md").write_text(content, encoding="utf-8") + research_parts.append(("Research Manager", llm_label, debate["judge_decision"])) + if research_parts: - content = "\n\n".join(f"### {name}\n{text}" for name, text in research_parts) + content = "\n\n".join( + f"### {name}\n*LLM: {llm_label}*\n\n{text}" + for name, llm_label, text in research_parts + ) sections.append(f"## II. Research Team Decision\n\n{content}") # 3. Trading if final_state.get("trader_investment_plan"): trading_dir = save_path / "3_trading" trading_dir.mkdir(exist_ok=True) - (trading_dir / "trader.md").write_text(final_state["trader_investment_plan"]) - sections.append(f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}") + llm_label = get_role_llm_label(llm_routing, "trader") + content = with_llm_header(final_state["trader_investment_plan"], llm_label) + (trading_dir / "trader.md").write_text(content, encoding="utf-8") + sections.append( + f"## III. Trading Team Plan\n\n### Trader\n*LLM: {llm_label}*\n\n{final_state['trader_investment_plan']}" + ) # 4. Risk Management if final_state.get("risk_debate_state"): risk_dir = save_path / "4_risk" risk = final_state["risk_debate_state"] risk_parts = [] + if risk.get("aggressive_history"): risk_dir.mkdir(exist_ok=True) - (risk_dir / "aggressive.md").write_text(risk["aggressive_history"]) - risk_parts.append(("Aggressive Analyst", risk["aggressive_history"])) + llm_label = get_role_llm_label(llm_routing, "aggressive_analyst") + content = with_llm_header(risk["aggressive_history"], llm_label) + (risk_dir / "aggressive.md").write_text(content, encoding="utf-8") + risk_parts.append(("Aggressive Analyst", llm_label, risk["aggressive_history"])) + if risk.get("conservative_history"): risk_dir.mkdir(exist_ok=True) - (risk_dir / "conservative.md").write_text(risk["conservative_history"]) - risk_parts.append(("Conservative Analyst", risk["conservative_history"])) + llm_label = get_role_llm_label(llm_routing, "conservative_analyst") + content = with_llm_header(risk["conservative_history"], llm_label) + (risk_dir / "conservative.md").write_text(content, encoding="utf-8") + risk_parts.append(("Conservative Analyst", llm_label, risk["conservative_history"])) + if risk.get("neutral_history"): risk_dir.mkdir(exist_ok=True) - (risk_dir / "neutral.md").write_text(risk["neutral_history"]) - risk_parts.append(("Neutral Analyst", risk["neutral_history"])) + llm_label = get_role_llm_label(llm_routing, "neutral_analyst") + content = with_llm_header(risk["neutral_history"], llm_label) + (risk_dir / "neutral.md").write_text(content, encoding="utf-8") + risk_parts.append(("Neutral Analyst", llm_label, risk["neutral_history"])) + if risk_parts: - content = "\n\n".join(f"### {name}\n{text}" for name, text in risk_parts) + content = "\n\n".join( + f"### {name}\n*LLM: {llm_label}*\n\n{text}" + for name, llm_label, text in risk_parts + ) sections.append(f"## IV. Risk Management Team Decision\n\n{content}") # 5. Portfolio Manager if risk.get("judge_decision"): portfolio_dir = save_path / "5_portfolio" portfolio_dir.mkdir(exist_ok=True) - (portfolio_dir / "decision.md").write_text(risk["judge_decision"]) - sections.append(f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}") + llm_label = get_role_llm_label(llm_routing, "portfolio_manager") + content = with_llm_header(risk["judge_decision"], llm_label) + (portfolio_dir / "decision.md").write_text(content, encoding="utf-8") + sections.append( + f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n*LLM: {llm_label}*\n\n{risk['judge_decision']}" + ) - # Write consolidated report - header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - (save_path / "complete_report.md").write_text(header + "\n\n".join(sections)) + header = ( + f"# Trading Analysis Report: {ticker}\n\n" + f"Generated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + ) + (save_path / "complete_report.md").write_text(header + "\n\n".join(sections), encoding="utf-8") return save_path / "complete_report.md" -def display_complete_report(final_state): - """Display the complete analysis report sequentially (avoids truncation).""" +def display_complete_report(final_state, llm_routing: dict): + """Display the complete analysis report sequentially with LLM metadata.""" console.print() console.print(Rule("Complete Analysis Report", style="bold green")) # I. Analyst Team Reports analysts = [] if final_state.get("market_report"): - analysts.append(("Market Analyst", final_state["market_report"])) + analysts.append(("Market Analyst", get_role_llm_label(llm_routing, "market"), final_state["market_report"])) if final_state.get("sentiment_report"): - analysts.append(("Social Analyst", final_state["sentiment_report"])) + analysts.append(("Social Analyst", get_role_llm_label(llm_routing, "social"), final_state["sentiment_report"])) if final_state.get("news_report"): - analysts.append(("News Analyst", final_state["news_report"])) + analysts.append(("News Analyst", get_role_llm_label(llm_routing, "news"), final_state["news_report"])) if final_state.get("fundamentals_report"): - analysts.append(("Fundamentals Analyst", final_state["fundamentals_report"])) + analysts.append(("Fundamentals Analyst", get_role_llm_label(llm_routing, "fundamentals"), final_state["fundamentals_report"])) + if analysts: console.print(Panel("[bold]I. Analyst Team Reports[/bold]", border_style="cyan")) - for title, content in analysts: - console.print(Panel(Markdown(content), title=title, border_style="blue", padding=(1, 2))) + for title, llm_label, content in analysts: + body = f"*LLM: {llm_label}*\n\n{content}" + console.print(Panel(Markdown(body), title=title, border_style="blue", padding=(1, 2))) # II. Research Team Reports if final_state.get("investment_debate_state"): debate = final_state["investment_debate_state"] research = [] if debate.get("bull_history"): - research.append(("Bull Researcher", debate["bull_history"])) + research.append(("Bull Researcher", get_role_llm_label(llm_routing, "bull_researcher"), debate["bull_history"])) if debate.get("bear_history"): - research.append(("Bear Researcher", debate["bear_history"])) + research.append(("Bear Researcher", get_role_llm_label(llm_routing, "bear_researcher"), debate["bear_history"])) if debate.get("judge_decision"): - research.append(("Research Manager", debate["judge_decision"])) + research.append(("Research Manager", get_role_llm_label(llm_routing, "research_manager"), debate["judge_decision"])) + if research: console.print(Panel("[bold]II. Research Team Decision[/bold]", border_style="magenta")) - for title, content in research: - console.print(Panel(Markdown(content), title=title, border_style="blue", padding=(1, 2))) + for title, llm_label, content in research: + body = f"*LLM: {llm_label}*\n\n{content}" + console.print(Panel(Markdown(body), title=title, border_style="blue", padding=(1, 2))) # III. Trading Team if final_state.get("trader_investment_plan"): + llm_label = get_role_llm_label(llm_routing, "trader") console.print(Panel("[bold]III. Trading Team Plan[/bold]", border_style="yellow")) - console.print(Panel(Markdown(final_state["trader_investment_plan"]), title="Trader", border_style="blue", padding=(1, 2))) + body = f"*LLM: {llm_label}*\n\n{final_state['trader_investment_plan']}" + console.print(Panel(Markdown(body), title="Trader", border_style="blue", padding=(1, 2))) # IV. Risk Management Team if final_state.get("risk_debate_state"): risk = final_state["risk_debate_state"] risk_reports = [] if risk.get("aggressive_history"): - risk_reports.append(("Aggressive Analyst", risk["aggressive_history"])) + risk_reports.append(("Aggressive Analyst", get_role_llm_label(llm_routing, "aggressive_analyst"), risk["aggressive_history"])) if risk.get("conservative_history"): - risk_reports.append(("Conservative Analyst", risk["conservative_history"])) + risk_reports.append(("Conservative Analyst", get_role_llm_label(llm_routing, "conservative_analyst"), risk["conservative_history"])) if risk.get("neutral_history"): - risk_reports.append(("Neutral Analyst", risk["neutral_history"])) + risk_reports.append(("Neutral Analyst", get_role_llm_label(llm_routing, "neutral_analyst"), risk["neutral_history"])) + if risk_reports: console.print(Panel("[bold]IV. Risk Management Team Decision[/bold]", border_style="red")) - for title, content in risk_reports: - console.print(Panel(Markdown(content), title=title, border_style="blue", padding=(1, 2))) + for title, llm_label, content in risk_reports: + body = f"*LLM: {llm_label}*\n\n{content}" + console.print(Panel(Markdown(body), title=title, border_style="blue", padding=(1, 2))) # V. Portfolio Manager Decision if risk.get("judge_decision"): + llm_label = get_role_llm_label(llm_routing, "portfolio_manager") console.print(Panel("[bold]V. Portfolio Manager Decision[/bold]", border_style="green")) - console.print(Panel(Markdown(risk["judge_decision"]), title="Portfolio Manager", border_style="blue", padding=(1, 2))) + body = f"*LLM: {llm_label}*\n\n{risk['judge_decision']}" + console.print(Panel(Markdown(body), title="Portfolio Manager", border_style="blue", padding=(1, 2))) def update_research_team_status(status): @@ -901,14 +1035,10 @@ def run_analysis(): selections = get_user_selections() # Create config with selected research depth - config = DEFAULT_CONFIG.copy() + config = copy.deepcopy(DEFAULT_CONFIG) config["max_debate_rounds"] = selections["research_depth"] config["max_risk_discuss_rounds"] = selections["research_depth"] - config["quick_think_llm"] = selections["shallow_thinker"] - config["deep_think_llm"] = selections["deep_thinker"] - config["backend_url"] = selections["backend_url"] - config["llm_provider"] = selections["llm_provider"].lower() - # Provider-specific thinking configuration + config["llm_routing"] = selections["llm_routing"] config["google_thinking_level"] = selections.get("google_thinking_level") config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") @@ -1155,7 +1285,12 @@ def run_analysis(): ).strip() save_path = Path(save_path_str) try: - report_file = save_report_to_disk(final_state, selections["ticker"], save_path) + report_file = save_report_to_disk( + final_state, + selections["ticker"], + save_path, + config["llm_routing"], + ) console.print(f"\n[green]✓ Report saved to:[/green] {save_path.resolve()}") console.print(f" [dim]Complete report:[/dim] {report_file.name}") except Exception as e: @@ -1164,7 +1299,7 @@ def run_analysis(): # Prompt to display full report display_choice = typer.prompt("\nDisplay full report on screen?", default="Y").strip().upper() if display_choice in ("Y", "YES", ""): - display_complete_report(final_state) + display_complete_report(final_state, config["llm_routing"]) @app.command() diff --git a/cli/utils.py b/cli/utils.py index 5a8ec16c..9750e716 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -126,12 +126,8 @@ def select_research_depth() -> int: return choice -def select_shallow_thinking_agent(provider) -> str: - """Select shallow thinking llm engine using an interactive selection.""" - - # Define shallow thinking llm engine options with their corresponding model names - # Ordering: medium → light → heavy (balanced first for quick tasks) - # Within same tier, newer models first +def select_shallow_thinking_agent(provider: str) -> str: + """Select quick-thinking LLM engine using an interactive selection.""" SHALLOW_AGENT_OPTIONS = { "openai": [ ("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"), @@ -166,11 +162,15 @@ def select_shallow_thinking_agent(provider) -> str: ], } + provider = provider.lower() + if provider not in SHALLOW_AGENT_OPTIONS: + raise ValueError(f"Unsupported quick-thinking provider: {provider}") + choice = questionary.select( "Select Your [Quick-Thinking LLM Engine]:", choices=[ questionary.Choice(display, value=value) - for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()] + for display, value in SHALLOW_AGENT_OPTIONS[provider] ], instruction="\n- Use arrow keys to navigate\n- Press Enter to select", style=questionary.Style( @@ -183,20 +183,14 @@ def select_shallow_thinking_agent(provider) -> str: ).ask() if choice is None: - console.print( - "\n[red]No shallow thinking llm engine selected. Exiting...[/red]" - ) + console.print("\n[red]No quick-thinking LLM engine selected. Exiting...[/red]") exit(1) return choice -def select_deep_thinking_agent(provider) -> str: - """Select deep thinking llm engine using an interactive selection.""" - - # Define deep thinking llm engine options with their corresponding model names - # Ordering: heavy → medium → light (most capable first for deep tasks) - # Within same tier, newer models first +def select_deep_thinking_agent(provider: str) -> str: + """Select deep-thinking LLM engine using an interactive selection.""" DEEP_AGENT_OPTIONS = { "openai": [ ("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"), @@ -233,11 +227,15 @@ def select_deep_thinking_agent(provider) -> str: ], } + provider = provider.lower() + if provider not in DEEP_AGENT_OPTIONS: + raise ValueError(f"Unsupported deep-thinking provider: {provider}") + choice = questionary.select( "Select Your [Deep-Thinking LLM Engine]:", choices=[ questionary.Choice(display, value=value) - for display, value in DEEP_AGENT_OPTIONS[provider.lower()] + for display, value in DEEP_AGENT_OPTIONS[provider] ], instruction="\n- Use arrow keys to navigate\n- Press Enter to select", style=questionary.Style( @@ -250,27 +248,26 @@ def select_deep_thinking_agent(provider) -> str: ).ask() if choice is None: - console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]") + console.print("\n[red]No deep-thinking LLM engine selected. Exiting...[/red]") exit(1) return choice -def select_llm_provider() -> tuple[str, str]: - """Select the OpenAI api url using interactive selection.""" - # Define OpenAI api options with their corresponding endpoints +def select_provider(label: str) -> tuple[str, str]: + """Select an LLM provider and base URL.""" BASE_URLS = [ ("OpenAI", "https://api.openai.com/v1"), ("Google", "https://generativelanguage.googleapis.com/v1"), - ("Anthropic", "https://api.anthropic.com/"), + ("Anthropic", "https://api.anthropic.com"), ("xAI", "https://api.x.ai/v1"), - ("Openrouter", "https://openrouter.ai/api/v1"), + ("OpenRouter", "https://openrouter.ai/api/v1"), ("Ollama", "http://localhost:11434/v1"), ] - + choice = questionary.select( - "Select your LLM Provider:", + f"Select your {label} Provider:", choices=[ - questionary.Choice(display, value=(display, value)) + questionary.Choice(display, value=(display.lower(), value)) for display, value in BASE_URLS ], instruction="\n- Use arrow keys to navigate\n- Press Enter to select", @@ -282,15 +279,14 @@ def select_llm_provider() -> tuple[str, str]: ] ), ).ask() - - if choice is None: - console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") - exit(1) - - display_name, url = choice - print(f"You selected: {display_name}\tURL: {url}") - return display_name, url + if choice is None: + console.print(f"\n[red]No {label.lower()} provider selected. Exiting...[/red]") + exit(1) + + provider, url = choice + console.print(f"You selected {label}: [green]{provider}[/green] URL: {url}") + return provider, url def ask_openai_reasoning_effort() -> str: @@ -329,3 +325,34 @@ def ask_gemini_thinking_config() -> str | None: ("pointer", "fg:green noinherit"), ]), ).ask() + +def select_routing_mode() -> str: + choice = questionary.select( + "How would you like to configure LLMs?", + choices=[ + questionary.Choice("Use one LLM for all agents", "single"), + questionary.Choice("Customize by stage", "stage"), + ], + ).ask() + + if choice is None: + console.print("\n[red]No routing mode selected. Exiting...[/red]") + exit(1) + + return choice + +def select_llm_bundle(label: str, depth_hint: str | None = None) -> dict: + provider, base_url = select_provider(label) + + if depth_hint == "quick": + model = select_shallow_thinking_agent(provider) + elif depth_hint == "deep": + model = select_deep_thinking_agent(provider) + else: + model = select_deep_thinking_agent(provider) + + return { + "provider": provider, + "model": model, + "base_url": base_url, + } \ No newline at end of file diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index ecf0dc29..5f58b874 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -7,28 +7,53 @@ DEFAULT_CONFIG = { os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "dataflows/data_cache", ), - # LLM settings - "llm_provider": "openai", - "deep_think_llm": "gpt-5.2", - "quick_think_llm": "gpt-5-mini", - "backend_url": "https://api.openai.com/v1", + + # LLM routing settings + # Default: all agents use the same provider/model unless explicitly overridden + "llm_routing": { + "default": { + "provider": "openai", + "model": "gpt-5-mini", + "base_url": "https://api.openai.com/v1", + }, + "roles": { + # Optional per-role overrides. + # Leave as None to inherit from "default". + "market": None, + "social": None, + "news": None, + "fundamentals": None, + "bull_researcher": None, + "bear_researcher": None, + "research_manager": None, + "trader": None, + "aggressive_analyst": None, + "neutral_analyst": None, + "conservative_analyst": None, + "portfolio_manager": None, + }, + }, + # Provider-specific thinking configuration - "google_thinking_level": None, # "high", "minimal", etc. - "openai_reasoning_effort": None, # "medium", "high", "low" + # These apply whenever that provider is used. + "google_thinking_level": None, # e.g. "high", "minimal" + "openai_reasoning_effort": None, # e.g. "medium", "high", "low" + # Debate and discussion settings "max_debate_rounds": 1, "max_risk_discuss_rounds": 1, "max_recur_limit": 100, + # Data vendor configuration - # Category-level configuration (default for all tools in category) "data_vendors": { - "core_stock_apis": "yfinance", # Options: alpha_vantage, yfinance - "technical_indicators": "yfinance", # Options: alpha_vantage, yfinance - "fundamental_data": "yfinance", # Options: alpha_vantage, yfinance - "news_data": "yfinance", # Options: alpha_vantage, yfinance + "core_stock_apis": "yfinance", + "technical_indicators": "yfinance", + "fundamental_data": "yfinance", + "news_data": "yfinance", }, + # Tool-level configuration (takes precedence over category-level) "tool_vendors": { - # Example: "get_stock_data": "alpha_vantage", # Override category default + # Example: "get_stock_data": "alpha_vantage", }, -} +} \ No newline at end of file diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index 772efe7f..d38a13ba 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -1,7 +1,4 @@ -# TradingAgents/graph/setup.py - from typing import Dict, Any -from langchain_openai import ChatOpenAI from langgraph.graph import END, StateGraph, START from langgraph.prebuilt import ToolNode @@ -16,8 +13,7 @@ class GraphSetup: def __init__( self, - quick_thinking_llm: ChatOpenAI, - deep_thinking_llm: ChatOpenAI, + role_llms: Dict[str, Any], tool_nodes: Dict[str, ToolNode], bull_memory, bear_memory, @@ -27,8 +23,7 @@ class GraphSetup: conditional_logic: ConditionalLogic, ): """Initialize with required components.""" - self.quick_thinking_llm = quick_thinking_llm - self.deep_thinking_llm = deep_thinking_llm + self.role_llms = role_llms self.tool_nodes = tool_nodes self.bull_memory = bull_memory self.bear_memory = bear_memory @@ -37,78 +32,78 @@ class GraphSetup: self.risk_manager_memory = risk_manager_memory self.conditional_logic = conditional_logic + def _get_llm(self, role: str): + if role not in self.role_llms: + raise ValueError(f"Missing LLM assignment for role: {role}") + return self.role_llms[role] + def setup_graph( self, selected_analysts=["market", "social", "news", "fundamentals"] ): - """Set up and compile the agent workflow graph. - - Args: - selected_analysts (list): List of analyst types to include. Options are: - - "market": Market analyst - - "social": Social media analyst - - "news": News analyst - - "fundamentals": Fundamentals analyst - """ + """Set up and compile the agent workflow graph.""" if len(selected_analysts) == 0: raise ValueError("Trading Agents Graph Setup Error: no analysts selected!") - # Create analyst nodes analyst_nodes = {} delete_nodes = {} tool_nodes = {} if "market" in selected_analysts: analyst_nodes["market"] = create_market_analyst( - self.quick_thinking_llm + self._get_llm("market") ) delete_nodes["market"] = create_msg_delete() tool_nodes["market"] = self.tool_nodes["market"] if "social" in selected_analysts: analyst_nodes["social"] = create_social_media_analyst( - self.quick_thinking_llm + self._get_llm("social") ) delete_nodes["social"] = create_msg_delete() tool_nodes["social"] = self.tool_nodes["social"] if "news" in selected_analysts: analyst_nodes["news"] = create_news_analyst( - self.quick_thinking_llm + self._get_llm("news") ) delete_nodes["news"] = create_msg_delete() tool_nodes["news"] = self.tool_nodes["news"] if "fundamentals" in selected_analysts: analyst_nodes["fundamentals"] = create_fundamentals_analyst( - self.quick_thinking_llm + self._get_llm("fundamentals") ) delete_nodes["fundamentals"] = create_msg_delete() tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"] - # Create researcher and manager nodes bull_researcher_node = create_bull_researcher( - self.quick_thinking_llm, self.bull_memory + self._get_llm("bull_researcher"), self.bull_memory ) bear_researcher_node = create_bear_researcher( - self.quick_thinking_llm, self.bear_memory + self._get_llm("bear_researcher"), self.bear_memory ) research_manager_node = create_research_manager( - self.deep_thinking_llm, self.invest_judge_memory + self._get_llm("research_manager"), self.invest_judge_memory + ) + trader_node = create_trader( + self._get_llm("trader"), self.trader_memory ) - trader_node = create_trader(self.quick_thinking_llm, self.trader_memory) - # Create risk analysis nodes - aggressive_analyst = create_aggressive_debator(self.quick_thinking_llm) - neutral_analyst = create_neutral_debator(self.quick_thinking_llm) - conservative_analyst = create_conservative_debator(self.quick_thinking_llm) + aggressive_analyst = create_aggressive_debator( + self._get_llm("aggressive_analyst") + ) + neutral_analyst = create_neutral_debator( + self._get_llm("neutral_analyst") + ) + conservative_analyst = create_conservative_debator( + self._get_llm("conservative_analyst") + ) risk_manager_node = create_risk_manager( - self.deep_thinking_llm, self.risk_manager_memory + self._get_llm("portfolio_manager"), self.risk_manager_memory ) - # Create workflow workflow = StateGraph(AgentState) - # Add analyst nodes to the graph for analyst_type, node in analyst_nodes.items(): workflow.add_node(f"{analyst_type.capitalize()} Analyst", node) workflow.add_node( @@ -116,7 +111,6 @@ class GraphSetup: ) workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type]) - # Add other nodes workflow.add_node("Bull Researcher", bull_researcher_node) workflow.add_node("Bear Researcher", bear_researcher_node) workflow.add_node("Research Manager", research_manager_node) @@ -126,18 +120,14 @@ class GraphSetup: workflow.add_node("Conservative Analyst", conservative_analyst) workflow.add_node("Risk Judge", risk_manager_node) - # Define edges - # Start with the first analyst first_analyst = selected_analysts[0] workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst") - # Connect analysts in sequence for i, analyst_type in enumerate(selected_analysts): current_analyst = f"{analyst_type.capitalize()} Analyst" current_tools = f"tools_{analyst_type}" current_clear = f"Msg Clear {analyst_type.capitalize()}" - # Add conditional edges for current analyst workflow.add_conditional_edges( current_analyst, getattr(self.conditional_logic, f"should_continue_{analyst_type}"), @@ -145,14 +135,12 @@ class GraphSetup: ) workflow.add_edge(current_tools, current_analyst) - # Connect to next analyst or to Bull Researcher if this is the last analyst if i < len(selected_analysts) - 1: next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst" workflow.add_edge(current_clear, next_analyst) else: workflow.add_edge(current_clear, "Bull Researcher") - # Add remaining edges workflow.add_conditional_edges( "Bull Researcher", self.conditional_logic.should_continue_debate, @@ -198,5 +186,4 @@ class GraphSetup: workflow.add_edge("Risk Judge", END) - # Compile and return - return workflow.compile() + return workflow.compile() \ No newline at end of file diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index c7ef0f98..5102b6fa 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -71,28 +71,51 @@ class TradingAgentsGraph: exist_ok=True, ) - # Initialize LLMs with provider-specific thinking configuration - llm_kwargs = self._get_provider_kwargs() + # Initialize LLMs from llm_routing. + llm_routing = self.config.get("llm_routing", {}) + default_llm_cfg = llm_routing.get("default") + role_overrides = llm_routing.get("roles", {}) - # Add callbacks to kwargs if provided (passed to LLM constructor) - if self.callbacks: - llm_kwargs["callbacks"] = self.callbacks + if default_llm_cfg is None: + raise ValueError("config['llm_routing']['default'] must be set") - deep_client = create_llm_client( - provider=self.config["llm_provider"], - model=self.config["deep_think_llm"], - base_url=self.config.get("backend_url"), - **llm_kwargs, - ) - quick_client = create_llm_client( - provider=self.config["llm_provider"], - model=self.config["quick_think_llm"], - base_url=self.config.get("backend_url"), - **llm_kwargs, - ) + all_roles = [ + "market", + "social", + "news", + "fundamentals", + "bull_researcher", + "bear_researcher", + "research_manager", + "trader", + "aggressive_analyst", + "neutral_analyst", + "conservative_analyst", + "portfolio_manager", + ] - self.deep_thinking_llm = deep_client.get_llm() - self.quick_thinking_llm = quick_client.get_llm() + self.role_llms = {} + + for role in all_roles: + llm_cfg = role_overrides.get(role) or default_llm_cfg + + provider = llm_cfg["provider"] + model = llm_cfg["model"] + base_url = llm_cfg.get("base_url") + + llm_kwargs = self._get_provider_kwargs(provider) + + if self.callbacks: + llm_kwargs["callbacks"] = self.callbacks + + client = create_llm_client( + provider=provider, + model=model, + base_url=base_url, + **llm_kwargs, + ) + + self.role_llms[role] = client.get_llm() # Initialize memories self.bull_memory = FinancialSituationMemory("bull_memory", self.config) @@ -110,8 +133,7 @@ class TradingAgentsGraph: max_risk_discuss_rounds=self.config["max_risk_discuss_rounds"], ) self.graph_setup = GraphSetup( - self.quick_thinking_llm, - self.deep_thinking_llm, + self.role_llms, self.tool_nodes, self.bull_memory, self.bear_memory, @@ -122,8 +144,8 @@ class TradingAgentsGraph: ) self.propagator = Propagator() - self.reflector = Reflector(self.quick_thinking_llm) - self.signal_processor = SignalProcessor(self.quick_thinking_llm) + self.reflector = Reflector(self.role_llms["trader"]) + self.signal_processor = SignalProcessor(self.role_llms["portfolio_manager"]) # State tracking self.curr_state = None @@ -132,11 +154,10 @@ class TradingAgentsGraph: # Set up the graph self.graph = self.graph_setup.setup_graph(selected_analysts) - - def _get_provider_kwargs(self) -> Dict[str, Any]: + def _get_provider_kwargs(self, provider: str) -> Dict[str, Any]: """Get provider-specific kwargs for LLM client creation.""" kwargs = {} - provider = self.config.get("llm_provider", "").lower() + provider = (provider or "").lower() if provider == "google": thinking_level = self.config.get("google_thinking_level") @@ -150,6 +171,7 @@ class TradingAgentsGraph: return kwargs + def _create_tool_nodes(self) -> Dict[str, ToolNode]: """Create tool nodes for different data sources using abstract methods.""" return {