TradingAgents/app.py

304 lines
11 KiB
Python

"""FastAPI SSE backend for TradingAgents."""
import os
import time
import uuid
import asyncio
import json
from datetime import date
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from cli.stats_handler import StatsCallbackHandler
from cli.main import (
MessageBuffer,
classify_message_type,
update_analyst_statuses,
update_research_team_status,
)
app = FastAPI(title="TradingAgents API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Active analysis queues: id -> asyncio.Queue
analyses: dict[str, asyncio.Queue] = {}
class AnalyzeRequest(BaseModel):
ticker: str
date: str | None = None
def build_config():
"""Build TradingAgents config for Anthropic/Claude."""
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "anthropic"
config["deep_think_llm"] = os.getenv("DEEP_THINK_MODEL", "claude-sonnet-4-6")
config["quick_think_llm"] = os.getenv("QUICK_THINK_MODEL", "claude-haiku-4-5-20251001")
config["backend_url"] = None
config["max_debate_rounds"] = 1
config["max_risk_discuss_rounds"] = 1
config["data_vendors"] = {
"core_stock_apis": "yfinance",
"technical_indicators": "yfinance",
"fundamental_data": "yfinance",
"news_data": "yfinance",
}
return config
def get_stats_dict(stats_handler, buf, start_time):
"""Build stats dict for SSE events."""
s = stats_handler.get_stats()
agents_done = sum(1 for v in buf.agent_status.values() if v == "completed")
elapsed = time.time() - start_time
return {
"agents_done": agents_done,
"agents_total": len(buf.agent_status),
"llm_calls": s["llm_calls"],
"tool_calls": s["tool_calls"],
"tokens_in": s["tokens_in"],
"tokens_out": s["tokens_out"],
"reports_done": buf.get_completed_reports_count(),
"reports_total": len(buf.report_sections),
"elapsed": round(elapsed, 1),
}
def _agent_stage(agent_name):
"""Map agent name to pipeline stage."""
if agent_name in ("Market Analyst", "Social Analyst", "News Analyst", "Fundamentals Analyst"):
return "analysts"
if agent_name in ("Bull Researcher", "Bear Researcher", "Research Manager"):
return "research"
if agent_name == "Trader":
return "trading"
if agent_name in ("Aggressive Analyst", "Conservative Analyst", "Neutral Analyst"):
return "risk"
if agent_name == "Portfolio Manager":
return "decision"
return "unknown"
async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
"""Background task that runs the TradingAgents pipeline and pushes SSE events."""
q = analyses[analysis_id]
config = build_config()
stats_handler = StatsCallbackHandler()
selected_analysts = ["market", "social", "news", "fundamentals"]
try:
graph = TradingAgentsGraph(
selected_analysts=selected_analysts,
debug=False,
config=config,
callbacks=[stats_handler],
)
except Exception as e:
await q.put({"type": "error", "message": f"Init failed: {e}"})
await q.put(None)
return
buf = MessageBuffer()
buf.init_for_analysis(selected_analysts)
init_state = graph.propagator.create_initial_state(ticker, trade_date)
args = graph.propagator.get_graph_args(callbacks=[stats_handler])
start_time = time.time()
emitted_reports = set()
research_emitted = False
trader_emitted = False
risk_emitted = False
final_state = None
prev_statuses = {}
try:
async for chunk in graph.graph.astream(init_state, **args):
final_state = chunk
# Process messages (same logic as Chainlit app)
if chunk.get("messages") and len(chunk["messages"]) > 0:
last_msg = chunk["messages"][-1]
msg_id = getattr(last_msg, "id", None)
if msg_id != buf._last_message_id:
buf._last_message_id = msg_id
msg_type, content = classify_message_type(last_msg)
if content and content.strip():
buf.add_message(msg_type, content)
if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
for tc in last_msg.tool_calls:
if isinstance(tc, dict):
buf.add_tool_call(tc["name"], tc["args"])
else:
buf.add_tool_call(tc.name, tc.args)
update_analyst_statuses(buf, chunk)
st = get_stats_dict(stats_handler, buf, start_time)
# Emit agent status changes only (avoid flooding)
for agent, status in buf.agent_status.items():
if prev_statuses.get(agent) != status:
prev_statuses[agent] = status
await q.put({
"type": "agent_update",
"agent": agent,
"stage": _agent_stage(agent),
"status": status,
"stats": st,
})
# Analyst reports
report_map = {
"market_report": ("Market Analyst", "analysts"),
"sentiment_report": ("Social Analyst", "analysts"),
"news_report": ("News Analyst", "analysts"),
"fundamentals_report": ("Fundamentals Analyst", "analysts"),
}
for field, (agent_name, stage) in report_map.items():
if field not in emitted_reports and chunk.get(field):
emitted_reports.add(field)
await q.put({
"type": "report",
"agent": agent_name,
"stage": stage,
"field": field,
"report": chunk[field],
"stats": st,
})
# Research debate
if chunk.get("investment_debate_state"):
debate = chunk["investment_debate_state"]
bull = debate.get("bull_history", "").strip()
bear = debate.get("bear_history", "").strip()
judge = debate.get("judge_decision", "").strip()
if bull or bear:
update_research_team_status("in_progress")
if judge and not research_emitted:
research_emitted = True
update_research_team_status("completed")
buf.update_agent_status("Trader", "in_progress")
await q.put({
"type": "debate",
"stage": "research",
"bull": bull,
"bear": bear,
"judge": judge,
"stats": get_stats_dict(stats_handler, buf, start_time),
})
# Trader plan
if chunk.get("trader_investment_plan") and not trader_emitted:
trader_emitted = True
buf.update_agent_status("Trader", "completed")
buf.update_agent_status("Aggressive Analyst", "in_progress")
await q.put({
"type": "trader",
"stage": "trading",
"plan": chunk["trader_investment_plan"],
"stats": get_stats_dict(stats_handler, buf, start_time),
})
# Risk debate
if chunk.get("risk_debate_state"):
risk = chunk["risk_debate_state"]
agg = risk.get("aggressive_history", "").strip()
con = risk.get("conservative_history", "").strip()
neu = risk.get("neutral_history", "").strip()
judge = risk.get("judge_decision", "").strip()
if agg:
buf.update_agent_status("Aggressive Analyst", "in_progress")
if con:
buf.update_agent_status("Conservative Analyst", "in_progress")
if neu:
buf.update_agent_status("Neutral Analyst", "in_progress")
if judge and not risk_emitted:
risk_emitted = True
buf.update_agent_status("Aggressive Analyst", "completed")
buf.update_agent_status("Conservative Analyst", "completed")
buf.update_agent_status("Neutral Analyst", "completed")
buf.update_agent_status("Portfolio Manager", "completed")
await q.put({
"type": "risk",
"stage": "risk",
"aggressive": agg,
"conservative": con,
"neutral": neu,
"judge": judge,
"stats": get_stats_dict(stats_handler, buf, start_time),
})
except Exception as e:
await q.put({"type": "error", "message": str(e)})
await q.put(None)
return
# Final decision
if final_state:
decision_text = final_state.get("final_trade_decision", "No decision reached.")
signal = graph.process_signal(decision_text)
for agent in buf.agent_status:
buf.update_agent_status(agent, "completed")
st = get_stats_dict(stats_handler, buf, start_time)
await q.put({
"type": "decision",
"stage": "decision",
"signal": signal,
"decision_text": decision_text,
"stats": st,
})
await q.put(None) # sentinel — stream done
@app.post("/analyze")
async def start_analysis(req: AnalyzeRequest):
ticker = req.ticker.upper().strip()
if not ticker or len(ticker) > 5:
raise HTTPException(400, "Invalid ticker")
trade_date = req.date or str(date.today())
analysis_id = str(uuid.uuid4())
analyses[analysis_id] = asyncio.Queue()
asyncio.create_task(run_analysis(analysis_id, ticker, trade_date))
return {"id": analysis_id, "ticker": ticker, "date": trade_date}
@app.get("/analyze/{analysis_id}/stream")
async def stream_analysis(analysis_id: str):
if analysis_id not in analyses:
raise HTTPException(404, "Analysis not found")
q = analyses[analysis_id]
async def event_generator():
while True:
try:
event = await asyncio.wait_for(q.get(), timeout=15)
except asyncio.TimeoutError:
yield {"event": "heartbeat", "data": ""}
continue
if event is None:
break
yield {"data": json.dumps(event)}
analyses.pop(analysis_id, None)
return EventSourceResponse(event_generator())
@app.get("/health")
async def health():
return {"status": "ok"}