Merge e5690d0388 into f362a160c3
This commit is contained in:
commit
f0d6ec3581
329
cli/main.py
329
cli/main.py
|
|
@ -5,6 +5,7 @@ from pathlib import Path
|
|||
from functools import wraps
|
||||
from rich.console import Console
|
||||
from dotenv import load_dotenv
|
||||
import copy
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
|
@ -465,16 +466,12 @@ def get_user_selections():
|
|||
with open("./cli/static/welcome.txt", "r", encoding="utf-8") as f:
|
||||
welcome_ascii = f.read()
|
||||
|
||||
# Create welcome box content
|
||||
welcome_content = f"{welcome_ascii}\n"
|
||||
welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n"
|
||||
welcome_content += "[bold]Workflow Steps:[/bold]\n"
|
||||
welcome_content += "I. Analyst Team → II. Research Team → III. Trader → IV. Risk Management → V. Portfolio Management\n\n"
|
||||
welcome_content += (
|
||||
"[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]"
|
||||
)
|
||||
welcome_content += "[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]"
|
||||
|
||||
# Create and center the welcome box
|
||||
welcome_box = Panel(
|
||||
welcome_content,
|
||||
border_style="green",
|
||||
|
|
@ -484,13 +481,11 @@ def get_user_selections():
|
|||
)
|
||||
console.print(Align.center(welcome_box))
|
||||
console.print()
|
||||
console.print() # Add vertical space before announcements
|
||||
console.print()
|
||||
|
||||
# Fetch and display announcements (silent on failure)
|
||||
announcements = fetch_announcements()
|
||||
display_announcements(console, announcements)
|
||||
|
||||
# Create a boxed questionnaire for each step
|
||||
def create_question_box(title, prompt, default=None):
|
||||
box_content = f"[bold]{title}[/bold]\n"
|
||||
box_content += f"[dim]{prompt}[/dim]"
|
||||
|
|
@ -498,7 +493,7 @@ def get_user_selections():
|
|||
box_content += f"\n[dim]Default: {default}[/dim]"
|
||||
return Panel(box_content, border_style="blue", padding=(1, 2))
|
||||
|
||||
# Step 1: Ticker symbol
|
||||
# Step 1: Ticker
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY"
|
||||
|
|
@ -517,10 +512,11 @@ def get_user_selections():
|
|||
)
|
||||
analysis_date = get_analysis_date()
|
||||
|
||||
# Step 3: Select analysts
|
||||
# Step 3: Analysts
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis"
|
||||
"Step 3: Analysts Team",
|
||||
"Select your LLM analyst agents for the analysis",
|
||||
)
|
||||
)
|
||||
selected_analysts = select_analysts()
|
||||
|
|
@ -536,54 +532,117 @@ def get_user_selections():
|
|||
)
|
||||
selected_research_depth = select_research_depth()
|
||||
|
||||
# Step 5: OpenAI backend
|
||||
# Step 5: Routing mode
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 5: OpenAI backend", "Select which service to talk to"
|
||||
"Step 5: LLM Routing",
|
||||
"Choose whether to use one LLM for all agents or customize by stage",
|
||||
)
|
||||
)
|
||||
selected_llm_provider, backend_url = select_llm_provider()
|
||||
|
||||
# Step 6: Thinking agents
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 6: Thinking Agents", "Select your thinking agents for analysis"
|
||||
)
|
||||
)
|
||||
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
||||
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
|
||||
routing_mode = select_routing_mode()
|
||||
|
||||
# Step 7: Provider-specific thinking configuration
|
||||
llm_routing = {"default": None, "roles": {}}
|
||||
|
||||
# Provider-specific thinking configuration
|
||||
thinking_level = None
|
||||
reasoning_effort = None
|
||||
|
||||
provider_lower = selected_llm_provider.lower()
|
||||
if provider_lower == "google":
|
||||
if routing_mode == "single":
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 7: Thinking Mode",
|
||||
"Configure Gemini thinking mode"
|
||||
"Step 6: Default LLM",
|
||||
"Select one provider/model for all agents",
|
||||
)
|
||||
)
|
||||
thinking_level = ask_gemini_thinking_config()
|
||||
elif provider_lower == "openai":
|
||||
default_llm = select_llm_bundle("All Agents")
|
||||
llm_routing["default"] = default_llm
|
||||
|
||||
provider = default_llm["provider"]
|
||||
if provider == "google":
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 7: Thinking Mode",
|
||||
"Configure Google/Gemini thinking mode",
|
||||
)
|
||||
)
|
||||
thinking_level = ask_gemini_thinking_config()
|
||||
elif provider == "openai":
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 7: Reasoning Effort",
|
||||
"Configure OpenAI reasoning effort level",
|
||||
)
|
||||
)
|
||||
reasoning_effort = ask_openai_reasoning_effort()
|
||||
|
||||
elif routing_mode == "stage":
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 7: Reasoning Effort",
|
||||
"Configure OpenAI reasoning effort level"
|
||||
"Step 6: Default LLM",
|
||||
"Select the default provider/model used unless a stage overrides it",
|
||||
)
|
||||
)
|
||||
reasoning_effort = ask_openai_reasoning_effort()
|
||||
default_llm = select_llm_bundle("Default LLM")
|
||||
llm_routing["default"] = default_llm
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 7: Stage Overrides",
|
||||
"Select LLMs for each stage",
|
||||
)
|
||||
)
|
||||
analyst_llm = select_llm_bundle("Analyst Team", depth_hint="quick")
|
||||
research_llm = select_llm_bundle("Research Team", depth_hint="deep")
|
||||
trader_llm = select_llm_bundle("Trader", depth_hint="deep")
|
||||
risk_llm = select_llm_bundle("Risk Team", depth_hint="deep")
|
||||
pm_llm = select_llm_bundle("Portfolio Manager", depth_hint="deep")
|
||||
|
||||
for role in ["market", "social", "news", "fundamentals"]:
|
||||
llm_routing["roles"][role] = analyst_llm
|
||||
|
||||
for role in ["bull_researcher", "bear_researcher", "research_manager"]:
|
||||
llm_routing["roles"][role] = research_llm
|
||||
|
||||
llm_routing["roles"]["trader"] = trader_llm
|
||||
|
||||
for role in ["aggressive_analyst", "neutral_analyst", "conservative_analyst"]:
|
||||
llm_routing["roles"][role] = risk_llm
|
||||
|
||||
llm_routing["roles"]["portfolio_manager"] = pm_llm
|
||||
|
||||
providers_used = {
|
||||
cfg["provider"]
|
||||
for cfg in [default_llm, analyst_llm, research_llm, trader_llm, risk_llm, pm_llm]
|
||||
if cfg is not None and "provider" in cfg
|
||||
}
|
||||
|
||||
if "google" in providers_used:
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 8: Thinking Mode",
|
||||
"Configure Google/Gemini thinking mode",
|
||||
)
|
||||
)
|
||||
thinking_level = ask_gemini_thinking_config()
|
||||
|
||||
if "openai" in providers_used:
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 8: Reasoning Effort",
|
||||
"Configure OpenAI reasoning effort level",
|
||||
)
|
||||
)
|
||||
reasoning_effort = ask_openai_reasoning_effort()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported routing mode: {routing_mode}")
|
||||
|
||||
return {
|
||||
"ticker": selected_ticker,
|
||||
"analysis_date": analysis_date,
|
||||
"analysts": selected_analysts,
|
||||
"research_depth": selected_research_depth,
|
||||
"llm_provider": selected_llm_provider.lower(),
|
||||
"backend_url": backend_url,
|
||||
"shallow_thinker": selected_shallow_thinker,
|
||||
"deep_thinker": selected_deep_thinker,
|
||||
"llm_routing": llm_routing,
|
||||
"google_thinking_level": thinking_level,
|
||||
"openai_reasoning_effort": reasoning_effort,
|
||||
}
|
||||
|
|
@ -612,33 +671,62 @@ def get_analysis_date():
|
|||
"[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]"
|
||||
)
|
||||
|
||||
def get_role_llm_label(llm_routing: dict, role: str) -> str:
|
||||
"""Return a human-readable provider/model label for a role."""
|
||||
default_cfg = llm_routing.get("default", {})
|
||||
role_cfg = llm_routing.get("roles", {}).get(role) or default_cfg
|
||||
|
||||
def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
||||
"""Save complete analysis report to disk with organized subfolders."""
|
||||
provider = role_cfg.get("provider", "unknown")
|
||||
model = role_cfg.get("model", "unknown")
|
||||
return f"{provider} / {model}"
|
||||
|
||||
|
||||
def with_llm_header(content: str, llm_label: str) -> str:
|
||||
"""Prefix saved markdown content with LLM metadata."""
|
||||
return f"> LLM: {llm_label}\n\n{content}"
|
||||
|
||||
def save_report_to_disk(final_state, ticker: str, save_path: Path, llm_routing: dict):
|
||||
"""Save complete analysis report to disk with organized subfolders and LLM metadata."""
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
sections = []
|
||||
|
||||
# 1. Analysts
|
||||
analysts_dir = save_path / "1_analysts"
|
||||
analyst_parts = []
|
||||
|
||||
if final_state.get("market_report"):
|
||||
analysts_dir.mkdir(exist_ok=True)
|
||||
(analysts_dir / "market.md").write_text(final_state["market_report"])
|
||||
analyst_parts.append(("Market Analyst", final_state["market_report"]))
|
||||
llm_label = get_role_llm_label(llm_routing, "market")
|
||||
content = with_llm_header(final_state["market_report"], llm_label)
|
||||
(analysts_dir / "market.md").write_text(content, encoding="utf-8")
|
||||
analyst_parts.append(("Market Analyst", llm_label, final_state["market_report"]))
|
||||
|
||||
if final_state.get("sentiment_report"):
|
||||
analysts_dir.mkdir(exist_ok=True)
|
||||
(analysts_dir / "sentiment.md").write_text(final_state["sentiment_report"])
|
||||
analyst_parts.append(("Social Analyst", final_state["sentiment_report"]))
|
||||
llm_label = get_role_llm_label(llm_routing, "social")
|
||||
content = with_llm_header(final_state["sentiment_report"], llm_label)
|
||||
(analysts_dir / "sentiment.md").write_text(content, encoding="utf-8")
|
||||
analyst_parts.append(("Social Analyst", llm_label, final_state["sentiment_report"]))
|
||||
|
||||
if final_state.get("news_report"):
|
||||
analysts_dir.mkdir(exist_ok=True)
|
||||
(analysts_dir / "news.md").write_text(final_state["news_report"])
|
||||
analyst_parts.append(("News Analyst", final_state["news_report"]))
|
||||
llm_label = get_role_llm_label(llm_routing, "news")
|
||||
content = with_llm_header(final_state["news_report"], llm_label)
|
||||
(analysts_dir / "news.md").write_text(content, encoding="utf-8")
|
||||
analyst_parts.append(("News Analyst", llm_label, final_state["news_report"]))
|
||||
|
||||
if final_state.get("fundamentals_report"):
|
||||
analysts_dir.mkdir(exist_ok=True)
|
||||
(analysts_dir / "fundamentals.md").write_text(final_state["fundamentals_report"])
|
||||
analyst_parts.append(("Fundamentals Analyst", final_state["fundamentals_report"]))
|
||||
llm_label = get_role_llm_label(llm_routing, "fundamentals")
|
||||
content = with_llm_header(final_state["fundamentals_report"], llm_label)
|
||||
(analysts_dir / "fundamentals.md").write_text(content, encoding="utf-8")
|
||||
analyst_parts.append(("Fundamentals Analyst", llm_label, final_state["fundamentals_report"]))
|
||||
|
||||
if analyst_parts:
|
||||
content = "\n\n".join(f"### {name}\n{text}" for name, text in analyst_parts)
|
||||
content = "\n\n".join(
|
||||
f"### {name}\n*LLM: {llm_label}*\n\n{text}"
|
||||
for name, llm_label, text in analyst_parts
|
||||
)
|
||||
sections.append(f"## I. Analyst Team Reports\n\n{content}")
|
||||
|
||||
# 2. Research
|
||||
|
|
@ -646,122 +734,168 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
|||
research_dir = save_path / "2_research"
|
||||
debate = final_state["investment_debate_state"]
|
||||
research_parts = []
|
||||
|
||||
if debate.get("bull_history"):
|
||||
research_dir.mkdir(exist_ok=True)
|
||||
(research_dir / "bull.md").write_text(debate["bull_history"])
|
||||
research_parts.append(("Bull Researcher", debate["bull_history"]))
|
||||
llm_label = get_role_llm_label(llm_routing, "bull_researcher")
|
||||
content = with_llm_header(debate["bull_history"], llm_label)
|
||||
(research_dir / "bull.md").write_text(content, encoding="utf-8")
|
||||
research_parts.append(("Bull Researcher", llm_label, debate["bull_history"]))
|
||||
|
||||
if debate.get("bear_history"):
|
||||
research_dir.mkdir(exist_ok=True)
|
||||
(research_dir / "bear.md").write_text(debate["bear_history"])
|
||||
research_parts.append(("Bear Researcher", debate["bear_history"]))
|
||||
llm_label = get_role_llm_label(llm_routing, "bear_researcher")
|
||||
content = with_llm_header(debate["bear_history"], llm_label)
|
||||
(research_dir / "bear.md").write_text(content, encoding="utf-8")
|
||||
research_parts.append(("Bear Researcher", llm_label, debate["bear_history"]))
|
||||
|
||||
if debate.get("judge_decision"):
|
||||
research_dir.mkdir(exist_ok=True)
|
||||
(research_dir / "manager.md").write_text(debate["judge_decision"])
|
||||
research_parts.append(("Research Manager", debate["judge_decision"]))
|
||||
llm_label = get_role_llm_label(llm_routing, "research_manager")
|
||||
content = with_llm_header(debate["judge_decision"], llm_label)
|
||||
(research_dir / "manager.md").write_text(content, encoding="utf-8")
|
||||
research_parts.append(("Research Manager", llm_label, debate["judge_decision"]))
|
||||
|
||||
if research_parts:
|
||||
content = "\n\n".join(f"### {name}\n{text}" for name, text in research_parts)
|
||||
content = "\n\n".join(
|
||||
f"### {name}\n*LLM: {llm_label}*\n\n{text}"
|
||||
for name, llm_label, text in research_parts
|
||||
)
|
||||
sections.append(f"## II. Research Team Decision\n\n{content}")
|
||||
|
||||
# 3. Trading
|
||||
if final_state.get("trader_investment_plan"):
|
||||
trading_dir = save_path / "3_trading"
|
||||
trading_dir.mkdir(exist_ok=True)
|
||||
(trading_dir / "trader.md").write_text(final_state["trader_investment_plan"])
|
||||
sections.append(f"## III. Trading Team Plan\n\n### Trader\n{final_state['trader_investment_plan']}")
|
||||
llm_label = get_role_llm_label(llm_routing, "trader")
|
||||
content = with_llm_header(final_state["trader_investment_plan"], llm_label)
|
||||
(trading_dir / "trader.md").write_text(content, encoding="utf-8")
|
||||
sections.append(
|
||||
f"## III. Trading Team Plan\n\n### Trader\n*LLM: {llm_label}*\n\n{final_state['trader_investment_plan']}"
|
||||
)
|
||||
|
||||
# 4. Risk Management
|
||||
if final_state.get("risk_debate_state"):
|
||||
risk_dir = save_path / "4_risk"
|
||||
risk = final_state["risk_debate_state"]
|
||||
risk_parts = []
|
||||
|
||||
if risk.get("aggressive_history"):
|
||||
risk_dir.mkdir(exist_ok=True)
|
||||
(risk_dir / "aggressive.md").write_text(risk["aggressive_history"])
|
||||
risk_parts.append(("Aggressive Analyst", risk["aggressive_history"]))
|
||||
llm_label = get_role_llm_label(llm_routing, "aggressive_analyst")
|
||||
content = with_llm_header(risk["aggressive_history"], llm_label)
|
||||
(risk_dir / "aggressive.md").write_text(content, encoding="utf-8")
|
||||
risk_parts.append(("Aggressive Analyst", llm_label, risk["aggressive_history"]))
|
||||
|
||||
if risk.get("conservative_history"):
|
||||
risk_dir.mkdir(exist_ok=True)
|
||||
(risk_dir / "conservative.md").write_text(risk["conservative_history"])
|
||||
risk_parts.append(("Conservative Analyst", risk["conservative_history"]))
|
||||
llm_label = get_role_llm_label(llm_routing, "conservative_analyst")
|
||||
content = with_llm_header(risk["conservative_history"], llm_label)
|
||||
(risk_dir / "conservative.md").write_text(content, encoding="utf-8")
|
||||
risk_parts.append(("Conservative Analyst", llm_label, risk["conservative_history"]))
|
||||
|
||||
if risk.get("neutral_history"):
|
||||
risk_dir.mkdir(exist_ok=True)
|
||||
(risk_dir / "neutral.md").write_text(risk["neutral_history"])
|
||||
risk_parts.append(("Neutral Analyst", risk["neutral_history"]))
|
||||
llm_label = get_role_llm_label(llm_routing, "neutral_analyst")
|
||||
content = with_llm_header(risk["neutral_history"], llm_label)
|
||||
(risk_dir / "neutral.md").write_text(content, encoding="utf-8")
|
||||
risk_parts.append(("Neutral Analyst", llm_label, risk["neutral_history"]))
|
||||
|
||||
if risk_parts:
|
||||
content = "\n\n".join(f"### {name}\n{text}" for name, text in risk_parts)
|
||||
content = "\n\n".join(
|
||||
f"### {name}\n*LLM: {llm_label}*\n\n{text}"
|
||||
for name, llm_label, text in risk_parts
|
||||
)
|
||||
sections.append(f"## IV. Risk Management Team Decision\n\n{content}")
|
||||
|
||||
# 5. Portfolio Manager
|
||||
if risk.get("judge_decision"):
|
||||
portfolio_dir = save_path / "5_portfolio"
|
||||
portfolio_dir.mkdir(exist_ok=True)
|
||||
(portfolio_dir / "decision.md").write_text(risk["judge_decision"])
|
||||
sections.append(f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n{risk['judge_decision']}")
|
||||
llm_label = get_role_llm_label(llm_routing, "portfolio_manager")
|
||||
content = with_llm_header(risk["judge_decision"], llm_label)
|
||||
(portfolio_dir / "decision.md").write_text(content, encoding="utf-8")
|
||||
sections.append(
|
||||
f"## V. Portfolio Manager Decision\n\n### Portfolio Manager\n*LLM: {llm_label}*\n\n{risk['judge_decision']}"
|
||||
)
|
||||
|
||||
# Write consolidated report
|
||||
header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||
(save_path / "complete_report.md").write_text(header + "\n\n".join(sections))
|
||||
header = (
|
||||
f"# Trading Analysis Report: {ticker}\n\n"
|
||||
f"Generated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||
)
|
||||
(save_path / "complete_report.md").write_text(header + "\n\n".join(sections), encoding="utf-8")
|
||||
return save_path / "complete_report.md"
|
||||
|
||||
|
||||
def display_complete_report(final_state):
|
||||
"""Display the complete analysis report sequentially (avoids truncation)."""
|
||||
def display_complete_report(final_state, llm_routing: dict):
|
||||
"""Display the complete analysis report sequentially with LLM metadata."""
|
||||
console.print()
|
||||
console.print(Rule("Complete Analysis Report", style="bold green"))
|
||||
|
||||
# I. Analyst Team Reports
|
||||
analysts = []
|
||||
if final_state.get("market_report"):
|
||||
analysts.append(("Market Analyst", final_state["market_report"]))
|
||||
analysts.append(("Market Analyst", get_role_llm_label(llm_routing, "market"), final_state["market_report"]))
|
||||
if final_state.get("sentiment_report"):
|
||||
analysts.append(("Social Analyst", final_state["sentiment_report"]))
|
||||
analysts.append(("Social Analyst", get_role_llm_label(llm_routing, "social"), final_state["sentiment_report"]))
|
||||
if final_state.get("news_report"):
|
||||
analysts.append(("News Analyst", final_state["news_report"]))
|
||||
analysts.append(("News Analyst", get_role_llm_label(llm_routing, "news"), final_state["news_report"]))
|
||||
if final_state.get("fundamentals_report"):
|
||||
analysts.append(("Fundamentals Analyst", final_state["fundamentals_report"]))
|
||||
analysts.append(("Fundamentals Analyst", get_role_llm_label(llm_routing, "fundamentals"), final_state["fundamentals_report"]))
|
||||
|
||||
if analysts:
|
||||
console.print(Panel("[bold]I. Analyst Team Reports[/bold]", border_style="cyan"))
|
||||
for title, content in analysts:
|
||||
console.print(Panel(Markdown(content), title=title, border_style="blue", padding=(1, 2)))
|
||||
for title, llm_label, content in analysts:
|
||||
body = f"*LLM: {llm_label}*\n\n{content}"
|
||||
console.print(Panel(Markdown(body), title=title, border_style="blue", padding=(1, 2)))
|
||||
|
||||
# II. Research Team Reports
|
||||
if final_state.get("investment_debate_state"):
|
||||
debate = final_state["investment_debate_state"]
|
||||
research = []
|
||||
if debate.get("bull_history"):
|
||||
research.append(("Bull Researcher", debate["bull_history"]))
|
||||
research.append(("Bull Researcher", get_role_llm_label(llm_routing, "bull_researcher"), debate["bull_history"]))
|
||||
if debate.get("bear_history"):
|
||||
research.append(("Bear Researcher", debate["bear_history"]))
|
||||
research.append(("Bear Researcher", get_role_llm_label(llm_routing, "bear_researcher"), debate["bear_history"]))
|
||||
if debate.get("judge_decision"):
|
||||
research.append(("Research Manager", debate["judge_decision"]))
|
||||
research.append(("Research Manager", get_role_llm_label(llm_routing, "research_manager"), debate["judge_decision"]))
|
||||
|
||||
if research:
|
||||
console.print(Panel("[bold]II. Research Team Decision[/bold]", border_style="magenta"))
|
||||
for title, content in research:
|
||||
console.print(Panel(Markdown(content), title=title, border_style="blue", padding=(1, 2)))
|
||||
for title, llm_label, content in research:
|
||||
body = f"*LLM: {llm_label}*\n\n{content}"
|
||||
console.print(Panel(Markdown(body), title=title, border_style="blue", padding=(1, 2)))
|
||||
|
||||
# III. Trading Team
|
||||
if final_state.get("trader_investment_plan"):
|
||||
llm_label = get_role_llm_label(llm_routing, "trader")
|
||||
console.print(Panel("[bold]III. Trading Team Plan[/bold]", border_style="yellow"))
|
||||
console.print(Panel(Markdown(final_state["trader_investment_plan"]), title="Trader", border_style="blue", padding=(1, 2)))
|
||||
body = f"*LLM: {llm_label}*\n\n{final_state['trader_investment_plan']}"
|
||||
console.print(Panel(Markdown(body), title="Trader", border_style="blue", padding=(1, 2)))
|
||||
|
||||
# IV. Risk Management Team
|
||||
if final_state.get("risk_debate_state"):
|
||||
risk = final_state["risk_debate_state"]
|
||||
risk_reports = []
|
||||
if risk.get("aggressive_history"):
|
||||
risk_reports.append(("Aggressive Analyst", risk["aggressive_history"]))
|
||||
risk_reports.append(("Aggressive Analyst", get_role_llm_label(llm_routing, "aggressive_analyst"), risk["aggressive_history"]))
|
||||
if risk.get("conservative_history"):
|
||||
risk_reports.append(("Conservative Analyst", risk["conservative_history"]))
|
||||
risk_reports.append(("Conservative Analyst", get_role_llm_label(llm_routing, "conservative_analyst"), risk["conservative_history"]))
|
||||
if risk.get("neutral_history"):
|
||||
risk_reports.append(("Neutral Analyst", risk["neutral_history"]))
|
||||
risk_reports.append(("Neutral Analyst", get_role_llm_label(llm_routing, "neutral_analyst"), risk["neutral_history"]))
|
||||
|
||||
if risk_reports:
|
||||
console.print(Panel("[bold]IV. Risk Management Team Decision[/bold]", border_style="red"))
|
||||
for title, content in risk_reports:
|
||||
console.print(Panel(Markdown(content), title=title, border_style="blue", padding=(1, 2)))
|
||||
for title, llm_label, content in risk_reports:
|
||||
body = f"*LLM: {llm_label}*\n\n{content}"
|
||||
console.print(Panel(Markdown(body), title=title, border_style="blue", padding=(1, 2)))
|
||||
|
||||
# V. Portfolio Manager Decision
|
||||
if risk.get("judge_decision"):
|
||||
llm_label = get_role_llm_label(llm_routing, "portfolio_manager")
|
||||
console.print(Panel("[bold]V. Portfolio Manager Decision[/bold]", border_style="green"))
|
||||
console.print(Panel(Markdown(risk["judge_decision"]), title="Portfolio Manager", border_style="blue", padding=(1, 2)))
|
||||
body = f"*LLM: {llm_label}*\n\n{risk['judge_decision']}"
|
||||
console.print(Panel(Markdown(body), title="Portfolio Manager", border_style="blue", padding=(1, 2)))
|
||||
|
||||
|
||||
def update_research_team_status(status):
|
||||
|
|
@ -901,14 +1035,10 @@ def run_analysis():
|
|||
selections = get_user_selections()
|
||||
|
||||
# Create config with selected research depth
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config = copy.deepcopy(DEFAULT_CONFIG)
|
||||
config["max_debate_rounds"] = selections["research_depth"]
|
||||
config["max_risk_discuss_rounds"] = selections["research_depth"]
|
||||
config["quick_think_llm"] = selections["shallow_thinker"]
|
||||
config["deep_think_llm"] = selections["deep_thinker"]
|
||||
config["backend_url"] = selections["backend_url"]
|
||||
config["llm_provider"] = selections["llm_provider"].lower()
|
||||
# Provider-specific thinking configuration
|
||||
config["llm_routing"] = selections["llm_routing"]
|
||||
config["google_thinking_level"] = selections.get("google_thinking_level")
|
||||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||
|
||||
|
|
@ -1155,7 +1285,12 @@ def run_analysis():
|
|||
).strip()
|
||||
save_path = Path(save_path_str)
|
||||
try:
|
||||
report_file = save_report_to_disk(final_state, selections["ticker"], save_path)
|
||||
report_file = save_report_to_disk(
|
||||
final_state,
|
||||
selections["ticker"],
|
||||
save_path,
|
||||
config["llm_routing"],
|
||||
)
|
||||
console.print(f"\n[green]✓ Report saved to:[/green] {save_path.resolve()}")
|
||||
console.print(f" [dim]Complete report:[/dim] {report_file.name}")
|
||||
except Exception as e:
|
||||
|
|
@ -1164,7 +1299,7 @@ def run_analysis():
|
|||
# Prompt to display full report
|
||||
display_choice = typer.prompt("\nDisplay full report on screen?", default="Y").strip().upper()
|
||||
if display_choice in ("Y", "YES", ""):
|
||||
display_complete_report(final_state)
|
||||
display_complete_report(final_state, config["llm_routing"])
|
||||
|
||||
|
||||
@app.command()
|
||||
|
|
|
|||
95
cli/utils.py
95
cli/utils.py
|
|
@ -126,12 +126,8 @@ def select_research_depth() -> int:
|
|||
return choice
|
||||
|
||||
|
||||
def select_shallow_thinking_agent(provider) -> str:
|
||||
"""Select shallow thinking llm engine using an interactive selection."""
|
||||
|
||||
# Define shallow thinking llm engine options with their corresponding model names
|
||||
# Ordering: medium → light → heavy (balanced first for quick tasks)
|
||||
# Within same tier, newer models first
|
||||
def select_shallow_thinking_agent(provider: str) -> str:
|
||||
"""Select quick-thinking LLM engine using an interactive selection."""
|
||||
SHALLOW_AGENT_OPTIONS = {
|
||||
"openai": [
|
||||
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
|
||||
|
|
@ -166,11 +162,15 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
],
|
||||
}
|
||||
|
||||
provider = provider.lower()
|
||||
if provider not in SHALLOW_AGENT_OPTIONS:
|
||||
raise ValueError(f"Unsupported quick-thinking provider: {provider}")
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Quick-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()]
|
||||
for display, value in SHALLOW_AGENT_OPTIONS[provider]
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
|
|
@ -183,20 +183,14 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print(
|
||||
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
|
||||
)
|
||||
console.print("\n[red]No quick-thinking LLM engine selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
return choice
|
||||
|
||||
|
||||
def select_deep_thinking_agent(provider) -> str:
|
||||
"""Select deep thinking llm engine using an interactive selection."""
|
||||
|
||||
# Define deep thinking llm engine options with their corresponding model names
|
||||
# Ordering: heavy → medium → light (most capable first for deep tasks)
|
||||
# Within same tier, newer models first
|
||||
def select_deep_thinking_agent(provider: str) -> str:
|
||||
"""Select deep-thinking LLM engine using an interactive selection."""
|
||||
DEEP_AGENT_OPTIONS = {
|
||||
"openai": [
|
||||
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
||||
|
|
@ -233,11 +227,15 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
],
|
||||
}
|
||||
|
||||
provider = provider.lower()
|
||||
if provider not in DEEP_AGENT_OPTIONS:
|
||||
raise ValueError(f"Unsupported deep-thinking provider: {provider}")
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Deep-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
|
||||
for display, value in DEEP_AGENT_OPTIONS[provider]
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
|
|
@ -250,27 +248,26 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
|
||||
console.print("\n[red]No deep-thinking LLM engine selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
return choice
|
||||
|
||||
def select_llm_provider() -> tuple[str, str]:
|
||||
"""Select the OpenAI api url using interactive selection."""
|
||||
# Define OpenAI api options with their corresponding endpoints
|
||||
def select_provider(label: str) -> tuple[str, str]:
|
||||
"""Select an LLM provider and base URL."""
|
||||
BASE_URLS = [
|
||||
("OpenAI", "https://api.openai.com/v1"),
|
||||
("Google", "https://generativelanguage.googleapis.com/v1"),
|
||||
("Anthropic", "https://api.anthropic.com/"),
|
||||
("Anthropic", "https://api.anthropic.com"),
|
||||
("xAI", "https://api.x.ai/v1"),
|
||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||
("OpenRouter", "https://openrouter.ai/api/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
]
|
||||
|
||||
|
||||
choice = questionary.select(
|
||||
"Select your LLM Provider:",
|
||||
f"Select your {label} Provider:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=(display, value))
|
||||
questionary.Choice(display, value=(display.lower(), value))
|
||||
for display, value in BASE_URLS
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
|
|
@ -282,15 +279,14 @@ def select_llm_provider() -> tuple[str, str]:
|
|||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
display_name, url = choice
|
||||
print(f"You selected: {display_name}\tURL: {url}")
|
||||
|
||||
return display_name, url
|
||||
if choice is None:
|
||||
console.print(f"\n[red]No {label.lower()} provider selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
provider, url = choice
|
||||
console.print(f"You selected {label}: [green]{provider}[/green] URL: {url}")
|
||||
return provider, url
|
||||
|
||||
|
||||
def ask_openai_reasoning_effort() -> str:
|
||||
|
|
@ -329,3 +325,34 @@ def ask_gemini_thinking_config() -> str | None:
|
|||
("pointer", "fg:green noinherit"),
|
||||
]),
|
||||
).ask()
|
||||
|
||||
def select_routing_mode() -> str:
|
||||
choice = questionary.select(
|
||||
"How would you like to configure LLMs?",
|
||||
choices=[
|
||||
questionary.Choice("Use one LLM for all agents", "single"),
|
||||
questionary.Choice("Customize by stage", "stage"),
|
||||
],
|
||||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]No routing mode selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
return choice
|
||||
|
||||
def select_llm_bundle(label: str, depth_hint: str | None = None) -> dict:
|
||||
provider, base_url = select_provider(label)
|
||||
|
||||
if depth_hint == "quick":
|
||||
model = select_shallow_thinking_agent(provider)
|
||||
elif depth_hint == "deep":
|
||||
model = select_deep_thinking_agent(provider)
|
||||
else:
|
||||
model = select_deep_thinking_agent(provider)
|
||||
|
||||
return {
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
|
@ -7,28 +7,53 @@ DEFAULT_CONFIG = {
|
|||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"dataflows/data_cache",
|
||||
),
|
||||
# LLM settings
|
||||
"llm_provider": "openai",
|
||||
"deep_think_llm": "gpt-5.2",
|
||||
"quick_think_llm": "gpt-5-mini",
|
||||
"backend_url": "https://api.openai.com/v1",
|
||||
|
||||
# LLM routing settings
|
||||
# Default: all agents use the same provider/model unless explicitly overridden
|
||||
"llm_routing": {
|
||||
"default": {
|
||||
"provider": "openai",
|
||||
"model": "gpt-5-mini",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
},
|
||||
"roles": {
|
||||
# Optional per-role overrides.
|
||||
# Leave as None to inherit from "default".
|
||||
"market": None,
|
||||
"social": None,
|
||||
"news": None,
|
||||
"fundamentals": None,
|
||||
"bull_researcher": None,
|
||||
"bear_researcher": None,
|
||||
"research_manager": None,
|
||||
"trader": None,
|
||||
"aggressive_analyst": None,
|
||||
"neutral_analyst": None,
|
||||
"conservative_analyst": None,
|
||||
"portfolio_manager": None,
|
||||
},
|
||||
},
|
||||
|
||||
# Provider-specific thinking configuration
|
||||
"google_thinking_level": None, # "high", "minimal", etc.
|
||||
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||
# These apply whenever that provider is used.
|
||||
"google_thinking_level": None, # e.g. "high", "minimal"
|
||||
"openai_reasoning_effort": None, # e.g. "medium", "high", "low"
|
||||
|
||||
# Debate and discussion settings
|
||||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
"max_recur_limit": 100,
|
||||
|
||||
# Data vendor configuration
|
||||
# Category-level configuration (default for all tools in category)
|
||||
"data_vendors": {
|
||||
"core_stock_apis": "yfinance", # Options: alpha_vantage, yfinance
|
||||
"technical_indicators": "yfinance", # Options: alpha_vantage, yfinance
|
||||
"fundamental_data": "yfinance", # Options: alpha_vantage, yfinance
|
||||
"news_data": "yfinance", # Options: alpha_vantage, yfinance
|
||||
"core_stock_apis": "yfinance",
|
||||
"technical_indicators": "yfinance",
|
||||
"fundamental_data": "yfinance",
|
||||
"news_data": "yfinance",
|
||||
},
|
||||
|
||||
# Tool-level configuration (takes precedence over category-level)
|
||||
"tool_vendors": {
|
||||
# Example: "get_stock_data": "alpha_vantage", # Override category default
|
||||
# Example: "get_stock_data": "alpha_vantage",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,4 @@
|
|||
# TradingAgents/graph/setup.py
|
||||
|
||||
from typing import Dict, Any
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, StateGraph, START
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
|
|
@ -16,8 +13,7 @@ class GraphSetup:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
quick_thinking_llm: ChatOpenAI,
|
||||
deep_thinking_llm: ChatOpenAI,
|
||||
role_llms: Dict[str, Any],
|
||||
tool_nodes: Dict[str, ToolNode],
|
||||
bull_memory,
|
||||
bear_memory,
|
||||
|
|
@ -27,8 +23,7 @@ class GraphSetup:
|
|||
conditional_logic: ConditionalLogic,
|
||||
):
|
||||
"""Initialize with required components."""
|
||||
self.quick_thinking_llm = quick_thinking_llm
|
||||
self.deep_thinking_llm = deep_thinking_llm
|
||||
self.role_llms = role_llms
|
||||
self.tool_nodes = tool_nodes
|
||||
self.bull_memory = bull_memory
|
||||
self.bear_memory = bear_memory
|
||||
|
|
@ -37,78 +32,78 @@ class GraphSetup:
|
|||
self.risk_manager_memory = risk_manager_memory
|
||||
self.conditional_logic = conditional_logic
|
||||
|
||||
def _get_llm(self, role: str):
|
||||
if role not in self.role_llms:
|
||||
raise ValueError(f"Missing LLM assignment for role: {role}")
|
||||
return self.role_llms[role]
|
||||
|
||||
def setup_graph(
|
||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
||||
):
|
||||
"""Set up and compile the agent workflow graph.
|
||||
|
||||
Args:
|
||||
selected_analysts (list): List of analyst types to include. Options are:
|
||||
- "market": Market analyst
|
||||
- "social": Social media analyst
|
||||
- "news": News analyst
|
||||
- "fundamentals": Fundamentals analyst
|
||||
"""
|
||||
"""Set up and compile the agent workflow graph."""
|
||||
if len(selected_analysts) == 0:
|
||||
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
||||
|
||||
# Create analyst nodes
|
||||
analyst_nodes = {}
|
||||
delete_nodes = {}
|
||||
tool_nodes = {}
|
||||
|
||||
if "market" in selected_analysts:
|
||||
analyst_nodes["market"] = create_market_analyst(
|
||||
self.quick_thinking_llm
|
||||
self._get_llm("market")
|
||||
)
|
||||
delete_nodes["market"] = create_msg_delete()
|
||||
tool_nodes["market"] = self.tool_nodes["market"]
|
||||
|
||||
if "social" in selected_analysts:
|
||||
analyst_nodes["social"] = create_social_media_analyst(
|
||||
self.quick_thinking_llm
|
||||
self._get_llm("social")
|
||||
)
|
||||
delete_nodes["social"] = create_msg_delete()
|
||||
tool_nodes["social"] = self.tool_nodes["social"]
|
||||
|
||||
if "news" in selected_analysts:
|
||||
analyst_nodes["news"] = create_news_analyst(
|
||||
self.quick_thinking_llm
|
||||
self._get_llm("news")
|
||||
)
|
||||
delete_nodes["news"] = create_msg_delete()
|
||||
tool_nodes["news"] = self.tool_nodes["news"]
|
||||
|
||||
if "fundamentals" in selected_analysts:
|
||||
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
|
||||
self.quick_thinking_llm
|
||||
self._get_llm("fundamentals")
|
||||
)
|
||||
delete_nodes["fundamentals"] = create_msg_delete()
|
||||
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
||||
|
||||
# Create researcher and manager nodes
|
||||
bull_researcher_node = create_bull_researcher(
|
||||
self.quick_thinking_llm, self.bull_memory
|
||||
self._get_llm("bull_researcher"), self.bull_memory
|
||||
)
|
||||
bear_researcher_node = create_bear_researcher(
|
||||
self.quick_thinking_llm, self.bear_memory
|
||||
self._get_llm("bear_researcher"), self.bear_memory
|
||||
)
|
||||
research_manager_node = create_research_manager(
|
||||
self.deep_thinking_llm, self.invest_judge_memory
|
||||
self._get_llm("research_manager"), self.invest_judge_memory
|
||||
)
|
||||
trader_node = create_trader(
|
||||
self._get_llm("trader"), self.trader_memory
|
||||
)
|
||||
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
|
||||
|
||||
# Create risk analysis nodes
|
||||
aggressive_analyst = create_aggressive_debator(self.quick_thinking_llm)
|
||||
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
|
||||
conservative_analyst = create_conservative_debator(self.quick_thinking_llm)
|
||||
aggressive_analyst = create_aggressive_debator(
|
||||
self._get_llm("aggressive_analyst")
|
||||
)
|
||||
neutral_analyst = create_neutral_debator(
|
||||
self._get_llm("neutral_analyst")
|
||||
)
|
||||
conservative_analyst = create_conservative_debator(
|
||||
self._get_llm("conservative_analyst")
|
||||
)
|
||||
risk_manager_node = create_risk_manager(
|
||||
self.deep_thinking_llm, self.risk_manager_memory
|
||||
self._get_llm("portfolio_manager"), self.risk_manager_memory
|
||||
)
|
||||
|
||||
# Create workflow
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Add analyst nodes to the graph
|
||||
for analyst_type, node in analyst_nodes.items():
|
||||
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
|
||||
workflow.add_node(
|
||||
|
|
@ -116,7 +111,6 @@ class GraphSetup:
|
|||
)
|
||||
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
|
||||
|
||||
# Add other nodes
|
||||
workflow.add_node("Bull Researcher", bull_researcher_node)
|
||||
workflow.add_node("Bear Researcher", bear_researcher_node)
|
||||
workflow.add_node("Research Manager", research_manager_node)
|
||||
|
|
@ -126,18 +120,14 @@ class GraphSetup:
|
|||
workflow.add_node("Conservative Analyst", conservative_analyst)
|
||||
workflow.add_node("Risk Judge", risk_manager_node)
|
||||
|
||||
# Define edges
|
||||
# Start with the first analyst
|
||||
first_analyst = selected_analysts[0]
|
||||
workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst")
|
||||
|
||||
# Connect analysts in sequence
|
||||
for i, analyst_type in enumerate(selected_analysts):
|
||||
current_analyst = f"{analyst_type.capitalize()} Analyst"
|
||||
current_tools = f"tools_{analyst_type}"
|
||||
current_clear = f"Msg Clear {analyst_type.capitalize()}"
|
||||
|
||||
# Add conditional edges for current analyst
|
||||
workflow.add_conditional_edges(
|
||||
current_analyst,
|
||||
getattr(self.conditional_logic, f"should_continue_{analyst_type}"),
|
||||
|
|
@ -145,14 +135,12 @@ class GraphSetup:
|
|||
)
|
||||
workflow.add_edge(current_tools, current_analyst)
|
||||
|
||||
# Connect to next analyst or to Bull Researcher if this is the last analyst
|
||||
if i < len(selected_analysts) - 1:
|
||||
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
|
||||
workflow.add_edge(current_clear, next_analyst)
|
||||
else:
|
||||
workflow.add_edge(current_clear, "Bull Researcher")
|
||||
|
||||
# Add remaining edges
|
||||
workflow.add_conditional_edges(
|
||||
"Bull Researcher",
|
||||
self.conditional_logic.should_continue_debate,
|
||||
|
|
@ -198,5 +186,4 @@ class GraphSetup:
|
|||
|
||||
workflow.add_edge("Risk Judge", END)
|
||||
|
||||
# Compile and return
|
||||
return workflow.compile()
|
||||
return workflow.compile()
|
||||
|
|
@ -71,28 +71,51 @@ class TradingAgentsGraph:
|
|||
exist_ok=True,
|
||||
)
|
||||
|
||||
# Initialize LLMs with provider-specific thinking configuration
|
||||
llm_kwargs = self._get_provider_kwargs()
|
||||
# Initialize LLMs from llm_routing.
|
||||
llm_routing = self.config.get("llm_routing", {})
|
||||
default_llm_cfg = llm_routing.get("default")
|
||||
role_overrides = llm_routing.get("roles", {})
|
||||
|
||||
# Add callbacks to kwargs if provided (passed to LLM constructor)
|
||||
if self.callbacks:
|
||||
llm_kwargs["callbacks"] = self.callbacks
|
||||
if default_llm_cfg is None:
|
||||
raise ValueError("config['llm_routing']['default'] must be set")
|
||||
|
||||
deep_client = create_llm_client(
|
||||
provider=self.config["llm_provider"],
|
||||
model=self.config["deep_think_llm"],
|
||||
base_url=self.config.get("backend_url"),
|
||||
**llm_kwargs,
|
||||
)
|
||||
quick_client = create_llm_client(
|
||||
provider=self.config["llm_provider"],
|
||||
model=self.config["quick_think_llm"],
|
||||
base_url=self.config.get("backend_url"),
|
||||
**llm_kwargs,
|
||||
)
|
||||
all_roles = [
|
||||
"market",
|
||||
"social",
|
||||
"news",
|
||||
"fundamentals",
|
||||
"bull_researcher",
|
||||
"bear_researcher",
|
||||
"research_manager",
|
||||
"trader",
|
||||
"aggressive_analyst",
|
||||
"neutral_analyst",
|
||||
"conservative_analyst",
|
||||
"portfolio_manager",
|
||||
]
|
||||
|
||||
self.deep_thinking_llm = deep_client.get_llm()
|
||||
self.quick_thinking_llm = quick_client.get_llm()
|
||||
self.role_llms = {}
|
||||
|
||||
for role in all_roles:
|
||||
llm_cfg = role_overrides.get(role) or default_llm_cfg
|
||||
|
||||
provider = llm_cfg["provider"]
|
||||
model = llm_cfg["model"]
|
||||
base_url = llm_cfg.get("base_url")
|
||||
|
||||
llm_kwargs = self._get_provider_kwargs(provider)
|
||||
|
||||
if self.callbacks:
|
||||
llm_kwargs["callbacks"] = self.callbacks
|
||||
|
||||
client = create_llm_client(
|
||||
provider=provider,
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
**llm_kwargs,
|
||||
)
|
||||
|
||||
self.role_llms[role] = client.get_llm()
|
||||
|
||||
# Initialize memories
|
||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||
|
|
@ -110,8 +133,7 @@ class TradingAgentsGraph:
|
|||
max_risk_discuss_rounds=self.config["max_risk_discuss_rounds"],
|
||||
)
|
||||
self.graph_setup = GraphSetup(
|
||||
self.quick_thinking_llm,
|
||||
self.deep_thinking_llm,
|
||||
self.role_llms,
|
||||
self.tool_nodes,
|
||||
self.bull_memory,
|
||||
self.bear_memory,
|
||||
|
|
@ -122,8 +144,8 @@ class TradingAgentsGraph:
|
|||
)
|
||||
|
||||
self.propagator = Propagator()
|
||||
self.reflector = Reflector(self.quick_thinking_llm)
|
||||
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
|
||||
self.reflector = Reflector(self.role_llms["trader"])
|
||||
self.signal_processor = SignalProcessor(self.role_llms["portfolio_manager"])
|
||||
|
||||
# State tracking
|
||||
self.curr_state = None
|
||||
|
|
@ -132,11 +154,10 @@ class TradingAgentsGraph:
|
|||
|
||||
# Set up the graph
|
||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||||
|
||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||
def _get_provider_kwargs(self, provider: str) -> Dict[str, Any]:
|
||||
"""Get provider-specific kwargs for LLM client creation."""
|
||||
kwargs = {}
|
||||
provider = self.config.get("llm_provider", "").lower()
|
||||
provider = (provider or "").lower()
|
||||
|
||||
if provider == "google":
|
||||
thinking_level = self.config.get("google_thinking_level")
|
||||
|
|
@ -150,6 +171,7 @@ class TradingAgentsGraph:
|
|||
|
||||
return kwargs
|
||||
|
||||
|
||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||||
"""Create tool nodes for different data sources using abstract methods."""
|
||||
return {
|
||||
|
|
|
|||
Loading…
Reference in New Issue