TradingAgents/cli/main.py

541 lines
23 KiB
Python

from typing import Optional
import datetime
import typer
from pathlib import Path
from functools import wraps
from rich.console import Console
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
from rich.panel import Panel
from rich.spinner import Spinner
from rich.live import Live
from rich.columns import Columns
from rich.markdown import Markdown
from rich.layout import Layout
from rich.text import Text
from rich.live import Live
from rich.table import Table
from collections import deque
import time
from rich.tree import Tree
from rich import box
from rich.align import Align
from rich.rule import Rule
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from cli.models import AnalystType
from cli.utils import *
from cli.message_buffer import MessageBuffer
from cli.ui_display import create_layout, update_display
from cli.report_display import display_complete_report
from cli.helper_functions import update_research_team_status, extract_content_string
from cli.asset_detection import detect_asset_class, get_asset_class_display_name
console = Console()
app = typer.Typer(
name="TradingAgents",
help="TradingAgents CLI: Multi-Agents LLM Financial Trading Framework",
add_completion=True, # Enable shell completion
)
# Create a global message buffer instance
message_buffer = MessageBuffer()
def get_user_selections():
"""Get all user selections before starting the analysis display."""
# Load config to check for pre-configured values
from tradingagents.default_config import DEFAULT_CONFIG
# Display ASCII art welcome message
with open("./cli/static/welcome.txt", "r") as f:
welcome_ascii = f.read()
# Create welcome box content
welcome_content = f"{welcome_ascii}\n"
welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n"
welcome_content += "[bold]Workflow Steps:[/bold]\n"
welcome_content += "I. Analyst Team → II. Research Team → III. Trader → IV. Risk Management → V. Portfolio Management\n\n"
welcome_content += (
"[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]"
)
# Create and center the welcome box
welcome_box = Panel(
welcome_content,
border_style="green",
padding=(1, 2),
title="Welcome to TradingAgents",
subtitle="Multi-Agents LLM Financial Trading Framework",
)
console.print(Align.center(welcome_box))
console.print() # Add a blank line after the welcome box
# Create a boxed questionnaire for each step
def create_question_box(title, prompt, default=None):
box_content = f"[bold]{title}[/bold]\n"
box_content += f"[dim]{prompt}[/dim]"
if default:
box_content += f"\n[dim]Default: {default}[/dim]"
return Panel(box_content, border_style="blue", padding=(1, 2))
# Step 1: Ticker symbol
console.print(
create_question_box(
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY"
)
)
selected_ticker = get_ticker()
# Auto-detect asset class from ticker
asset_class = detect_asset_class(selected_ticker)
console.print(
f"[dim]→ Detected asset class: [bold]{get_asset_class_display_name(asset_class)}[/bold][/dim]\n"
)
# Step 2: Analysis date
default_date = datetime.datetime.now().strftime("%Y-%m-%d")
console.print(
create_question_box(
"Step 2: Analysis Date",
"Enter the analysis date (YYYY-MM-DD)",
default_date,
)
)
analysis_date = get_analysis_date()
# Step 3: Select analysts (use config default if available)
if "default_analysts" in DEFAULT_CONFIG and DEFAULT_CONFIG["default_analysts"]:
# Convert analyst names to AnalystType
analyst_map = {
"market": AnalystType.MARKET,
"news": AnalystType.NEWS,
"social": AnalystType.SOCIAL,
"fundamentals": AnalystType.FUNDAMENTALS,
}
selected_analysts = [analyst_map[a] for a in DEFAULT_CONFIG["default_analysts"] if a in analyst_map]
# Filter out fundamentals for commodities
if asset_class == "commodity":
selected_analysts = [a for a in selected_analysts if a != AnalystType.FUNDAMENTALS]
console.print(
f"[dim]→ Using configured analysts: [bold]{', '.join(a.value for a in selected_analysts)}[/bold][/dim]\n"
)
else:
console.print(
create_question_box(
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis"
)
)
selected_analysts = select_analysts(asset_class)
console.print(
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
)
# Step 4: Research depth (use config default if available)
if "max_debate_rounds" in DEFAULT_CONFIG:
selected_research_depth = DEFAULT_CONFIG["max_debate_rounds"]
depth_labels = {1: "Shallow", 2: "Medium", 3: "Deep"}
console.print(
f"[dim]→ Using configured research depth: [bold]{depth_labels.get(selected_research_depth, selected_research_depth)}[/bold][/dim]\n"
)
else:
console.print(
create_question_box(
"Step 4: Research Depth", "Select your research depth level"
)
)
selected_research_depth = select_research_depth()
# Step 5: LLM backend (use config default if available)
if "llm_provider" in DEFAULT_CONFIG and "backend_url" in DEFAULT_CONFIG:
selected_llm_provider = DEFAULT_CONFIG["llm_provider"]
backend_url = DEFAULT_CONFIG["backend_url"]
console.print(
f"[dim]→ Using configured LLM provider: [bold]{selected_llm_provider.upper()}[/bold] ({backend_url})[/dim]\n"
)
else:
console.print(
create_question_box(
"Step 5: LLM Backend", "Select which service to talk to"
)
)
selected_llm_provider, backend_url = select_llm_provider()
# Step 6: Thinking agents (use config defaults if available)
if "quick_think_llm" in DEFAULT_CONFIG and "deep_think_llm" in DEFAULT_CONFIG:
selected_shallow_thinker = DEFAULT_CONFIG["quick_think_llm"]
selected_deep_thinker = DEFAULT_CONFIG["deep_think_llm"]
console.print(
f"[dim]→ Using configured models:[/dim]\n"
f"[dim] Quick thinking: [bold]{selected_shallow_thinker}[/bold][/dim]\n"
f"[dim] Deep thinking: [bold]{selected_deep_thinker}[/bold][/dim]\n"
)
else:
console.print(
create_question_box(
"Step 6: Thinking Agents", "Select your thinking agents for analysis"
)
)
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
return {
"ticker": selected_ticker,
"analysis_date": analysis_date,
"asset_class": asset_class,
"analysts": selected_analysts,
"research_depth": selected_research_depth,
"llm_provider": selected_llm_provider.lower(),
"backend_url": backend_url,
"shallow_thinker": selected_shallow_thinker,
"deep_thinker": selected_deep_thinker,
}
def get_ticker():
"""Get ticker symbol from user input."""
return typer.prompt("", default="SPY")
def get_analysis_date():
"""Get the analysis date from user input."""
while True:
date_str = typer.prompt(
"", default=datetime.datetime.now().strftime("%Y-%m-%d")
)
try:
# Validate date format and ensure it's not in the future
analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
if analysis_date.date() > datetime.datetime.now().date():
console.print("[red]Error: Analysis date cannot be in the future[/red]")
continue
return date_str
except ValueError:
console.print(
"[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]"
)
def run_analysis():
# First get all user selections
selections = get_user_selections()
# Create config with selected research depth
config = DEFAULT_CONFIG.copy()
config["max_debate_rounds"] = selections["research_depth"]
config["max_risk_discuss_rounds"] = selections["research_depth"]
config["quick_think_llm"] = selections["shallow_thinker"]
config["deep_think_llm"] = selections["deep_thinker"]
config["backend_url"] = selections["backend_url"]
config["llm_provider"] = selections["llm_provider"].lower()
config["asset_class"] = selections["asset_class"]
# Initialize the graph
graph = TradingAgentsGraph(
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
)
# Create result directory
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
results_dir.mkdir(parents=True, exist_ok=True)
report_dir = results_dir / "reports"
report_dir.mkdir(parents=True, exist_ok=True)
log_file = results_dir / "message_tool.log"
log_file.touch(exist_ok=True)
def save_message_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(*args, **kwargs):
func(*args, **kwargs)
timestamp, message_type, content = obj.messages[-1]
content = content.replace("\n", " ") # Replace newlines with spaces
with open(log_file, "a") as f:
f.write(f"{timestamp} [{message_type}] {content}\n")
return wrapper
def save_tool_call_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(*args, **kwargs):
func(*args, **kwargs)
timestamp, tool_name, args = obj.tool_calls[-1]
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
with open(log_file, "a") as f:
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
return wrapper
def save_report_section_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(section_name, content):
func(section_name, content)
if section_name in obj.report_sections and obj.report_sections[section_name] is not None:
content = obj.report_sections[section_name]
if content:
file_name = f"{section_name}.md"
with open(report_dir / file_name, "w") as f:
f.write(content)
return wrapper
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call")
message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section")
# Now start the display layout
layout = create_layout()
with Live(layout, refresh_per_second=4) as live:
# Initial display
update_display(layout, message_buffer)
# Add initial messages
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
message_buffer.add_message(
"System", f"Analysis date: {selections['analysis_date']}"
)
message_buffer.add_message(
"System",
f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}",
)
update_display(layout, message_buffer)
# Reset agent statuses
for agent in message_buffer.agent_status:
message_buffer.update_agent_status(agent, "pending")
# Reset report sections
for section in message_buffer.report_sections:
message_buffer.report_sections[section] = None
message_buffer.current_report = None
message_buffer.final_report = None
# Update agent status to in_progress for the first analyst
first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst"
message_buffer.update_agent_status(first_analyst, "in_progress")
update_display(layout, message_buffer)
# Create spinner text
spinner_text = (
f"Analyzing {selections['ticker']} on {selections['analysis_date']}..."
)
update_display(layout, message_buffer, spinner_text)
# Initialize state and get graph args
init_agent_state = graph.propagator.create_initial_state(
selections["ticker"], selections["analysis_date"]
)
# CRITICAL: Add asset_class to state so market analyst can branch correctly
init_agent_state["asset_class"] = selections["asset_class"]
args = graph.propagator.get_graph_args()
# Stream the analysis
trace = []
for chunk in graph.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) > 0:
# Get the last message from the chunk
last_message = chunk["messages"][-1]
# Extract message content and type
if hasattr(last_message, "content"):
content = extract_content_string(last_message.content) # Use the helper function
msg_type = "Reasoning"
else:
content = str(last_message)
msg_type = "System"
# Add message to buffer
message_buffer.add_message(msg_type, content)
# If it's a tool call, add it to tool calls
if hasattr(last_message, "tool_calls"):
for tool_call in last_message.tool_calls:
# Handle both dictionary and object tool calls
if isinstance(tool_call, dict):
message_buffer.add_tool_call(
tool_call["name"], tool_call["args"]
)
else:
message_buffer.add_tool_call(tool_call.name, tool_call.args)
# Update reports and agent status based on chunk content
# Analyst Team Reports - use a mapping to reduce repetition
analyst_mappings = [
("market_report", "Market Analyst", "social", "Social Analyst"),
("sentiment_report", "Social Analyst", "news", "News Analyst"),
("news_report", "News Analyst", "fundamentals", "Fundamentals Analyst"),
("fundamentals_report", "Fundamentals Analyst", None, None),
]
for report_key, analyst_name, next_type, next_analyst in analyst_mappings:
if report_key in chunk and chunk[report_key]:
message_buffer.update_report_section(report_key, chunk[report_key])
message_buffer.update_agent_status(analyst_name, "completed")
if report_key == "fundamentals_report":
# Special case: set all research team to in_progress
update_research_team_status(message_buffer, "in_progress")
elif next_type and next_type in [a.value for a in selections["analysts"]]:
message_buffer.update_agent_status(next_analyst, "in_progress")
# Research Team - Handle Investment Debate State
if (
"investment_debate_state" in chunk
and chunk["investment_debate_state"]
):
debate_state = chunk["investment_debate_state"]
# Update Bull Researcher status and report
if "bull_history" in debate_state and debate_state["bull_history"]:
# Keep all research team members in progress
update_research_team_status(message_buffer, "in_progress")
# Extract latest bull response
bull_responses = debate_state["bull_history"].split("\n")
latest_bull = bull_responses[-1] if bull_responses else ""
if latest_bull:
message_buffer.add_message("Reasoning", latest_bull)
# Update research report with bull's latest analysis
message_buffer.update_report_section(
"investment_plan",
f"### Bull Researcher Analysis\n{latest_bull}",
)
# Update Bear Researcher status and report
if "bear_history" in debate_state and debate_state["bear_history"]:
# Keep all research team members in progress
update_research_team_status(message_buffer, "in_progress")
# Extract latest bear response
bear_responses = debate_state["bear_history"].split("\n")
latest_bear = bear_responses[-1] if bear_responses else ""
if latest_bear:
message_buffer.add_message("Reasoning", latest_bear)
# Update research report with bear's latest analysis
message_buffer.update_report_section(
"investment_plan",
f"{message_buffer.report_sections['investment_plan']}\n\n### Bear Researcher Analysis\n{latest_bear}",
)
# Update Research Manager status and final decision
if (
"judge_decision" in debate_state
and debate_state["judge_decision"]
):
# Keep all research team members in progress until final decision
update_research_team_status(message_buffer, "in_progress")
message_buffer.add_message(
"Reasoning",
f"Research Manager: {debate_state['judge_decision']}",
)
# Update research report with final decision
message_buffer.update_report_section(
"investment_plan",
f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}",
)
# Mark all research team members as completed
update_research_team_status(message_buffer, "completed")
# Set first risk analyst to in_progress
message_buffer.update_agent_status(
"Risky Analyst", "in_progress"
)
# Trading Team
if (
"trader_investment_plan" in chunk
and chunk["trader_investment_plan"]
):
message_buffer.update_report_section(
"trader_investment_plan", chunk["trader_investment_plan"]
)
# Set first risk analyst to in_progress
message_buffer.update_agent_status("Risky Analyst", "in_progress")
# Risk Management Team - Handle Risk Debate State
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
risk_state = chunk["risk_debate_state"]
# Handle all risk analysts with a mapping
risk_analysts = [
("current_risky_response", "Risky Analyst"),
("current_safe_response", "Safe Analyst"),
("current_neutral_response", "Neutral Analyst"),
]
for response_key, analyst_name in risk_analysts:
if response_key in risk_state and risk_state[response_key]:
message_buffer.update_agent_status(analyst_name, "in_progress")
message_buffer.add_message(
"Reasoning",
f"{analyst_name}: {risk_state[response_key]}",
)
message_buffer.update_report_section(
"final_trade_decision",
f"### {analyst_name} Analysis\n{risk_state[response_key]}",
)
# Update Portfolio Manager status and final decision
if "judge_decision" in risk_state and risk_state["judge_decision"]:
message_buffer.update_agent_status(
"Portfolio Manager", "in_progress"
)
message_buffer.add_message(
"Reasoning",
f"Portfolio Manager: {risk_state['judge_decision']}",
)
# Update risk report with final decision only
message_buffer.update_report_section(
"final_trade_decision",
f"### Portfolio Manager Decision\n{risk_state['judge_decision']}",
)
# Mark risk analysts as completed
message_buffer.update_agent_status("Risky Analyst", "completed")
message_buffer.update_agent_status("Safe Analyst", "completed")
message_buffer.update_agent_status(
"Neutral Analyst", "completed"
)
message_buffer.update_agent_status(
"Portfolio Manager", "completed"
)
# Update the display
update_display(layout, message_buffer)
trace.append(chunk)
# Get final state and decision
final_state = trace[-1]
decision = graph.process_signal(final_state["final_trade_decision"])
# Update all agent statuses to completed
for agent in message_buffer.agent_status:
message_buffer.update_agent_status(agent, "completed")
message_buffer.add_message(
"Analysis", f"Completed analysis for {selections['analysis_date']}"
)
# Update final report sections
for section in message_buffer.report_sections.keys():
if section in final_state:
message_buffer.update_report_section(section, final_state[section])
# Display the complete final report
display_complete_report(final_state)
update_display(layout, message_buffer)
@app.command()
def analyze():
run_analysis()
if __name__ == "__main__":
app()