TradingAgents/app.py

380 lines
14 KiB
Python

"""FastAPI SSE backend for TradingAgents."""
import os
import time
import uuid
import asyncio
import json
from datetime import date
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from cli.stats_handler import StatsCallbackHandler
from cli.main import (
MessageBuffer,
classify_message_type,
update_analyst_statuses,
)
app = FastAPI(title="TradingAgents API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Active analysis state: id -> {queue, events (replay buffer), done}
analyses: dict[str, dict] = {}
class AnalyzeRequest(BaseModel):
ticker: str
date: str | None = None
def build_config():
"""Build TradingAgents config for Anthropic/Claude."""
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "anthropic"
config["deep_think_llm"] = os.getenv("DEEP_THINK_MODEL", "claude-sonnet-4-6")
config["quick_think_llm"] = os.getenv("QUICK_THINK_MODEL", "claude-haiku-4-5-20251001")
config["backend_url"] = None
config["max_debate_rounds"] = 1
config["max_risk_discuss_rounds"] = 1
config["data_vendors"] = {
"core_stock_apis": "yfinance",
"technical_indicators": "yfinance",
"fundamental_data": "yfinance",
"news_data": "yfinance",
}
config["parallel_analysts"] = True
return config
def get_stats_dict(stats_handler, buf, start_time):
"""Build stats dict for SSE events."""
s = stats_handler.get_stats()
agents_done = sum(1 for v in buf.agent_status.values() if v == "completed")
elapsed = time.time() - start_time
return {
"agents_done": agents_done,
"agents_total": len(buf.agent_status),
"llm_calls": s["llm_calls"],
"tool_calls": s["tool_calls"],
"tokens_in": s["tokens_in"],
"tokens_out": s["tokens_out"],
"reports_done": buf.get_completed_reports_count(),
"reports_total": len(buf.report_sections),
"elapsed": round(elapsed, 1),
}
def _agent_stage(agent_name):
"""Map agent name to pipeline stage."""
if agent_name in ("Market Analyst", "Social Analyst", "News Analyst", "Fundamentals Analyst"):
return "analysts"
if agent_name in ("Bull Researcher", "Bear Researcher", "Research Manager"):
return "research"
if agent_name == "Trader":
return "trading"
if agent_name in ("Aggressive Analyst", "Conservative Analyst", "Neutral Analyst"):
return "risk"
if agent_name == "Portfolio Manager":
return "decision"
return "unknown"
async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
"""Background task that runs the TradingAgents pipeline and pushes SSE events."""
import traceback as _tb
state = analyses[analysis_id]
q = state["queue"]
config = build_config()
stats_handler = StatsCallbackHandler()
selected_analysts = ["market", "social", "news", "fundamentals"]
try:
graph = TradingAgentsGraph(
selected_analysts=selected_analysts,
debug=False,
config=config,
callbacks=[stats_handler],
)
except Exception as e:
print(f"[ANALYSIS] Init failed: {e}\n{_tb.format_exc()}", flush=True)
await q.put({"type": "error", "message": f"Init failed: {e}"})
await q.put(None)
return
buf = MessageBuffer()
buf.init_for_analysis(selected_analysts)
init_state = graph.propagator.create_initial_state(ticker, trade_date)
args = graph.propagator.get_graph_args(callbacks=[stats_handler])
start_time = time.time()
emitted_reports = set()
research_emitted = False
trader_emitted = False
risk_emitted = False
final_state = None
prev_statuses = {}
# Emit all analysts as "in_progress" immediately (they run in parallel)
analyst_name_map = {
"market": "Market Analyst",
"social": "Social Analyst",
"news": "News Analyst",
"fundamentals": "Fundamentals Analyst",
}
for analyst_type in selected_analysts:
agent_name = analyst_name_map[analyst_type]
buf.update_agent_status(agent_name, "in_progress")
st = get_stats_dict(stats_handler, buf, start_time)
evt = {
"type": "agent_update",
"agent": agent_name,
"stage": "analysts",
"status": "in_progress",
"stats": st,
}
state["events"].append(evt)
await q.put(evt)
prev_statuses[agent_name] = "in_progress"
try:
async for chunk in graph.graph.astream(init_state, **args):
final_state = chunk
# Process messages (same logic as Chainlit app)
if chunk.get("messages") and len(chunk["messages"]) > 0:
last_msg = chunk["messages"][-1]
msg_id = getattr(last_msg, "id", None)
if msg_id != buf._last_message_id:
buf._last_message_id = msg_id
msg_type, content = classify_message_type(last_msg)
if content and content.strip():
buf.add_message(msg_type, content)
if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
for tc in last_msg.tool_calls:
if isinstance(tc, dict):
buf.add_tool_call(tc["name"], tc["args"])
else:
buf.add_tool_call(tc.name, tc.args)
update_analyst_statuses(buf, chunk)
st = get_stats_dict(stats_handler, buf, start_time)
# Emit agent status changes only (avoid flooding)
for agent, status in buf.agent_status.items():
if prev_statuses.get(agent) != status:
prev_statuses[agent] = status
evt = {
"type": "agent_update",
"agent": agent,
"stage": _agent_stage(agent),
"status": status,
"stats": st,
}
state["events"].append(evt)
await q.put(evt)
# Analyst reports
report_map = {
"market_report": ("Market Analyst", "analysts"),
"sentiment_report": ("Social Analyst", "analysts"),
"news_report": ("News Analyst", "analysts"),
"fundamentals_report": ("Fundamentals Analyst", "analysts"),
}
for field, (agent_name, stage) in report_map.items():
if field not in emitted_reports and chunk.get(field):
emitted_reports.add(field)
evt = {
"type": "report",
"agent": agent_name,
"stage": stage,
"field": field,
"report": chunk[field],
"stats": st,
}
state["events"].append(evt)
await q.put(evt)
# Research debate (guard with research_emitted to avoid resetting
# statuses on subsequent chunks in stream_mode="values")
if chunk.get("investment_debate_state") and not research_emitted:
debate = chunk["investment_debate_state"]
bull = debate.get("bull_history", "").strip()
bear = debate.get("bear_history", "").strip()
judge = debate.get("judge_decision", "").strip()
if bull or bear:
for a in ("Bull Researcher", "Bear Researcher", "Research Manager"):
buf.update_agent_status(a, "in_progress")
if judge:
research_emitted = True
for a in ("Bull Researcher", "Bear Researcher", "Research Manager"):
buf.update_agent_status(a, "completed")
buf.update_agent_status("Trader", "in_progress")
buf.update_report_section("investment_plan", judge)
evt = {
"type": "debate",
"stage": "research",
"bull": bull,
"bear": bear,
"judge": judge,
"stats": get_stats_dict(stats_handler, buf, start_time),
}
state["events"].append(evt)
await q.put(evt)
# Trader plan
if chunk.get("trader_investment_plan") and not trader_emitted:
trader_emitted = True
buf.update_agent_status("Trader", "completed")
buf.update_agent_status("Aggressive Analyst", "in_progress")
buf.update_agent_status("Conservative Analyst", "in_progress")
buf.update_agent_status("Neutral Analyst", "in_progress")
buf.update_report_section("trader_investment_plan", chunk["trader_investment_plan"])
evt = {
"type": "trader",
"stage": "trading",
"plan": chunk["trader_investment_plan"],
"stats": get_stats_dict(stats_handler, buf, start_time),
}
state["events"].append(evt)
await q.put(evt)
# Risk debate (guard with risk_emitted to avoid resetting
# statuses on subsequent chunks in stream_mode="values")
if chunk.get("risk_debate_state") and not risk_emitted:
risk = chunk["risk_debate_state"]
agg = risk.get("aggressive_history", "").strip()
con = risk.get("conservative_history", "").strip()
neu = risk.get("neutral_history", "").strip()
judge = risk.get("judge_decision", "").strip()
if agg:
buf.update_agent_status("Aggressive Analyst", "in_progress")
if con:
buf.update_agent_status("Conservative Analyst", "in_progress")
if neu:
buf.update_agent_status("Neutral Analyst", "in_progress")
if judge:
risk_emitted = True
buf.update_agent_status("Aggressive Analyst", "completed")
buf.update_agent_status("Conservative Analyst", "completed")
buf.update_agent_status("Neutral Analyst", "completed")
buf.update_agent_status("Portfolio Manager", "completed")
evt = {
"type": "risk",
"stage": "risk",
"aggressive": agg,
"conservative": con,
"neutral": neu,
"judge": judge,
"stats": get_stats_dict(stats_handler, buf, start_time),
}
state["events"].append(evt)
await q.put(evt)
except Exception as e:
print(f"[ANALYSIS] Stream error: {e}\n{_tb.format_exc()}", flush=True)
evt = {"type": "error", "message": str(e)}
state["events"].append(evt)
await q.put(evt)
state["done"] = True
await q.put(None)
return
# Final decision
if final_state:
decision_text = final_state.get("final_trade_decision", "No decision reached.")
signal = graph.process_signal(decision_text)
buf.update_report_section("final_trade_decision", decision_text)
for agent in buf.agent_status:
buf.update_agent_status(agent, "completed")
st = get_stats_dict(stats_handler, buf, start_time)
# Emit agent_update for any agents not yet shown as completed on the client
for agent, status in buf.agent_status.items():
if prev_statuses.get(agent) != "completed":
prev_statuses[agent] = "completed"
evt = {
"type": "agent_update",
"agent": agent,
"stage": _agent_stage(agent),
"status": "completed",
"stats": st,
}
state["events"].append(evt)
await q.put(evt)
evt = {
"type": "decision",
"stage": "decision",
"signal": signal,
"decision_text": decision_text,
"stats": st,
}
state["events"].append(evt)
await q.put(evt)
state["done"] = True
await q.put(None) # sentinel — stream done
@app.post("/analyze")
async def start_analysis(req: AnalyzeRequest):
ticker = req.ticker.upper().strip()
if not ticker or len(ticker) > 5:
raise HTTPException(400, "Invalid ticker")
trade_date = req.date or str(date.today())
analysis_id = str(uuid.uuid4())
analyses[analysis_id] = {"queue": asyncio.Queue(), "events": [], "done": False}
asyncio.create_task(run_analysis(analysis_id, ticker, trade_date))
return {"id": analysis_id, "ticker": ticker, "date": trade_date}
@app.get("/analyze/{analysis_id}/stream")
async def stream_analysis(analysis_id: str, last_event: int = 0):
"""Stream SSE events. Supports reconnection via ?last_event=N to replay missed events."""
if analysis_id not in analyses:
raise HTTPException(404, "Analysis not found")
state = analyses[analysis_id]
async def event_generator():
idx = last_event
# Replay any events the client missed
while idx < len(state["events"]):
evt = state["events"][idx]
idx += 1
yield {"id": str(idx), "data": json.dumps(evt)}
# If analysis already done after replay, stop
if state["done"]:
return
# Stream new events from queue
q = state["queue"]
while True:
try:
event = await asyncio.wait_for(q.get(), timeout=15)
except asyncio.TimeoutError:
yield {"event": "heartbeat", "data": ""}
continue
if event is None:
break
idx += 1
yield {"id": str(idx), "data": json.dumps(event)}
return EventSourceResponse(event_generator())
@app.get("/health")
async def health():
return {"status": "ok"}