304 lines
11 KiB
Python
304 lines
11 KiB
Python
"""FastAPI SSE backend for TradingAgents."""
|
|
|
|
import os
|
|
import time
|
|
import uuid
|
|
import asyncio
|
|
import json
|
|
from datetime import date
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from sse_starlette.sse import EventSourceResponse
|
|
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
from tradingagents.default_config import DEFAULT_CONFIG
|
|
from cli.stats_handler import StatsCallbackHandler
|
|
from cli.main import (
|
|
MessageBuffer,
|
|
classify_message_type,
|
|
update_analyst_statuses,
|
|
update_research_team_status,
|
|
)
|
|
|
|
app = FastAPI(title="TradingAgents API")
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Active analysis queues: id -> asyncio.Queue
|
|
analyses: dict[str, asyncio.Queue] = {}
|
|
|
|
|
|
class AnalyzeRequest(BaseModel):
|
|
ticker: str
|
|
date: str | None = None
|
|
|
|
|
|
def build_config():
|
|
"""Build TradingAgents config for Anthropic/Claude."""
|
|
config = DEFAULT_CONFIG.copy()
|
|
config["llm_provider"] = "anthropic"
|
|
config["deep_think_llm"] = os.getenv("DEEP_THINK_MODEL", "claude-sonnet-4-6")
|
|
config["quick_think_llm"] = os.getenv("QUICK_THINK_MODEL", "claude-haiku-4-5-20251001")
|
|
config["backend_url"] = None
|
|
config["max_debate_rounds"] = 1
|
|
config["max_risk_discuss_rounds"] = 1
|
|
config["data_vendors"] = {
|
|
"core_stock_apis": "yfinance",
|
|
"technical_indicators": "yfinance",
|
|
"fundamental_data": "yfinance",
|
|
"news_data": "yfinance",
|
|
}
|
|
return config
|
|
|
|
|
|
def get_stats_dict(stats_handler, buf, start_time):
|
|
"""Build stats dict for SSE events."""
|
|
s = stats_handler.get_stats()
|
|
agents_done = sum(1 for v in buf.agent_status.values() if v == "completed")
|
|
elapsed = time.time() - start_time
|
|
return {
|
|
"agents_done": agents_done,
|
|
"agents_total": len(buf.agent_status),
|
|
"llm_calls": s["llm_calls"],
|
|
"tool_calls": s["tool_calls"],
|
|
"tokens_in": s["tokens_in"],
|
|
"tokens_out": s["tokens_out"],
|
|
"reports_done": buf.get_completed_reports_count(),
|
|
"reports_total": len(buf.report_sections),
|
|
"elapsed": round(elapsed, 1),
|
|
}
|
|
|
|
|
|
def _agent_stage(agent_name):
|
|
"""Map agent name to pipeline stage."""
|
|
if agent_name in ("Market Analyst", "Social Analyst", "News Analyst", "Fundamentals Analyst"):
|
|
return "analysts"
|
|
if agent_name in ("Bull Researcher", "Bear Researcher", "Research Manager"):
|
|
return "research"
|
|
if agent_name == "Trader":
|
|
return "trading"
|
|
if agent_name in ("Aggressive Analyst", "Conservative Analyst", "Neutral Analyst"):
|
|
return "risk"
|
|
if agent_name == "Portfolio Manager":
|
|
return "decision"
|
|
return "unknown"
|
|
|
|
|
|
async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
|
"""Background task that runs the TradingAgents pipeline and pushes SSE events."""
|
|
q = analyses[analysis_id]
|
|
config = build_config()
|
|
stats_handler = StatsCallbackHandler()
|
|
selected_analysts = ["market", "social", "news", "fundamentals"]
|
|
|
|
try:
|
|
graph = TradingAgentsGraph(
|
|
selected_analysts=selected_analysts,
|
|
debug=False,
|
|
config=config,
|
|
callbacks=[stats_handler],
|
|
)
|
|
except Exception as e:
|
|
await q.put({"type": "error", "message": f"Init failed: {e}"})
|
|
await q.put(None)
|
|
return
|
|
|
|
buf = MessageBuffer()
|
|
buf.init_for_analysis(selected_analysts)
|
|
init_state = graph.propagator.create_initial_state(ticker, trade_date)
|
|
args = graph.propagator.get_graph_args(callbacks=[stats_handler])
|
|
start_time = time.time()
|
|
|
|
emitted_reports = set()
|
|
research_emitted = False
|
|
trader_emitted = False
|
|
risk_emitted = False
|
|
final_state = None
|
|
prev_statuses = {}
|
|
|
|
try:
|
|
async for chunk in graph.graph.astream(init_state, **args):
|
|
final_state = chunk
|
|
|
|
# Process messages (same logic as Chainlit app)
|
|
if chunk.get("messages") and len(chunk["messages"]) > 0:
|
|
last_msg = chunk["messages"][-1]
|
|
msg_id = getattr(last_msg, "id", None)
|
|
if msg_id != buf._last_message_id:
|
|
buf._last_message_id = msg_id
|
|
msg_type, content = classify_message_type(last_msg)
|
|
if content and content.strip():
|
|
buf.add_message(msg_type, content)
|
|
if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
|
|
for tc in last_msg.tool_calls:
|
|
if isinstance(tc, dict):
|
|
buf.add_tool_call(tc["name"], tc["args"])
|
|
else:
|
|
buf.add_tool_call(tc.name, tc.args)
|
|
|
|
update_analyst_statuses(buf, chunk)
|
|
st = get_stats_dict(stats_handler, buf, start_time)
|
|
|
|
# Emit agent status changes only (avoid flooding)
|
|
for agent, status in buf.agent_status.items():
|
|
if prev_statuses.get(agent) != status:
|
|
prev_statuses[agent] = status
|
|
await q.put({
|
|
"type": "agent_update",
|
|
"agent": agent,
|
|
"stage": _agent_stage(agent),
|
|
"status": status,
|
|
"stats": st,
|
|
})
|
|
|
|
# Analyst reports
|
|
report_map = {
|
|
"market_report": ("Market Analyst", "analysts"),
|
|
"sentiment_report": ("Social Analyst", "analysts"),
|
|
"news_report": ("News Analyst", "analysts"),
|
|
"fundamentals_report": ("Fundamentals Analyst", "analysts"),
|
|
}
|
|
for field, (agent_name, stage) in report_map.items():
|
|
if field not in emitted_reports and chunk.get(field):
|
|
emitted_reports.add(field)
|
|
await q.put({
|
|
"type": "report",
|
|
"agent": agent_name,
|
|
"stage": stage,
|
|
"field": field,
|
|
"report": chunk[field],
|
|
"stats": st,
|
|
})
|
|
|
|
# Research debate
|
|
if chunk.get("investment_debate_state"):
|
|
debate = chunk["investment_debate_state"]
|
|
bull = debate.get("bull_history", "").strip()
|
|
bear = debate.get("bear_history", "").strip()
|
|
judge = debate.get("judge_decision", "").strip()
|
|
|
|
if bull or bear:
|
|
update_research_team_status("in_progress")
|
|
|
|
if judge and not research_emitted:
|
|
research_emitted = True
|
|
update_research_team_status("completed")
|
|
buf.update_agent_status("Trader", "in_progress")
|
|
await q.put({
|
|
"type": "debate",
|
|
"stage": "research",
|
|
"bull": bull,
|
|
"bear": bear,
|
|
"judge": judge,
|
|
"stats": get_stats_dict(stats_handler, buf, start_time),
|
|
})
|
|
|
|
# Trader plan
|
|
if chunk.get("trader_investment_plan") and not trader_emitted:
|
|
trader_emitted = True
|
|
buf.update_agent_status("Trader", "completed")
|
|
buf.update_agent_status("Aggressive Analyst", "in_progress")
|
|
await q.put({
|
|
"type": "trader",
|
|
"stage": "trading",
|
|
"plan": chunk["trader_investment_plan"],
|
|
"stats": get_stats_dict(stats_handler, buf, start_time),
|
|
})
|
|
|
|
# Risk debate
|
|
if chunk.get("risk_debate_state"):
|
|
risk = chunk["risk_debate_state"]
|
|
agg = risk.get("aggressive_history", "").strip()
|
|
con = risk.get("conservative_history", "").strip()
|
|
neu = risk.get("neutral_history", "").strip()
|
|
judge = risk.get("judge_decision", "").strip()
|
|
|
|
if agg:
|
|
buf.update_agent_status("Aggressive Analyst", "in_progress")
|
|
if con:
|
|
buf.update_agent_status("Conservative Analyst", "in_progress")
|
|
if neu:
|
|
buf.update_agent_status("Neutral Analyst", "in_progress")
|
|
|
|
if judge and not risk_emitted:
|
|
risk_emitted = True
|
|
buf.update_agent_status("Aggressive Analyst", "completed")
|
|
buf.update_agent_status("Conservative Analyst", "completed")
|
|
buf.update_agent_status("Neutral Analyst", "completed")
|
|
buf.update_agent_status("Portfolio Manager", "completed")
|
|
await q.put({
|
|
"type": "risk",
|
|
"stage": "risk",
|
|
"aggressive": agg,
|
|
"conservative": con,
|
|
"neutral": neu,
|
|
"judge": judge,
|
|
"stats": get_stats_dict(stats_handler, buf, start_time),
|
|
})
|
|
|
|
except Exception as e:
|
|
await q.put({"type": "error", "message": str(e)})
|
|
await q.put(None)
|
|
return
|
|
|
|
# Final decision
|
|
if final_state:
|
|
decision_text = final_state.get("final_trade_decision", "No decision reached.")
|
|
signal = graph.process_signal(decision_text)
|
|
for agent in buf.agent_status:
|
|
buf.update_agent_status(agent, "completed")
|
|
st = get_stats_dict(stats_handler, buf, start_time)
|
|
await q.put({
|
|
"type": "decision",
|
|
"stage": "decision",
|
|
"signal": signal,
|
|
"decision_text": decision_text,
|
|
"stats": st,
|
|
})
|
|
|
|
await q.put(None) # sentinel — stream done
|
|
|
|
|
|
@app.post("/analyze")
|
|
async def start_analysis(req: AnalyzeRequest):
|
|
ticker = req.ticker.upper().strip()
|
|
if not ticker or len(ticker) > 5:
|
|
raise HTTPException(400, "Invalid ticker")
|
|
trade_date = req.date or str(date.today())
|
|
analysis_id = str(uuid.uuid4())
|
|
analyses[analysis_id] = asyncio.Queue()
|
|
asyncio.create_task(run_analysis(analysis_id, ticker, trade_date))
|
|
return {"id": analysis_id, "ticker": ticker, "date": trade_date}
|
|
|
|
|
|
@app.get("/analyze/{analysis_id}/stream")
|
|
async def stream_analysis(analysis_id: str):
|
|
if analysis_id not in analyses:
|
|
raise HTTPException(404, "Analysis not found")
|
|
q = analyses[analysis_id]
|
|
|
|
async def event_generator():
|
|
while True:
|
|
try:
|
|
event = await asyncio.wait_for(q.get(), timeout=15)
|
|
except asyncio.TimeoutError:
|
|
yield {"event": "heartbeat", "data": ""}
|
|
continue
|
|
if event is None:
|
|
break
|
|
yield {"data": json.dumps(event)}
|
|
analyses.pop(analysis_id, None)
|
|
|
|
return EventSourceResponse(event_generator())
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok"}
|