refactor: split cli/main.py into modular components
Extract cli/main.py (1916 lines) into focused modules: - cli/state.py: MessageBuffer class for state management - cli/display.py: Layout, progress tables, and report display functions - cli/discovery.py: Trending stock discovery flow and UI - cli/analysis.py: Stock analysis flow and chunk processing - cli/backtest_cmd.py: Backtesting command and strategies main.py reduced from 1916 to 110 lines, serving as entry point only 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
9c252fdc2c
commit
293df9c552
|
|
@ -0,0 +1,442 @@
|
|||
import datetime
|
||||
from pathlib import Path
|
||||
from functools import wraps
|
||||
from typing import List
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.live import Live
|
||||
from rich.align import Align
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
from cli.state import message_buffer
|
||||
from cli.models import AnalystType
|
||||
from cli.display import (
|
||||
create_layout,
|
||||
update_display,
|
||||
display_complete_report,
|
||||
update_research_team_status,
|
||||
extract_content_string,
|
||||
create_question_box,
|
||||
console,
|
||||
)
|
||||
from cli.utils import (
|
||||
select_analysts,
|
||||
select_research_depth,
|
||||
select_shallow_thinking_agent,
|
||||
select_deep_thinking_agent,
|
||||
select_llm_provider,
|
||||
loading,
|
||||
)
|
||||
|
||||
|
||||
def get_ticker():
|
||||
return typer.prompt("", default="SPY")
|
||||
|
||||
|
||||
def get_analysis_date():
|
||||
while True:
|
||||
date_str = typer.prompt(
|
||||
"", default=datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
)
|
||||
try:
|
||||
analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
|
||||
if analysis_date.date() > datetime.datetime.now().date():
|
||||
console.print("[red]Error: Analysis date cannot be in the future[/red]")
|
||||
continue
|
||||
return date_str
|
||||
except ValueError:
|
||||
console.print(
|
||||
"[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]"
|
||||
)
|
||||
|
||||
|
||||
def get_user_selections():
|
||||
with open("./cli/static/welcome.txt", "r") as f:
|
||||
welcome_ascii = f.read()
|
||||
|
||||
welcome_content = f"{welcome_ascii}\n"
|
||||
welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n"
|
||||
welcome_content += "[bold]Workflow Steps:[/bold]\n"
|
||||
welcome_content += "I. Analyst Team -> II. Research Team -> III. Trader -> IV. Risk Management -> V. Portfolio Management\n\n"
|
||||
welcome_content += "[dim]Built by Tauric Research (https://github.com/TauricResearch)[/dim]"
|
||||
|
||||
welcome_box = Panel(
|
||||
welcome_content,
|
||||
border_style="green",
|
||||
padding=(1, 2),
|
||||
title="Welcome to TradingAgents",
|
||||
subtitle="Multi-Agents LLM Financial Trading Framework",
|
||||
)
|
||||
console.print(Align.center(welcome_box))
|
||||
console.print()
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY"
|
||||
)
|
||||
)
|
||||
selected_ticker = get_ticker()
|
||||
|
||||
default_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 2: Analysis Date",
|
||||
"Enter the analysis date (YYYY-MM-DD)",
|
||||
default_date,
|
||||
)
|
||||
)
|
||||
analysis_date = get_analysis_date()
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis"
|
||||
)
|
||||
)
|
||||
selected_analysts = select_analysts()
|
||||
console.print(
|
||||
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
|
||||
)
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 4: Research Depth", "Select your research depth level"
|
||||
)
|
||||
)
|
||||
selected_research_depth = select_research_depth()
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 5: OpenAI backend", "Select which service to talk to"
|
||||
)
|
||||
)
|
||||
selected_llm_provider, backend_url = select_llm_provider()
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 6: Thinking Agents", "Select your thinking agents for analysis"
|
||||
)
|
||||
)
|
||||
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
||||
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
|
||||
|
||||
return {
|
||||
"ticker": selected_ticker,
|
||||
"analysis_date": analysis_date,
|
||||
"analysts": selected_analysts,
|
||||
"research_depth": selected_research_depth,
|
||||
"llm_provider": selected_llm_provider.lower(),
|
||||
"backend_url": backend_url,
|
||||
"shallow_thinker": selected_shallow_thinker,
|
||||
"deep_thinker": selected_deep_thinker,
|
||||
}
|
||||
|
||||
|
||||
def process_chunk_for_display(chunk, selected_analysts: List[AnalystType]):
|
||||
if "market_report" in chunk and chunk["market_report"]:
|
||||
message_buffer.update_report_section("market_report", chunk["market_report"])
|
||||
message_buffer.update_agent_status("Market Analyst", "completed")
|
||||
if AnalystType.SOCIAL in selected_analysts:
|
||||
message_buffer.update_agent_status("Social Analyst", "in_progress")
|
||||
|
||||
if "sentiment_report" in chunk and chunk["sentiment_report"]:
|
||||
message_buffer.update_report_section("sentiment_report", chunk["sentiment_report"])
|
||||
message_buffer.update_agent_status("Social Analyst", "completed")
|
||||
if AnalystType.NEWS in selected_analysts:
|
||||
message_buffer.update_agent_status("News Analyst", "in_progress")
|
||||
|
||||
if "news_report" in chunk and chunk["news_report"]:
|
||||
message_buffer.update_report_section("news_report", chunk["news_report"])
|
||||
message_buffer.update_agent_status("News Analyst", "completed")
|
||||
if AnalystType.FUNDAMENTALS in selected_analysts:
|
||||
message_buffer.update_agent_status("Fundamentals Analyst", "in_progress")
|
||||
|
||||
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
|
||||
message_buffer.update_report_section("fundamentals_report", chunk["fundamentals_report"])
|
||||
message_buffer.update_agent_status("Fundamentals Analyst", "completed")
|
||||
update_research_team_status("in_progress")
|
||||
|
||||
if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
|
||||
debate_state = chunk["investment_debate_state"]
|
||||
|
||||
if "bull_history" in debate_state and debate_state["bull_history"]:
|
||||
update_research_team_status("in_progress")
|
||||
bull_responses = debate_state["bull_history"].split("\n")
|
||||
latest_bull = bull_responses[-1] if bull_responses else ""
|
||||
if latest_bull:
|
||||
message_buffer.add_message("Reasoning", latest_bull)
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan",
|
||||
f"### Bull Researcher Analysis\n{latest_bull}",
|
||||
)
|
||||
|
||||
if "bear_history" in debate_state and debate_state["bear_history"]:
|
||||
update_research_team_status("in_progress")
|
||||
bear_responses = debate_state["bear_history"].split("\n")
|
||||
latest_bear = bear_responses[-1] if bear_responses else ""
|
||||
if latest_bear:
|
||||
message_buffer.add_message("Reasoning", latest_bear)
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan",
|
||||
f"{message_buffer.report_sections['investment_plan']}\n\n### Bear Researcher Analysis\n{latest_bear}",
|
||||
)
|
||||
|
||||
if "judge_decision" in debate_state and debate_state["judge_decision"]:
|
||||
update_research_team_status("in_progress")
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Research Manager: {debate_state['judge_decision']}",
|
||||
)
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan",
|
||||
f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}",
|
||||
)
|
||||
update_research_team_status("completed")
|
||||
message_buffer.update_agent_status("Risky Analyst", "in_progress")
|
||||
|
||||
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
|
||||
message_buffer.update_report_section("trader_investment_plan", chunk["trader_investment_plan"])
|
||||
message_buffer.update_agent_status("Risky Analyst", "in_progress")
|
||||
|
||||
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
|
||||
risk_state = chunk["risk_debate_state"]
|
||||
|
||||
if "current_risky_response" in risk_state and risk_state["current_risky_response"]:
|
||||
message_buffer.update_agent_status("Risky Analyst", "in_progress")
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Risky Analyst: {risk_state['current_risky_response']}",
|
||||
)
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision",
|
||||
f"### Risky Analyst Analysis\n{risk_state['current_risky_response']}",
|
||||
)
|
||||
|
||||
if "current_safe_response" in risk_state and risk_state["current_safe_response"]:
|
||||
message_buffer.update_agent_status("Safe Analyst", "in_progress")
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Safe Analyst: {risk_state['current_safe_response']}",
|
||||
)
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision",
|
||||
f"### Safe Analyst Analysis\n{risk_state['current_safe_response']}",
|
||||
)
|
||||
|
||||
if "current_neutral_response" in risk_state and risk_state["current_neutral_response"]:
|
||||
message_buffer.update_agent_status("Neutral Analyst", "in_progress")
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Neutral Analyst: {risk_state['current_neutral_response']}",
|
||||
)
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision",
|
||||
f"### Neutral Analyst Analysis\n{risk_state['current_neutral_response']}",
|
||||
)
|
||||
|
||||
if "judge_decision" in risk_state and risk_state["judge_decision"]:
|
||||
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Portfolio Manager: {risk_state['judge_decision']}",
|
||||
)
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision",
|
||||
f"### Portfolio Manager Decision\n{risk_state['judge_decision']}",
|
||||
)
|
||||
message_buffer.update_agent_status("Risky Analyst", "completed")
|
||||
message_buffer.update_agent_status("Safe Analyst", "completed")
|
||||
message_buffer.update_agent_status("Neutral Analyst", "completed")
|
||||
message_buffer.update_agent_status("Portfolio Manager", "completed")
|
||||
|
||||
|
||||
def setup_logging_decorators(report_dir, log_file):
|
||||
def save_message_decorator(obj, func_name):
|
||||
func = getattr(obj, func_name)
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
func(*args, **kwargs)
|
||||
timestamp, message_type, content = obj.messages[-1]
|
||||
content = content.replace("\n", " ")
|
||||
with open(log_file, "a") as f:
|
||||
f.write(f"{timestamp} [{message_type}] {content}\n")
|
||||
return wrapper
|
||||
|
||||
def save_tool_call_decorator(obj, func_name):
|
||||
func = getattr(obj, func_name)
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
func(*args, **kwargs)
|
||||
timestamp, tool_name, tool_args = obj.tool_calls[-1]
|
||||
args_str = ", ".join(f"{k}={v}" for k, v in tool_args.items())
|
||||
with open(log_file, "a") as f:
|
||||
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
|
||||
return wrapper
|
||||
|
||||
def save_report_section_decorator(obj, func_name):
|
||||
func = getattr(obj, func_name)
|
||||
@wraps(func)
|
||||
def wrapper(section_name, content):
|
||||
func(section_name, content)
|
||||
if section_name in obj.report_sections and obj.report_sections[section_name] is not None:
|
||||
section_content = obj.report_sections[section_name]
|
||||
if section_content:
|
||||
file_name = f"{section_name}.md"
|
||||
with open(report_dir / file_name, "w") as f:
|
||||
f.write(section_content)
|
||||
return wrapper
|
||||
|
||||
return save_message_decorator, save_tool_call_decorator, save_report_section_decorator
|
||||
|
||||
|
||||
def run_analysis_for_ticker(ticker: str, config: dict):
|
||||
analysis_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Analysts Team",
|
||||
"Select your LLM analyst agents for the analysis"
|
||||
)
|
||||
)
|
||||
selected_analysts = select_analysts()
|
||||
console.print(
|
||||
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
|
||||
)
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Research Depth",
|
||||
"Select your research depth level"
|
||||
)
|
||||
)
|
||||
selected_research_depth = select_research_depth()
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Deep-Thinking Model",
|
||||
"Select the model for deep analysis"
|
||||
)
|
||||
)
|
||||
llm_provider = config.get("llm_provider", "openai")
|
||||
selected_deep_thinker = select_deep_thinking_agent(llm_provider.capitalize())
|
||||
|
||||
config["max_debate_rounds"] = selected_research_depth
|
||||
config["max_risk_discuss_rounds"] = selected_research_depth
|
||||
config["deep_think_llm"] = selected_deep_thinker
|
||||
|
||||
_run_analysis_with_config(ticker, analysis_date, selected_analysts, config)
|
||||
|
||||
|
||||
def run_analysis():
|
||||
selections = get_user_selections()
|
||||
|
||||
config = get_config()
|
||||
config["max_debate_rounds"] = selections["research_depth"]
|
||||
config["max_risk_discuss_rounds"] = selections["research_depth"]
|
||||
config["quick_think_llm"] = selections["shallow_thinker"]
|
||||
config["deep_think_llm"] = selections["deep_thinker"]
|
||||
config["backend_url"] = selections["backend_url"]
|
||||
config["llm_provider"] = selections["llm_provider"].lower()
|
||||
|
||||
_run_analysis_with_config(
|
||||
selections["ticker"],
|
||||
selections["analysis_date"],
|
||||
selections["analysts"],
|
||||
config
|
||||
)
|
||||
|
||||
|
||||
def _run_analysis_with_config(ticker: str, analysis_date: str, selected_analysts: List[AnalystType], config: dict):
|
||||
with loading("Initializing trading agents...", show_elapsed=True):
|
||||
graph = TradingAgentsGraph(
|
||||
[analyst.value for analyst in selected_analysts], config=config, debug=True
|
||||
)
|
||||
|
||||
results_dir = Path(config["results_dir"]) / ticker / analysis_date
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
report_dir = results_dir / "reports"
|
||||
report_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_file = results_dir / "message_tool.log"
|
||||
log_file.touch(exist_ok=True)
|
||||
|
||||
save_message_decorator, save_tool_call_decorator, save_report_section_decorator = \
|
||||
setup_logging_decorators(report_dir, log_file)
|
||||
|
||||
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
|
||||
message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call")
|
||||
message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section")
|
||||
|
||||
layout = create_layout()
|
||||
|
||||
with Live(layout, refresh_per_second=4):
|
||||
update_display(layout)
|
||||
|
||||
message_buffer.add_message("System", f"Selected ticker: {ticker}")
|
||||
message_buffer.add_message("System", f"Analysis date: {analysis_date}")
|
||||
message_buffer.add_message(
|
||||
"System",
|
||||
f"Selected analysts: {', '.join(analyst.value for analyst in selected_analysts)}",
|
||||
)
|
||||
update_display(layout)
|
||||
|
||||
for agent in message_buffer.agent_status:
|
||||
message_buffer.update_agent_status(agent, "pending")
|
||||
|
||||
for section in message_buffer.report_sections:
|
||||
message_buffer.report_sections[section] = None
|
||||
message_buffer.current_report = None
|
||||
message_buffer.final_report = None
|
||||
|
||||
first_analyst = f"{selected_analysts[0].value.capitalize()} Analyst"
|
||||
message_buffer.update_agent_status(first_analyst, "in_progress")
|
||||
update_display(layout)
|
||||
|
||||
spinner_text = f"Analyzing {ticker} on {analysis_date}..."
|
||||
update_display(layout, spinner_text)
|
||||
|
||||
init_agent_state = graph.propagator.create_initial_state(ticker, analysis_date)
|
||||
args = graph.propagator.get_graph_args()
|
||||
|
||||
trace = []
|
||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
||||
if len(chunk["messages"]) > 0:
|
||||
last_message = chunk["messages"][-1]
|
||||
|
||||
if hasattr(last_message, "content"):
|
||||
content = extract_content_string(last_message.content)
|
||||
msg_type = "Reasoning"
|
||||
else:
|
||||
content = str(last_message)
|
||||
msg_type = "System"
|
||||
|
||||
message_buffer.add_message(msg_type, content)
|
||||
|
||||
if hasattr(last_message, "tool_calls"):
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
|
||||
process_chunk_for_display(chunk, selected_analysts)
|
||||
update_display(layout)
|
||||
|
||||
trace.append(chunk)
|
||||
|
||||
final_state = trace[-1]
|
||||
decision = graph.process_signal(final_state["final_trade_decision"])
|
||||
|
||||
for agent in message_buffer.agent_status:
|
||||
message_buffer.update_agent_status(agent, "completed")
|
||||
|
||||
message_buffer.add_message("Analysis", f"Completed analysis for {analysis_date}")
|
||||
|
||||
for section in message_buffer.report_sections.keys():
|
||||
if section in final_state:
|
||||
message_buffer.update_report_section(section, final_state[section])
|
||||
|
||||
display_complete_report(final_state)
|
||||
update_display(layout)
|
||||
|
|
@ -0,0 +1,210 @@
|
|||
import datetime
|
||||
from decimal import Decimal
|
||||
from datetime import date as date_type
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich import box
|
||||
|
||||
from tradingagents.backtesting import SimpleBacktestEngine, DataLoader
|
||||
from tradingagents.models.backtest import BacktestConfig
|
||||
from tradingagents.models.portfolio import PortfolioConfig
|
||||
|
||||
from cli.display import create_question_box
|
||||
from cli.utils import loading
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def sma_buy(ticker, trading_date, ctx):
|
||||
loader = ctx["data_loader"]
|
||||
ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date)
|
||||
if len(ohlcv.bars) < 20:
|
||||
return False
|
||||
prices = [float(b.close) for b in ohlcv.bars[-20:]]
|
||||
sma = sum(prices) / len(prices)
|
||||
current = float(ohlcv.bars[-1].close)
|
||||
return current > sma * 1.02
|
||||
|
||||
|
||||
def sma_sell(ticker, trading_date, ctx):
|
||||
loader = ctx["data_loader"]
|
||||
ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date)
|
||||
if len(ohlcv.bars) < 20:
|
||||
return False
|
||||
prices = [float(b.close) for b in ohlcv.bars[-20:]]
|
||||
sma = sum(prices) / len(prices)
|
||||
current = float(ohlcv.bars[-1].close)
|
||||
return current < sma * 0.98
|
||||
|
||||
|
||||
def rsi_buy(ticker, trading_date, ctx):
|
||||
loader = ctx["data_loader"]
|
||||
ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date)
|
||||
if len(ohlcv.bars) < 15:
|
||||
return False
|
||||
changes = []
|
||||
for i in range(1, min(15, len(ohlcv.bars))):
|
||||
changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i-1].close))
|
||||
gains = [c for c in changes if c > 0]
|
||||
losses = [-c for c in changes if c < 0]
|
||||
avg_gain = sum(gains) / 14 if gains else 0.001
|
||||
avg_loss = sum(losses) / 14 if losses else 0.001
|
||||
rs = avg_gain / avg_loss if avg_loss else 100
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi < 30
|
||||
|
||||
|
||||
def rsi_sell(ticker, trading_date, ctx):
|
||||
loader = ctx["data_loader"]
|
||||
ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date)
|
||||
if len(ohlcv.bars) < 15:
|
||||
return False
|
||||
changes = []
|
||||
for i in range(1, min(15, len(ohlcv.bars))):
|
||||
changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i-1].close))
|
||||
gains = [c for c in changes if c > 0]
|
||||
losses = [-c for c in changes if c < 0]
|
||||
avg_gain = sum(gains) / 14 if gains else 0.001
|
||||
avg_loss = sum(losses) / 14 if losses else 0.001
|
||||
rs = avg_gain / avg_loss if avg_loss else 100
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi > 70
|
||||
|
||||
|
||||
def hold_buy(ticker, trading_date, ctx):
|
||||
return ctx.get("day_index", 0) == 5
|
||||
|
||||
|
||||
def hold_sell(ticker, trading_date, ctx):
|
||||
return False
|
||||
|
||||
|
||||
STRATEGIES = {
|
||||
"sma": (sma_buy, sma_sell),
|
||||
"rsi": (rsi_buy, rsi_sell),
|
||||
"hold": (hold_buy, hold_sell),
|
||||
}
|
||||
|
||||
|
||||
def run_backtest(
|
||||
ticker: str = None,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
initial_cash: float = 100000.0,
|
||||
strategy: str = "sma",
|
||||
):
|
||||
if not ticker:
|
||||
console.print(create_question_box("Ticker Symbol", "Enter the ticker symbol to backtest", "AAPL"))
|
||||
ticker = typer.prompt("", default="AAPL")
|
||||
|
||||
if not start_date:
|
||||
default_start = (datetime.datetime.now() - datetime.timedelta(days=365)).strftime("%Y-%m-%d")
|
||||
console.print(create_question_box("Start Date", "Enter backtest start date (YYYY-MM-DD)", default_start))
|
||||
start_date = typer.prompt("", default=default_start)
|
||||
|
||||
if not end_date:
|
||||
default_end = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
console.print(create_question_box("End Date", "Enter backtest end date (YYYY-MM-DD)", default_end))
|
||||
end_date = typer.prompt("", default=default_end)
|
||||
|
||||
try:
|
||||
start = datetime.datetime.strptime(start_date, "%Y-%m-%d").date()
|
||||
end = datetime.datetime.strptime(end_date, "%Y-%m-%d").date()
|
||||
except ValueError:
|
||||
console.print("[red]Invalid date format. Use YYYY-MM-DD[/red]")
|
||||
return
|
||||
|
||||
if start >= end:
|
||||
console.print("[red]Start date must be before end date[/red]")
|
||||
return
|
||||
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
f"[bold]Backtest Configuration[/bold]\n\n"
|
||||
f"Ticker: [cyan]{ticker.upper()}[/cyan]\n"
|
||||
f"Period: [cyan]{start_date}[/cyan] to [cyan]{end_date}[/cyan]\n"
|
||||
f"Initial Cash: [cyan]${initial_cash:,.2f}[/cyan]\n"
|
||||
f"Strategy: [cyan]{strategy}[/cyan]",
|
||||
title="Configuration",
|
||||
border_style="blue",
|
||||
))
|
||||
console.print()
|
||||
|
||||
if strategy not in STRATEGIES:
|
||||
console.print(f"[red]Unknown strategy: {strategy}. Use: sma, rsi, or hold[/red]")
|
||||
return
|
||||
|
||||
buy_fn, sell_fn = STRATEGIES[strategy]
|
||||
|
||||
config = BacktestConfig(
|
||||
name=f"{strategy.upper()} Backtest - {ticker.upper()}",
|
||||
tickers=[ticker.upper()],
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
portfolio_config=PortfolioConfig(
|
||||
initial_cash=Decimal(str(initial_cash)),
|
||||
commission_per_trade=Decimal("1"),
|
||||
slippage_percent=Decimal("0.05"),
|
||||
),
|
||||
warmup_period=5,
|
||||
)
|
||||
|
||||
with loading("Running backtest...", show_elapsed=True):
|
||||
engine = SimpleBacktestEngine(config, buy_signal=buy_fn, sell_signal=sell_fn)
|
||||
result = engine.run()
|
||||
|
||||
console.print()
|
||||
|
||||
if result.status == "failed":
|
||||
console.print(f"[red]Backtest failed: {result.error_message}[/red]")
|
||||
return
|
||||
|
||||
metrics = result.metrics
|
||||
trade_log = result.trade_log
|
||||
|
||||
performance_table = Table(title="Performance Metrics", box=box.ROUNDED)
|
||||
performance_table.add_column("Metric", style="cyan")
|
||||
performance_table.add_column("Value", style="green")
|
||||
|
||||
performance_table.add_row("Total Return", f"${float(metrics.total_return):,.2f}")
|
||||
performance_table.add_row("Total Return %", f"{float(metrics.total_return_percent):.2f}%")
|
||||
performance_table.add_row("Annualized Return", f"{float(metrics.annualized_return):.2f}%")
|
||||
performance_table.add_row("Sharpe Ratio", f"{float(metrics.sharpe_ratio):.2f}" if metrics.sharpe_ratio else "N/A")
|
||||
performance_table.add_row("Sortino Ratio", f"{float(metrics.sortino_ratio):.2f}" if metrics.sortino_ratio else "N/A")
|
||||
performance_table.add_row("Max Drawdown", f"{float(metrics.max_drawdown_percent):.2f}%")
|
||||
performance_table.add_row("Volatility (Ann.)", f"{float(metrics.annualized_volatility):.2f}%")
|
||||
|
||||
console.print(performance_table)
|
||||
console.print()
|
||||
|
||||
trading_table = Table(title="Trading Statistics", box=box.ROUNDED)
|
||||
trading_table.add_column("Metric", style="cyan")
|
||||
trading_table.add_column("Value", style="green")
|
||||
|
||||
trading_table.add_row("Total Trades", str(trade_log.total_trades))
|
||||
trading_table.add_row("Winning Trades", str(trade_log.winning_trades))
|
||||
trading_table.add_row("Losing Trades", str(trade_log.losing_trades))
|
||||
trading_table.add_row("Win Rate", f"{float(trade_log.win_rate):.1f}%" if trade_log.win_rate else "N/A")
|
||||
trading_table.add_row("Profit Factor", f"{float(trade_log.profit_factor):.2f}" if trade_log.profit_factor else "N/A")
|
||||
trading_table.add_row("Avg Win", f"${float(trade_log.avg_win):,.2f}" if trade_log.avg_win else "N/A")
|
||||
trading_table.add_row("Avg Loss", f"${float(trade_log.avg_loss):,.2f}" if trade_log.avg_loss else "N/A")
|
||||
|
||||
console.print(trading_table)
|
||||
console.print()
|
||||
|
||||
summary_table = Table(title="Portfolio Summary", box=box.ROUNDED)
|
||||
summary_table.add_column("Metric", style="cyan")
|
||||
summary_table.add_column("Value", style="green")
|
||||
|
||||
summary_table.add_row("Start Equity", f"${float(metrics.start_equity):,.2f}")
|
||||
summary_table.add_row("End Equity", f"${float(metrics.end_equity):,.2f}")
|
||||
summary_table.add_row("Trading Days", str(metrics.trading_days))
|
||||
summary_table.add_row("Duration", f"{result.duration_seconds:.1f} seconds")
|
||||
|
||||
console.print(summary_table)
|
||||
console.print()
|
||||
|
||||
console.print(f"[green]Backtest completed successfully![/green]")
|
||||
|
|
@ -0,0 +1,407 @@
|
|||
import time
|
||||
from typing import Optional, List
|
||||
|
||||
import questionary
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.rule import Rule
|
||||
from rich import box
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.dataflows.config import get_config
|
||||
from tradingagents.agents.discovery.models import (
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
DiscoveryStatus,
|
||||
TrendingStock,
|
||||
Sector,
|
||||
EventCategory,
|
||||
)
|
||||
from tradingagents.agents.discovery.persistence import save_discovery_result
|
||||
|
||||
from cli.display import create_question_box
|
||||
from cli.utils import (
|
||||
select_llm_provider,
|
||||
select_shallow_thinking_agent,
|
||||
loading,
|
||||
MultiStageLoader,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
LOOKBACK_OPTIONS = [
|
||||
("Last hour (1h)", "1h"),
|
||||
("Last 6 hours (6h)", "6h"),
|
||||
("Last 24 hours (24h)", "24h"),
|
||||
("Last 7 days (7d)", "7d"),
|
||||
]
|
||||
|
||||
SECTOR_OPTIONS = [
|
||||
("Technology", Sector.TECHNOLOGY),
|
||||
("Healthcare", Sector.HEALTHCARE),
|
||||
("Finance", Sector.FINANCE),
|
||||
("Energy", Sector.ENERGY),
|
||||
("Consumer Goods", Sector.CONSUMER_GOODS),
|
||||
("Industrials", Sector.INDUSTRIALS),
|
||||
("Other", Sector.OTHER),
|
||||
]
|
||||
|
||||
EVENT_OPTIONS = [
|
||||
("Earnings", EventCategory.EARNINGS),
|
||||
("Merger/Acquisition", EventCategory.MERGER_ACQUISITION),
|
||||
("Regulatory", EventCategory.REGULATORY),
|
||||
("Product Launch", EventCategory.PRODUCT_LAUNCH),
|
||||
("Executive Change", EventCategory.EXECUTIVE_CHANGE),
|
||||
("Other", EventCategory.OTHER),
|
||||
]
|
||||
|
||||
|
||||
def select_lookback_period() -> str:
|
||||
choice = questionary.select(
|
||||
"Select lookback period:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value) for display, value in LOOKBACK_OPTIONS
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
[
|
||||
("selected", "fg:cyan noinherit"),
|
||||
("highlighted", "fg:cyan noinherit"),
|
||||
("pointer", "fg:cyan noinherit"),
|
||||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]No lookback period selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
return choice
|
||||
|
||||
|
||||
def select_sector_filter() -> Optional[List[Sector]]:
|
||||
use_filter = questionary.confirm(
|
||||
"Filter by sector?",
|
||||
default=False,
|
||||
style=questionary.Style(
|
||||
[
|
||||
("selected", "fg:cyan noinherit"),
|
||||
("highlighted", "fg:cyan noinherit"),
|
||||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
if not use_filter:
|
||||
return None
|
||||
|
||||
choices = questionary.checkbox(
|
||||
"Select sectors to include:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value) for display, value in SECTOR_OPTIONS
|
||||
],
|
||||
instruction="\n- Press Space to select/unselect\n- Press 'a' to select all\n- Press Enter when done",
|
||||
style=questionary.Style(
|
||||
[
|
||||
("checkbox-selected", "fg:cyan"),
|
||||
("selected", "fg:cyan noinherit"),
|
||||
("highlighted", "noinherit"),
|
||||
("pointer", "noinherit"),
|
||||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
if not choices:
|
||||
return None
|
||||
|
||||
return choices
|
||||
|
||||
|
||||
def select_event_filter() -> Optional[List[EventCategory]]:
|
||||
use_filter = questionary.confirm(
|
||||
"Filter by event type?",
|
||||
default=False,
|
||||
style=questionary.Style(
|
||||
[
|
||||
("selected", "fg:cyan noinherit"),
|
||||
("highlighted", "fg:cyan noinherit"),
|
||||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
if not use_filter:
|
||||
return None
|
||||
|
||||
choices = questionary.checkbox(
|
||||
"Select event types to include:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value) for display, value in EVENT_OPTIONS
|
||||
],
|
||||
instruction="\n- Press Space to select/unselect\n- Press 'a' to select all\n- Press Enter when done",
|
||||
style=questionary.Style(
|
||||
[
|
||||
("checkbox-selected", "fg:cyan"),
|
||||
("selected", "fg:cyan noinherit"),
|
||||
("highlighted", "noinherit"),
|
||||
("pointer", "noinherit"),
|
||||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
if not choices:
|
||||
return None
|
||||
|
||||
return choices
|
||||
|
||||
|
||||
def create_discovery_results_table(trending_stocks: List[TrendingStock]) -> Table:
|
||||
table = Table(
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
box=box.ROUNDED,
|
||||
title="Trending Stocks",
|
||||
title_style="bold green",
|
||||
expand=True,
|
||||
)
|
||||
|
||||
table.add_column("Rank", style="cyan", justify="center", width=6)
|
||||
table.add_column("Ticker", style="bold yellow", justify="center", width=10)
|
||||
table.add_column("Company", style="white", justify="left", width=25)
|
||||
table.add_column("Score", style="green", justify="right", width=10)
|
||||
table.add_column("Mentions", style="blue", justify="center", width=10)
|
||||
table.add_column("Event Type", style="magenta", justify="center", width=18)
|
||||
|
||||
for rank, stock in enumerate(trending_stocks, 1):
|
||||
if rank <= 3:
|
||||
rank_display = f"[bold green]{rank}[/bold green]"
|
||||
ticker_display = f"[bold yellow]{stock.ticker}[/bold yellow]"
|
||||
else:
|
||||
rank_display = str(rank)
|
||||
ticker_display = stock.ticker
|
||||
|
||||
table.add_row(
|
||||
rank_display,
|
||||
ticker_display,
|
||||
stock.company_name[:25] if len(stock.company_name) > 25 else stock.company_name,
|
||||
f"{stock.score:.2f}",
|
||||
str(stock.mention_count),
|
||||
stock.event_type.value.replace("_", " ").title(),
|
||||
)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def create_stock_detail_panel(stock: TrendingStock, rank: int) -> Panel:
|
||||
sentiment_label = "positive" if stock.sentiment > 0.3 else "negative" if stock.sentiment < -0.3 else "neutral"
|
||||
sentiment_color = "green" if stock.sentiment > 0.3 else "red" if stock.sentiment < -0.3 else "yellow"
|
||||
|
||||
content = f"""[bold]Rank #{rank}: {stock.ticker} - {stock.company_name}[/bold]
|
||||
|
||||
[cyan]Score:[/cyan] {stock.score:.2f}
|
||||
[cyan]Sentiment:[/cyan] [{sentiment_color}]{stock.sentiment:.2f} ({sentiment_label})[/{sentiment_color}]
|
||||
[cyan]Sector:[/cyan] {stock.sector.value.replace("_", " ").title()}
|
||||
[cyan]Event Type:[/cyan] {stock.event_type.value.replace("_", " ").title()}
|
||||
[cyan]Mentions:[/cyan] {stock.mention_count}
|
||||
|
||||
[bold]News Summary:[/bold]
|
||||
{stock.news_summary}
|
||||
|
||||
[bold]Top Source Articles:[/bold]"""
|
||||
|
||||
for i, article in enumerate(stock.source_articles[:3], 1):
|
||||
content += f"\n {i}. [{article.title[:50]}...] - {article.source}"
|
||||
|
||||
return Panel(
|
||||
content,
|
||||
title=f"Stock Details: {stock.ticker}",
|
||||
border_style="cyan",
|
||||
padding=(1, 2),
|
||||
)
|
||||
|
||||
|
||||
def select_stock_for_detail(trending_stocks: List[TrendingStock]) -> Optional[TrendingStock]:
|
||||
if not trending_stocks:
|
||||
return None
|
||||
|
||||
choices = [
|
||||
questionary.Choice(
|
||||
f"{i+1}. {stock.ticker} - {stock.company_name} (Score: {stock.score:.2f})",
|
||||
value=stock
|
||||
)
|
||||
for i, stock in enumerate(trending_stocks)
|
||||
]
|
||||
choices.append(questionary.Choice("Back to menu", value=None))
|
||||
|
||||
selected = questionary.select(
|
||||
"Select a stock to view details:",
|
||||
choices=choices,
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
[
|
||||
("selected", "fg:cyan noinherit"),
|
||||
("highlighted", "fg:cyan noinherit"),
|
||||
("pointer", "fg:cyan noinherit"),
|
||||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def discover_trending_flow(run_analysis_callback=None):
|
||||
console.print(Rule("[bold green]Discover Trending Stocks[/bold green]"))
|
||||
console.print()
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 1: Lookback Period",
|
||||
"Select how far back to search for trending stocks"
|
||||
)
|
||||
)
|
||||
lookback_period = select_lookback_period()
|
||||
console.print(f"[green]Selected lookback period:[/green] {lookback_period}")
|
||||
console.print()
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 2: Sector Filter (Optional)",
|
||||
"Optionally filter results by sector"
|
||||
)
|
||||
)
|
||||
sector_filter = select_sector_filter()
|
||||
if sector_filter:
|
||||
console.print(f"[green]Selected sectors:[/green] {', '.join(s.value for s in sector_filter)}")
|
||||
else:
|
||||
console.print("[dim]No sector filter applied[/dim]")
|
||||
console.print()
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 3: Event Filter (Optional)",
|
||||
"Optionally filter results by event type"
|
||||
)
|
||||
)
|
||||
event_filter = select_event_filter()
|
||||
if event_filter:
|
||||
console.print(f"[green]Selected events:[/green] {', '.join(e.value for e in event_filter)}")
|
||||
else:
|
||||
console.print("[dim]No event filter applied[/dim]")
|
||||
console.print()
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 4: LLM Provider",
|
||||
"Select your LLM provider for entity extraction"
|
||||
)
|
||||
)
|
||||
selected_llm_provider, backend_url = select_llm_provider()
|
||||
console.print()
|
||||
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 5: Quick-Thinking Model",
|
||||
"Select the model for entity extraction"
|
||||
)
|
||||
)
|
||||
selected_model = select_shallow_thinking_agent(selected_llm_provider)
|
||||
console.print()
|
||||
|
||||
config = get_config()
|
||||
config["llm_provider"] = selected_llm_provider.lower()
|
||||
config["backend_url"] = backend_url
|
||||
config["quick_think_llm"] = selected_model
|
||||
config["deep_think_llm"] = selected_model
|
||||
|
||||
request = DiscoveryRequest(
|
||||
lookback_period=lookback_period,
|
||||
sector_filter=sector_filter,
|
||||
event_filter=event_filter,
|
||||
max_results=config.get("discovery_max_results", 20),
|
||||
)
|
||||
|
||||
discovery_stages = [
|
||||
"Initializing analysis engine",
|
||||
"Fetching news sources",
|
||||
"Extracting stock entities",
|
||||
"Resolving ticker symbols",
|
||||
"Calculating trending scores",
|
||||
]
|
||||
|
||||
result = None
|
||||
|
||||
with MultiStageLoader(discovery_stages, title="Discovery Progress") as loader:
|
||||
try:
|
||||
loader.next_stage()
|
||||
graph = TradingAgentsGraph(config=config, debug=False)
|
||||
|
||||
loader.next_stage()
|
||||
result = graph.discover_trending(request)
|
||||
|
||||
loader.next_stage()
|
||||
time.sleep(0.3)
|
||||
|
||||
loader.next_stage()
|
||||
time.sleep(0.3)
|
||||
|
||||
except (ValueError, KeyError, RuntimeError, ConnectionError, TimeoutError) as e:
|
||||
console.print(f"\n[red]Error during discovery: {e}[/red]")
|
||||
return
|
||||
|
||||
if result is None:
|
||||
console.print("\n[red]Discovery failed. Please try again.[/red]")
|
||||
return
|
||||
|
||||
if result.status == DiscoveryStatus.FAILED:
|
||||
console.print(f"\n[red]Discovery failed: {result.error_message}[/red]")
|
||||
return
|
||||
|
||||
if result.status == DiscoveryStatus.COMPLETED:
|
||||
try:
|
||||
with loading("Saving discovery results..."):
|
||||
save_path = save_discovery_result(result)
|
||||
console.print(f"\n[dim]Results saved to: {save_path}[/dim]")
|
||||
except (IOError, OSError, ValueError) as e:
|
||||
console.print(f"\n[yellow]Warning: Could not save results: {e}[/yellow]")
|
||||
|
||||
console.print()
|
||||
|
||||
if not result.trending_stocks:
|
||||
console.print("[yellow]No trending stocks found matching your criteria.[/yellow]")
|
||||
return
|
||||
|
||||
console.print(f"[green]Found {len(result.trending_stocks)} trending stocks[/green]")
|
||||
console.print()
|
||||
|
||||
results_table = create_discovery_results_table(result.trending_stocks)
|
||||
console.print(results_table)
|
||||
console.print()
|
||||
|
||||
while True:
|
||||
selected_stock = select_stock_for_detail(result.trending_stocks)
|
||||
|
||||
if selected_stock is None:
|
||||
break
|
||||
|
||||
rank = result.trending_stocks.index(selected_stock) + 1
|
||||
detail_panel = create_stock_detail_panel(selected_stock, rank)
|
||||
console.print()
|
||||
console.print(detail_panel)
|
||||
console.print()
|
||||
|
||||
analyze_choice = questionary.confirm(
|
||||
f"Analyze {selected_stock.ticker}?",
|
||||
default=False,
|
||||
style=questionary.Style(
|
||||
[
|
||||
("selected", "fg:green noinherit"),
|
||||
("highlighted", "fg:green noinherit"),
|
||||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
if analyze_choice and run_analysis_callback:
|
||||
console.print()
|
||||
with loading(f"Preparing analysis for {selected_stock.ticker}...", spinner_style="loading"):
|
||||
time.sleep(0.5)
|
||||
run_analysis_callback(selected_stock.ticker, config)
|
||||
break
|
||||
|
|
@ -0,0 +1,413 @@
|
|||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.spinner import Spinner
|
||||
from rich.markdown import Markdown
|
||||
from rich.layout import Layout
|
||||
from rich.text import Text
|
||||
from rich.table import Table
|
||||
from rich.columns import Columns
|
||||
from rich import box
|
||||
|
||||
from cli.state import message_buffer
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def create_layout():
|
||||
layout = Layout()
|
||||
layout.split_column(
|
||||
Layout(name="header", size=3),
|
||||
Layout(name="main"),
|
||||
Layout(name="footer", size=3),
|
||||
)
|
||||
layout["main"].split_column(
|
||||
Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5)
|
||||
)
|
||||
layout["upper"].split_row(
|
||||
Layout(name="progress", ratio=2), Layout(name="messages", ratio=3)
|
||||
)
|
||||
return layout
|
||||
|
||||
|
||||
def update_display(layout, spinner_text=None):
|
||||
layout["header"].update(
|
||||
Panel(
|
||||
"[bold green]Welcome to TradingAgents CLI[/bold green]\n"
|
||||
"[dim]Built by Tauric Research (https://github.com/TauricResearch)[/dim]",
|
||||
title="Welcome to TradingAgents",
|
||||
border_style="green",
|
||||
padding=(1, 2),
|
||||
expand=True,
|
||||
)
|
||||
)
|
||||
|
||||
progress_table = Table(
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
show_footer=False,
|
||||
box=box.SIMPLE_HEAD,
|
||||
title=None,
|
||||
padding=(0, 2),
|
||||
expand=True,
|
||||
)
|
||||
progress_table.add_column("Team", style="cyan", justify="center", width=20)
|
||||
progress_table.add_column("Agent", style="green", justify="center", width=20)
|
||||
progress_table.add_column("Status", style="yellow", justify="center", width=20)
|
||||
|
||||
teams = {
|
||||
"Analyst Team": [
|
||||
"Market Analyst",
|
||||
"Social Analyst",
|
||||
"News Analyst",
|
||||
"Fundamentals Analyst",
|
||||
],
|
||||
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
|
||||
"Trading Team": ["Trader"],
|
||||
"Risk Management": ["Risky Analyst", "Neutral Analyst", "Safe Analyst"],
|
||||
"Portfolio Management": ["Portfolio Manager"],
|
||||
}
|
||||
|
||||
for team, agents in teams.items():
|
||||
first_agent = agents[0]
|
||||
status = message_buffer.agent_status[first_agent]
|
||||
if status == "in_progress":
|
||||
spinner = Spinner(
|
||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
||||
)
|
||||
status_cell = spinner
|
||||
else:
|
||||
status_color = {
|
||||
"pending": "yellow",
|
||||
"completed": "green",
|
||||
"error": "red",
|
||||
}.get(status, "white")
|
||||
status_cell = f"[{status_color}]{status}[/{status_color}]"
|
||||
progress_table.add_row(team, first_agent, status_cell)
|
||||
|
||||
for agent in agents[1:]:
|
||||
status = message_buffer.agent_status[agent]
|
||||
if status == "in_progress":
|
||||
spinner = Spinner(
|
||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
||||
)
|
||||
status_cell = spinner
|
||||
else:
|
||||
status_color = {
|
||||
"pending": "yellow",
|
||||
"completed": "green",
|
||||
"error": "red",
|
||||
}.get(status, "white")
|
||||
status_cell = f"[{status_color}]{status}[/{status_color}]"
|
||||
progress_table.add_row("", agent, status_cell)
|
||||
|
||||
progress_table.add_row("-" * 20, "-" * 20, "-" * 20, style="dim")
|
||||
|
||||
layout["progress"].update(
|
||||
Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2))
|
||||
)
|
||||
|
||||
messages_table = Table(
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
show_footer=False,
|
||||
expand=True,
|
||||
box=box.MINIMAL,
|
||||
show_lines=True,
|
||||
padding=(0, 1),
|
||||
)
|
||||
messages_table.add_column("Time", style="cyan", width=8, justify="center")
|
||||
messages_table.add_column("Type", style="green", width=10, justify="center")
|
||||
messages_table.add_column("Content", style="white", no_wrap=False, ratio=1)
|
||||
|
||||
all_messages = []
|
||||
|
||||
for timestamp, tool_name, args in message_buffer.tool_calls:
|
||||
if isinstance(args, str) and len(args) > 100:
|
||||
args = args[:97] + "..."
|
||||
all_messages.append((timestamp, "Tool", f"{tool_name}: {args}"))
|
||||
|
||||
for timestamp, msg_type, content in message_buffer.messages:
|
||||
content_str = content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get('type') == 'text':
|
||||
text_parts.append(item.get('text', ''))
|
||||
elif item.get('type') == 'tool_use':
|
||||
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
content_str = ' '.join(text_parts)
|
||||
elif not isinstance(content_str, str):
|
||||
content_str = str(content)
|
||||
|
||||
if len(content_str) > 200:
|
||||
content_str = content_str[:197] + "..."
|
||||
all_messages.append((timestamp, msg_type, content_str))
|
||||
|
||||
all_messages.sort(key=lambda x: x[0])
|
||||
max_messages = 12
|
||||
recent_messages = all_messages[-max_messages:]
|
||||
|
||||
for timestamp, msg_type, content in recent_messages:
|
||||
wrapped_content = Text(content, overflow="fold")
|
||||
messages_table.add_row(timestamp, msg_type, wrapped_content)
|
||||
|
||||
if spinner_text:
|
||||
messages_table.add_row("", "Spinner", spinner_text)
|
||||
|
||||
if len(all_messages) > max_messages:
|
||||
messages_table.footer = (
|
||||
f"[dim]Showing last {max_messages} of {len(all_messages)} messages[/dim]"
|
||||
)
|
||||
|
||||
layout["messages"].update(
|
||||
Panel(
|
||||
messages_table,
|
||||
title="Messages & Tools",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if message_buffer.current_report:
|
||||
layout["analysis"].update(
|
||||
Panel(
|
||||
Markdown(message_buffer.current_report),
|
||||
title="Current Report",
|
||||
border_style="green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
else:
|
||||
layout["analysis"].update(
|
||||
Panel(
|
||||
"[italic]Waiting for analysis report...[/italic]",
|
||||
title="Current Report",
|
||||
border_style="green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
tool_calls_count = len(message_buffer.tool_calls)
|
||||
llm_calls_count = sum(
|
||||
1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning"
|
||||
)
|
||||
reports_count = sum(
|
||||
1 for content in message_buffer.report_sections.values() if content is not None
|
||||
)
|
||||
|
||||
stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
|
||||
stats_table.add_column("Stats", justify="center")
|
||||
stats_table.add_row(
|
||||
f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}"
|
||||
)
|
||||
|
||||
layout["footer"].update(Panel(stats_table, border_style="grey50"))
|
||||
|
||||
|
||||
def display_complete_report(final_state):
|
||||
console.print("\n[bold green]Complete Analysis Report[/bold green]\n")
|
||||
|
||||
analyst_reports = []
|
||||
|
||||
if final_state.get("market_report"):
|
||||
analyst_reports.append(
|
||||
Panel(
|
||||
Markdown(final_state["market_report"]),
|
||||
title="Market Analyst",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if final_state.get("sentiment_report"):
|
||||
analyst_reports.append(
|
||||
Panel(
|
||||
Markdown(final_state["sentiment_report"]),
|
||||
title="Social Analyst",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if final_state.get("news_report"):
|
||||
analyst_reports.append(
|
||||
Panel(
|
||||
Markdown(final_state["news_report"]),
|
||||
title="News Analyst",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if final_state.get("fundamentals_report"):
|
||||
analyst_reports.append(
|
||||
Panel(
|
||||
Markdown(final_state["fundamentals_report"]),
|
||||
title="Fundamentals Analyst",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if analyst_reports:
|
||||
console.print(
|
||||
Panel(
|
||||
Columns(analyst_reports, equal=True, expand=True),
|
||||
title="I. Analyst Team Reports",
|
||||
border_style="cyan",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if final_state.get("investment_debate_state"):
|
||||
research_reports = []
|
||||
debate_state = final_state["investment_debate_state"]
|
||||
|
||||
if debate_state.get("bull_history"):
|
||||
research_reports.append(
|
||||
Panel(
|
||||
Markdown(debate_state["bull_history"]),
|
||||
title="Bull Researcher",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if debate_state.get("bear_history"):
|
||||
research_reports.append(
|
||||
Panel(
|
||||
Markdown(debate_state["bear_history"]),
|
||||
title="Bear Researcher",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if debate_state.get("judge_decision"):
|
||||
research_reports.append(
|
||||
Panel(
|
||||
Markdown(debate_state["judge_decision"]),
|
||||
title="Research Manager",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if research_reports:
|
||||
console.print(
|
||||
Panel(
|
||||
Columns(research_reports, equal=True, expand=True),
|
||||
title="II. Research Team Decision",
|
||||
border_style="magenta",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if final_state.get("trader_investment_plan"):
|
||||
console.print(
|
||||
Panel(
|
||||
Panel(
|
||||
Markdown(final_state["trader_investment_plan"]),
|
||||
title="Trader",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
),
|
||||
title="III. Trading Team Plan",
|
||||
border_style="yellow",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if final_state.get("risk_debate_state"):
|
||||
risk_reports = []
|
||||
risk_state = final_state["risk_debate_state"]
|
||||
|
||||
if risk_state.get("risky_history"):
|
||||
risk_reports.append(
|
||||
Panel(
|
||||
Markdown(risk_state["risky_history"]),
|
||||
title="Aggressive Analyst",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if risk_state.get("safe_history"):
|
||||
risk_reports.append(
|
||||
Panel(
|
||||
Markdown(risk_state["safe_history"]),
|
||||
title="Conservative Analyst",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if risk_state.get("neutral_history"):
|
||||
risk_reports.append(
|
||||
Panel(
|
||||
Markdown(risk_state["neutral_history"]),
|
||||
title="Neutral Analyst",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if risk_reports:
|
||||
console.print(
|
||||
Panel(
|
||||
Columns(risk_reports, equal=True, expand=True),
|
||||
title="IV. Risk Management Team Decision",
|
||||
border_style="red",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
if risk_state.get("judge_decision"):
|
||||
console.print(
|
||||
Panel(
|
||||
Panel(
|
||||
Markdown(risk_state["judge_decision"]),
|
||||
title="Portfolio Manager",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
),
|
||||
title="V. Portfolio Manager Decision",
|
||||
border_style="green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def update_research_team_status(status):
|
||||
research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"]
|
||||
for agent in research_team:
|
||||
message_buffer.update_agent_status(agent, status)
|
||||
|
||||
|
||||
def extract_content_string(content):
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get('type') == 'text':
|
||||
text_parts.append(item.get('text', ''))
|
||||
elif item.get('type') == 'tool_use':
|
||||
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
return ' '.join(text_parts)
|
||||
else:
|
||||
return str(content)
|
||||
|
||||
|
||||
def create_question_box(title: str, prompt: str, default: str = None) -> Panel:
|
||||
box_content = f"[bold]{title}[/bold]\n"
|
||||
box_content += f"[dim]{prompt}[/dim]"
|
||||
if default:
|
||||
box_content += f"\n[dim]Default: {default}[/dim]"
|
||||
return Panel(box_content, border_style="blue", padding=(1, 2))
|
||||
1835
cli/main.py
1835
cli/main.py
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,135 @@
|
|||
import datetime
|
||||
from collections import deque
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class MessageBuffer:
|
||||
def __init__(self, max_length=100):
|
||||
self.messages = deque(maxlen=max_length)
|
||||
self.tool_calls = deque(maxlen=max_length)
|
||||
self.current_report = None
|
||||
self.final_report = None
|
||||
self.agent_status = {
|
||||
"Market Analyst": "pending",
|
||||
"Social Analyst": "pending",
|
||||
"News Analyst": "pending",
|
||||
"Fundamentals Analyst": "pending",
|
||||
"Bull Researcher": "pending",
|
||||
"Bear Researcher": "pending",
|
||||
"Research Manager": "pending",
|
||||
"Trader": "pending",
|
||||
"Risky Analyst": "pending",
|
||||
"Neutral Analyst": "pending",
|
||||
"Safe Analyst": "pending",
|
||||
"Portfolio Manager": "pending",
|
||||
}
|
||||
self.current_agent = None
|
||||
self.report_sections = {
|
||||
"market_report": None,
|
||||
"sentiment_report": None,
|
||||
"news_report": None,
|
||||
"fundamentals_report": None,
|
||||
"investment_plan": None,
|
||||
"trader_investment_plan": None,
|
||||
"final_trade_decision": None,
|
||||
}
|
||||
|
||||
def add_message(self, message_type: str, content: str):
|
||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
self.messages.append((timestamp, message_type, content))
|
||||
|
||||
def add_tool_call(self, tool_name: str, args: Dict[str, Any]):
|
||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
self.tool_calls.append((timestamp, tool_name, args))
|
||||
|
||||
def update_agent_status(self, agent: str, status: str):
|
||||
if agent in self.agent_status:
|
||||
self.agent_status[agent] = status
|
||||
self.current_agent = agent
|
||||
|
||||
def update_report_section(self, section_name: str, content: str):
|
||||
if section_name in self.report_sections:
|
||||
self.report_sections[section_name] = content
|
||||
self._update_current_report()
|
||||
|
||||
def _update_current_report(self):
|
||||
latest_section = None
|
||||
latest_content = None
|
||||
|
||||
for section, content in self.report_sections.items():
|
||||
if content is not None:
|
||||
latest_section = section
|
||||
latest_content = content
|
||||
|
||||
if latest_section and latest_content:
|
||||
section_titles = {
|
||||
"market_report": "Market Analysis",
|
||||
"sentiment_report": "Social Sentiment",
|
||||
"news_report": "News Analysis",
|
||||
"fundamentals_report": "Fundamentals Analysis",
|
||||
"investment_plan": "Research Team Decision",
|
||||
"trader_investment_plan": "Trading Team Plan",
|
||||
"final_trade_decision": "Portfolio Management Decision",
|
||||
}
|
||||
self.current_report = (
|
||||
f"### {section_titles[latest_section]}\n{latest_content}"
|
||||
)
|
||||
|
||||
self._update_final_report()
|
||||
|
||||
def _update_final_report(self):
|
||||
report_parts = []
|
||||
|
||||
if any(
|
||||
self.report_sections[section]
|
||||
for section in [
|
||||
"market_report",
|
||||
"sentiment_report",
|
||||
"news_report",
|
||||
"fundamentals_report",
|
||||
]
|
||||
):
|
||||
report_parts.append("## Analyst Team Reports")
|
||||
if self.report_sections["market_report"]:
|
||||
report_parts.append(
|
||||
f"### Market Analysis\n{self.report_sections['market_report']}"
|
||||
)
|
||||
if self.report_sections["sentiment_report"]:
|
||||
report_parts.append(
|
||||
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
|
||||
)
|
||||
if self.report_sections["news_report"]:
|
||||
report_parts.append(
|
||||
f"### News Analysis\n{self.report_sections['news_report']}"
|
||||
)
|
||||
if self.report_sections["fundamentals_report"]:
|
||||
report_parts.append(
|
||||
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
|
||||
)
|
||||
|
||||
if self.report_sections["investment_plan"]:
|
||||
report_parts.append("## Research Team Decision")
|
||||
report_parts.append(f"{self.report_sections['investment_plan']}")
|
||||
|
||||
if self.report_sections["trader_investment_plan"]:
|
||||
report_parts.append("## Trading Team Plan")
|
||||
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
|
||||
|
||||
if self.report_sections["final_trade_decision"]:
|
||||
report_parts.append("## Portfolio Management Decision")
|
||||
report_parts.append(f"{self.report_sections['final_trade_decision']}")
|
||||
|
||||
self.final_report = "\n\n".join(report_parts) if report_parts else None
|
||||
|
||||
def reset(self):
|
||||
for agent in self.agent_status:
|
||||
self.agent_status[agent] = "pending"
|
||||
for section in self.report_sections:
|
||||
self.report_sections[section] = None
|
||||
self.current_report = None
|
||||
self.final_report = None
|
||||
self.messages.clear()
|
||||
self.tool_calls.clear()
|
||||
|
||||
|
||||
message_buffer = MessageBuffer()
|
||||
Loading…
Reference in New Issue