417 lines
14 KiB
Python
417 lines
14 KiB
Python
from typing import Optional, Dict, Any
|
|
|
|
from rich.console import Console
|
|
from cli.models import AgentStatus
|
|
from rich.panel import Panel
|
|
from rich.spinner import Spinner
|
|
from rich.markdown import Markdown
|
|
from rich.layout import Layout
|
|
from rich.text import Text
|
|
from rich.table import Table
|
|
from rich.columns import Columns
|
|
from rich import box
|
|
|
|
from cli.state import message_buffer
|
|
|
|
console = Console()
|
|
|
|
|
|
def create_layout() -> Layout:
|
|
layout = Layout()
|
|
layout.split_column(
|
|
Layout(name="header", size=3),
|
|
Layout(name="main"),
|
|
Layout(name="footer", size=3),
|
|
)
|
|
layout["main"].split_column(
|
|
Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5)
|
|
)
|
|
layout["upper"].split_row(
|
|
Layout(name="progress", ratio=2), Layout(name="messages", ratio=3)
|
|
)
|
|
return layout
|
|
|
|
|
|
def update_display(layout: Layout, spinner_text: Optional[str] = None) -> None:
|
|
layout["header"].update(
|
|
Panel(
|
|
"[bold green]Welcome to TradingAgents CLI[/bold green]\n"
|
|
"[dim]Built by Tauric Research (https://github.com/TauricResearch)[/dim]",
|
|
title="Welcome to TradingAgents",
|
|
border_style="green",
|
|
padding=(1, 2),
|
|
expand=True,
|
|
)
|
|
)
|
|
|
|
progress_table = Table(
|
|
show_header=True,
|
|
header_style="bold magenta",
|
|
show_footer=False,
|
|
box=box.SIMPLE_HEAD,
|
|
title=None,
|
|
padding=(0, 2),
|
|
expand=True,
|
|
)
|
|
progress_table.add_column("Team", style="cyan", justify="center", width=20)
|
|
progress_table.add_column("Agent", style="green", justify="center", width=20)
|
|
progress_table.add_column("Status", style="yellow", justify="center", width=20)
|
|
|
|
teams = {
|
|
"Analyst Team": [
|
|
"Market Analyst",
|
|
"Social Analyst",
|
|
"News Analyst",
|
|
"Fundamentals Analyst",
|
|
],
|
|
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
|
|
"Trading Team": ["Trader"],
|
|
"Risk Management": ["Risky Analyst", "Neutral Analyst", "Safe Analyst"],
|
|
"Portfolio Management": ["Portfolio Manager"],
|
|
}
|
|
|
|
for team, agents in teams.items():
|
|
first_agent = agents[0]
|
|
status = message_buffer.agent_status[first_agent]
|
|
if status == AgentStatus.IN_PROGRESS:
|
|
spinner = Spinner(
|
|
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
|
)
|
|
status_cell = spinner
|
|
else:
|
|
status_color = {
|
|
AgentStatus.PENDING: "yellow",
|
|
AgentStatus.COMPLETED: "green",
|
|
AgentStatus.ERROR: "red",
|
|
}.get(status, "white")
|
|
status_cell = f"[{status_color}]{status.value}[/{status_color}]"
|
|
progress_table.add_row(team, first_agent, status_cell)
|
|
|
|
for agent in agents[1:]:
|
|
status = message_buffer.agent_status[agent]
|
|
if status == AgentStatus.IN_PROGRESS:
|
|
spinner = Spinner(
|
|
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
|
)
|
|
status_cell = spinner
|
|
else:
|
|
status_color = {
|
|
AgentStatus.PENDING: "yellow",
|
|
AgentStatus.COMPLETED: "green",
|
|
AgentStatus.ERROR: "red",
|
|
}.get(status, "white")
|
|
status_cell = f"[{status_color}]{status.value}[/{status_color}]"
|
|
progress_table.add_row("", agent, status_cell)
|
|
|
|
progress_table.add_row("-" * 20, "-" * 20, "-" * 20, style="dim")
|
|
|
|
layout["progress"].update(
|
|
Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2))
|
|
)
|
|
|
|
messages_table = Table(
|
|
show_header=True,
|
|
header_style="bold magenta",
|
|
show_footer=False,
|
|
expand=True,
|
|
box=box.MINIMAL,
|
|
show_lines=True,
|
|
padding=(0, 1),
|
|
)
|
|
messages_table.add_column("Time", style="cyan", width=8, justify="center")
|
|
messages_table.add_column("Type", style="green", width=10, justify="center")
|
|
messages_table.add_column("Content", style="white", no_wrap=False, ratio=1)
|
|
|
|
all_messages = []
|
|
|
|
for timestamp, tool_name, args in message_buffer.tool_calls:
|
|
if isinstance(args, str) and len(args) > 100:
|
|
args = args[:97] + "..."
|
|
all_messages.append((timestamp, "Tool", f"{tool_name}: {args}"))
|
|
|
|
for timestamp, msg_type, content in message_buffer.messages:
|
|
content_str = content
|
|
if isinstance(content, list):
|
|
text_parts = []
|
|
for item in content:
|
|
if isinstance(item, dict):
|
|
if item.get('type') == 'text':
|
|
text_parts.append(item.get('text', ''))
|
|
elif item.get('type') == 'tool_use':
|
|
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
|
|
else:
|
|
text_parts.append(str(item))
|
|
content_str = ' '.join(text_parts)
|
|
elif not isinstance(content_str, str):
|
|
content_str = str(content)
|
|
|
|
if len(content_str) > 200:
|
|
content_str = content_str[:197] + "..."
|
|
all_messages.append((timestamp, msg_type, content_str))
|
|
|
|
all_messages.sort(key=lambda x: x[0])
|
|
max_messages = 12
|
|
recent_messages = all_messages[-max_messages:]
|
|
|
|
for timestamp, msg_type, content in recent_messages:
|
|
wrapped_content = Text(content, overflow="fold")
|
|
messages_table.add_row(timestamp, msg_type, wrapped_content)
|
|
|
|
if spinner_text:
|
|
messages_table.add_row("", "Spinner", spinner_text)
|
|
|
|
if len(all_messages) > max_messages:
|
|
messages_table.footer = (
|
|
f"[dim]Showing last {max_messages} of {len(all_messages)} messages[/dim]"
|
|
)
|
|
|
|
layout["messages"].update(
|
|
Panel(
|
|
messages_table,
|
|
title="Messages & Tools",
|
|
border_style="blue",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if message_buffer.current_report:
|
|
layout["analysis"].update(
|
|
Panel(
|
|
Markdown(message_buffer.current_report),
|
|
title="Current Report",
|
|
border_style="green",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
else:
|
|
layout["analysis"].update(
|
|
Panel(
|
|
"[italic]Waiting for analysis report...[/italic]",
|
|
title="Current Report",
|
|
border_style="green",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
tool_calls_count = len(message_buffer.tool_calls)
|
|
llm_calls_count = sum(
|
|
1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning"
|
|
)
|
|
reports_count = sum(
|
|
1 for content in message_buffer.report_sections.values() if content is not None
|
|
)
|
|
|
|
stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
|
|
stats_table.add_column("Stats", justify="center")
|
|
stats_table.add_row(
|
|
f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}"
|
|
)
|
|
|
|
layout["footer"].update(Panel(stats_table, border_style="grey50"))
|
|
|
|
|
|
def display_complete_report(final_state: Dict[str, Any]) -> None:
|
|
console.print("\n[bold green]Complete Analysis Report[/bold green]\n")
|
|
|
|
analyst_reports = []
|
|
|
|
if final_state.get("market_report"):
|
|
analyst_reports.append(
|
|
Panel(
|
|
Markdown(final_state["market_report"]),
|
|
title="Market Analyst",
|
|
border_style="blue",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if final_state.get("sentiment_report"):
|
|
analyst_reports.append(
|
|
Panel(
|
|
Markdown(final_state["sentiment_report"]),
|
|
title="Social Analyst",
|
|
border_style="blue",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if final_state.get("news_report"):
|
|
analyst_reports.append(
|
|
Panel(
|
|
Markdown(final_state["news_report"]),
|
|
title="News Analyst",
|
|
border_style="blue",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if final_state.get("fundamentals_report"):
|
|
analyst_reports.append(
|
|
Panel(
|
|
Markdown(final_state["fundamentals_report"]),
|
|
title="Fundamentals Analyst",
|
|
border_style="blue",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if analyst_reports:
|
|
console.print(
|
|
Panel(
|
|
Columns(analyst_reports, equal=True, expand=True),
|
|
title="I. Analyst Team Reports",
|
|
border_style="cyan",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if final_state.get("investment_debate_state"):
|
|
research_reports = []
|
|
debate_state = final_state["investment_debate_state"]
|
|
|
|
if debate_state.get("bull_history"):
|
|
research_reports.append(
|
|
Panel(
|
|
Markdown(debate_state["bull_history"]),
|
|
title="Bull Researcher",
|
|
border_style="blue",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if debate_state.get("bear_history"):
|
|
research_reports.append(
|
|
Panel(
|
|
Markdown(debate_state["bear_history"]),
|
|
title="Bear Researcher",
|
|
border_style="blue",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if debate_state.get("judge_decision"):
|
|
research_reports.append(
|
|
Panel(
|
|
Markdown(debate_state["judge_decision"]),
|
|
title="Research Manager",
|
|
border_style="blue",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if research_reports:
|
|
console.print(
|
|
Panel(
|
|
Columns(research_reports, equal=True, expand=True),
|
|
title="II. Research Team Decision",
|
|
border_style="magenta",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if final_state.get("trader_investment_plan"):
|
|
console.print(
|
|
Panel(
|
|
Panel(
|
|
Markdown(final_state["trader_investment_plan"]),
|
|
title="Trader",
|
|
border_style="blue",
|
|
padding=(1, 2),
|
|
),
|
|
title="III. Trading Team Plan",
|
|
border_style="yellow",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if final_state.get("risk_debate_state"):
|
|
risk_reports = []
|
|
risk_state = final_state["risk_debate_state"]
|
|
|
|
if risk_state.get("risky_history"):
|
|
risk_reports.append(
|
|
Panel(
|
|
Markdown(risk_state["risky_history"]),
|
|
title="Aggressive Analyst",
|
|
border_style="blue",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if risk_state.get("safe_history"):
|
|
risk_reports.append(
|
|
Panel(
|
|
Markdown(risk_state["safe_history"]),
|
|
title="Conservative Analyst",
|
|
border_style="blue",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if risk_state.get("neutral_history"):
|
|
risk_reports.append(
|
|
Panel(
|
|
Markdown(risk_state["neutral_history"]),
|
|
title="Neutral Analyst",
|
|
border_style="blue",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if risk_reports:
|
|
console.print(
|
|
Panel(
|
|
Columns(risk_reports, equal=True, expand=True),
|
|
title="IV. Risk Management Team Decision",
|
|
border_style="red",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
if risk_state.get("judge_decision"):
|
|
console.print(
|
|
Panel(
|
|
Panel(
|
|
Markdown(risk_state["judge_decision"]),
|
|
title="Portfolio Manager",
|
|
border_style="blue",
|
|
padding=(1, 2),
|
|
),
|
|
title="V. Portfolio Manager Decision",
|
|
border_style="green",
|
|
padding=(1, 2),
|
|
)
|
|
)
|
|
|
|
|
|
def update_research_team_status(status: AgentStatus) -> None:
|
|
research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"]
|
|
for agent in research_team:
|
|
message_buffer.update_agent_status(agent, status)
|
|
|
|
|
|
def extract_content_string(content: Any) -> str:
|
|
if isinstance(content, str):
|
|
return content
|
|
elif isinstance(content, list):
|
|
text_parts = []
|
|
for item in content:
|
|
if isinstance(item, dict):
|
|
if item.get('type') == 'text':
|
|
text_parts.append(item.get('text', ''))
|
|
elif item.get('type') == 'tool_use':
|
|
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
|
|
else:
|
|
text_parts.append(str(item))
|
|
return ' '.join(text_parts)
|
|
else:
|
|
return str(content)
|
|
|
|
|
|
def create_question_box(title: str, prompt: str, default: Optional[str] = None) -> Panel:
|
|
box_content = f"[bold]{title}[/bold]\n"
|
|
box_content += f"[dim]{prompt}[/dim]"
|
|
if default:
|
|
box_content += f"\n[dim]Default: {default}[/dim]"
|
|
return Panel(box_content, border_style="blue", padding=(1, 2))
|