TradingAgents/cli/display.py

417 lines
14 KiB
Python

from typing import Any
from rich import box
from rich.columns import Columns
from rich.console import Console
from rich.layout import Layout
from rich.markdown import Markdown
from rich.panel import Panel
from rich.spinner import Spinner
from rich.table import Table
from rich.text import Text
from cli.models import AgentStatus
from cli.state import message_buffer
console = Console()
def create_layout() -> Layout:
layout = Layout()
layout.split_column(
Layout(name="header", size=3),
Layout(name="main"),
Layout(name="footer", size=3),
)
layout["main"].split_column(
Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5)
)
layout["upper"].split_row(
Layout(name="progress", ratio=2), Layout(name="messages", ratio=3)
)
return layout
def update_display(layout: Layout, spinner_text: str | None = None) -> None:
layout["header"].update(
Panel(
"[bold green]Welcome to TradingAgents CLI[/bold green]\n"
"[dim]Built by Tauric Research (https://github.com/TauricResearch)[/dim]",
title="Welcome to TradingAgents",
border_style="green",
padding=(1, 2),
expand=True,
)
)
progress_table = Table(
show_header=True,
header_style="bold magenta",
show_footer=False,
box=box.SIMPLE_HEAD,
title=None,
padding=(0, 2),
expand=True,
)
progress_table.add_column("Team", style="cyan", justify="center", width=20)
progress_table.add_column("Agent", style="green", justify="center", width=20)
progress_table.add_column("Status", style="yellow", justify="center", width=20)
teams = {
"Analyst Team": [
"Market Analyst",
"Social Analyst",
"News Analyst",
"Fundamentals Analyst",
],
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
"Trading Team": ["Trader"],
"Risk Management": ["Risky Analyst", "Neutral Analyst", "Safe Analyst"],
"Portfolio Management": ["Portfolio Manager"],
}
for team, agents in teams.items():
first_agent = agents[0]
status = message_buffer.agent_status[first_agent]
if status == AgentStatus.IN_PROGRESS:
spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
)
status_cell = spinner
else:
status_color = {
AgentStatus.PENDING: "yellow",
AgentStatus.COMPLETED: "green",
AgentStatus.ERROR: "red",
}.get(status, "white")
status_cell = f"[{status_color}]{status.value}[/{status_color}]"
progress_table.add_row(team, first_agent, status_cell)
for agent in agents[1:]:
status = message_buffer.agent_status[agent]
if status == AgentStatus.IN_PROGRESS:
spinner = Spinner(
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
)
status_cell = spinner
else:
status_color = {
AgentStatus.PENDING: "yellow",
AgentStatus.COMPLETED: "green",
AgentStatus.ERROR: "red",
}.get(status, "white")
status_cell = f"[{status_color}]{status.value}[/{status_color}]"
progress_table.add_row("", agent, status_cell)
progress_table.add_row("-" * 20, "-" * 20, "-" * 20, style="dim")
layout["progress"].update(
Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2))
)
messages_table = Table(
show_header=True,
header_style="bold magenta",
show_footer=False,
expand=True,
box=box.MINIMAL,
show_lines=True,
padding=(0, 1),
)
messages_table.add_column("Time", style="cyan", width=8, justify="center")
messages_table.add_column("Type", style="green", width=10, justify="center")
messages_table.add_column("Content", style="white", no_wrap=False, ratio=1)
all_messages = []
for timestamp, tool_name, args in message_buffer.tool_calls:
if isinstance(args, str) and len(args) > 100:
args = args[:97] + "..."
all_messages.append((timestamp, "Tool", f"{tool_name}: {args}"))
for timestamp, msg_type, content in message_buffer.messages:
content_str = content
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif item.get("type") == "tool_use":
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
else:
text_parts.append(str(item))
content_str = " ".join(text_parts)
elif not isinstance(content_str, str):
content_str = str(content)
if len(content_str) > 200:
content_str = content_str[:197] + "..."
all_messages.append((timestamp, msg_type, content_str))
all_messages.sort(key=lambda x: x[0])
max_messages = 12
recent_messages = all_messages[-max_messages:]
for timestamp, msg_type, content in recent_messages:
wrapped_content = Text(content, overflow="fold")
messages_table.add_row(timestamp, msg_type, wrapped_content)
if spinner_text:
messages_table.add_row("", "Spinner", spinner_text)
if len(all_messages) > max_messages:
messages_table.footer = (
f"[dim]Showing last {max_messages} of {len(all_messages)} messages[/dim]"
)
layout["messages"].update(
Panel(
messages_table,
title="Messages & Tools",
border_style="blue",
padding=(1, 2),
)
)
if message_buffer.current_report:
layout["analysis"].update(
Panel(
Markdown(message_buffer.current_report),
title="Current Report",
border_style="green",
padding=(1, 2),
)
)
else:
layout["analysis"].update(
Panel(
"[italic]Waiting for analysis report...[/italic]",
title="Current Report",
border_style="green",
padding=(1, 2),
)
)
tool_calls_count = len(message_buffer.tool_calls)
llm_calls_count = sum(
1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning"
)
reports_count = sum(
1 for content in message_buffer.report_sections.values() if content is not None
)
stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
stats_table.add_column("Stats", justify="center")
stats_table.add_row(
f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}"
)
layout["footer"].update(Panel(stats_table, border_style="grey50"))
def display_complete_report(final_state: dict[str, Any]) -> None:
console.print("\n[bold green]Complete Analysis Report[/bold green]\n")
analyst_reports = []
if final_state.get("market_report"):
analyst_reports.append(
Panel(
Markdown(final_state["market_report"]),
title="Market Analyst",
border_style="blue",
padding=(1, 2),
)
)
if final_state.get("sentiment_report"):
analyst_reports.append(
Panel(
Markdown(final_state["sentiment_report"]),
title="Social Analyst",
border_style="blue",
padding=(1, 2),
)
)
if final_state.get("news_report"):
analyst_reports.append(
Panel(
Markdown(final_state["news_report"]),
title="News Analyst",
border_style="blue",
padding=(1, 2),
)
)
if final_state.get("fundamentals_report"):
analyst_reports.append(
Panel(
Markdown(final_state["fundamentals_report"]),
title="Fundamentals Analyst",
border_style="blue",
padding=(1, 2),
)
)
if analyst_reports:
console.print(
Panel(
Columns(analyst_reports, equal=True, expand=True),
title="I. Analyst Team Reports",
border_style="cyan",
padding=(1, 2),
)
)
if final_state.get("investment_debate_state"):
research_reports = []
debate_state = final_state["investment_debate_state"]
if debate_state.get("bull_history"):
research_reports.append(
Panel(
Markdown(debate_state["bull_history"]),
title="Bull Researcher",
border_style="blue",
padding=(1, 2),
)
)
if debate_state.get("bear_history"):
research_reports.append(
Panel(
Markdown(debate_state["bear_history"]),
title="Bear Researcher",
border_style="blue",
padding=(1, 2),
)
)
if debate_state.get("judge_decision"):
research_reports.append(
Panel(
Markdown(debate_state["judge_decision"]),
title="Research Manager",
border_style="blue",
padding=(1, 2),
)
)
if research_reports:
console.print(
Panel(
Columns(research_reports, equal=True, expand=True),
title="II. Research Team Decision",
border_style="magenta",
padding=(1, 2),
)
)
if final_state.get("trader_investment_plan"):
console.print(
Panel(
Panel(
Markdown(final_state["trader_investment_plan"]),
title="Trader",
border_style="blue",
padding=(1, 2),
),
title="III. Trading Team Plan",
border_style="yellow",
padding=(1, 2),
)
)
if final_state.get("risk_debate_state"):
risk_reports = []
risk_state = final_state["risk_debate_state"]
if risk_state.get("risky_history"):
risk_reports.append(
Panel(
Markdown(risk_state["risky_history"]),
title="Aggressive Analyst",
border_style="blue",
padding=(1, 2),
)
)
if risk_state.get("safe_history"):
risk_reports.append(
Panel(
Markdown(risk_state["safe_history"]),
title="Conservative Analyst",
border_style="blue",
padding=(1, 2),
)
)
if risk_state.get("neutral_history"):
risk_reports.append(
Panel(
Markdown(risk_state["neutral_history"]),
title="Neutral Analyst",
border_style="blue",
padding=(1, 2),
)
)
if risk_reports:
console.print(
Panel(
Columns(risk_reports, equal=True, expand=True),
title="IV. Risk Management Team Decision",
border_style="red",
padding=(1, 2),
)
)
if risk_state.get("judge_decision"):
console.print(
Panel(
Panel(
Markdown(risk_state["judge_decision"]),
title="Portfolio Manager",
border_style="blue",
padding=(1, 2),
),
title="V. Portfolio Manager Decision",
border_style="green",
padding=(1, 2),
)
)
def update_research_team_status(status: AgentStatus) -> None:
research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"]
for agent in research_team:
message_buffer.update_agent_status(agent, status)
def extract_content_string(content: Any) -> str:
if isinstance(content, str):
return content
elif isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif item.get("type") == "tool_use":
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
else:
text_parts.append(str(item))
return " ".join(text_parts)
else:
return str(content)
def create_question_box(title: str, prompt: str, default: str | None = None) -> Panel:
box_content = f"[bold]{title}[/bold]\n"
box_content += f"[dim]{prompt}[/dim]"
if default:
box_content += f"\n[dim]Default: {default}[/dim]"
return Panel(box_content, border_style="blue", padding=(1, 2))