"""FastAPI SSE backend for TradingAgents.""" import os import time import uuid import asyncio import json import traceback as _tb from datetime import date from fastapi import FastAPI, HTTPException, Request, Depends from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from sse_starlette.sse import EventSourceResponse from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG from cli.stats_handler import StatsCallbackHandler from cli.main import ( MessageBuffer, classify_message_type, update_analyst_statuses, ) app = FastAPI(title="TradingAgents API") # --- CORS --- _cors_env = os.getenv("CORS_ORIGINS", "") _cors_origins = [o.strip() for o in _cors_env.split(",") if o.strip()] if _cors_env else ["*"] app.add_middleware( CORSMiddleware, allow_origins=_cors_origins, allow_methods=["*"], allow_headers=["*"], ) # --- Auth dependency --- _API_KEY = os.getenv("AGENTS_API_KEY", "") async def verify_api_key(request: Request): if not _API_KEY: return # dev mode — no auth auth = request.headers.get("Authorization", "") if auth != f"Bearer {_API_KEY}": raise HTTPException(401, "Invalid or missing API key") # --- Concurrency --- MAX_CONCURRENT = int(os.getenv("MAX_CONCURRENT_ANALYSES", "3")) _semaphore = asyncio.Semaphore(MAX_CONCURRENT) # Active analysis state: id -> {queue, events (replay buffer), done, created_at} analyses: dict[str, dict] = {} class AnalyzeRequest(BaseModel): ticker: str date: str | None = None def build_config(): """Build TradingAgents config for Anthropic/Claude.""" config = DEFAULT_CONFIG.copy() config["llm_provider"] = "anthropic" config["deep_think_llm"] = os.getenv("DEEP_THINK_MODEL", "claude-sonnet-4-6") config["quick_think_llm"] = os.getenv("QUICK_THINK_MODEL", "claude-haiku-4-5-20251001") config["backend_url"] = None config["max_debate_rounds"] = 1 config["max_risk_discuss_rounds"] = 1 config["data_vendors"] = { "core_stock_apis": "yfinance", "technical_indicators": "yfinance", "fundamental_data": "yfinance", "news_data": "yfinance", } config["parallel_analysts"] = True return config def get_stats_dict(stats_handler, buf, start_time): """Build stats dict for SSE events.""" s = stats_handler.get_stats() agents_done = sum(1 for v in buf.agent_status.values() if v == "completed") elapsed = time.time() - start_time return { "agents_done": agents_done, "agents_total": len(buf.agent_status), "llm_calls": s["llm_calls"], "tool_calls": s["tool_calls"], "tokens_in": s["tokens_in"], "tokens_out": s["tokens_out"], "reports_done": buf.get_completed_reports_count(), "reports_total": len(buf.report_sections), "elapsed": round(elapsed, 1), } def _agent_stage(agent_name): """Map agent name to pipeline stage.""" if agent_name in ("Market Analyst", "Social Analyst", "News Analyst", "Fundamentals Analyst"): return "analysts" if agent_name in ("Bull Researcher", "Bear Researcher", "Research Manager"): return "research" if agent_name == "Trader": return "trading" if agent_name in ("Aggressive Analyst", "Conservative Analyst", "Neutral Analyst"): return "risk" if agent_name == "Portfolio Manager": return "decision" return "unknown" async def _run_analysis_inner(analysis_id: str, ticker: str, trade_date: str): """Core analysis logic.""" state = analyses[analysis_id] q = state["queue"] config = build_config() stats_handler = StatsCallbackHandler() selected_analysts = ["market", "social", "news", "fundamentals"] try: graph = TradingAgentsGraph( selected_analysts=selected_analysts, debug=False, config=config, callbacks=[stats_handler], ) except Exception as e: print(f"[ANALYSIS] Init failed: {e}\n{_tb.format_exc()}", flush=True) await q.put({"type": "error", "message": f"Init failed: {e}"}) await q.put(None) return buf = MessageBuffer() buf.init_for_analysis(selected_analysts) init_state = graph.propagator.create_initial_state(ticker, trade_date) args = graph.propagator.get_graph_args(callbacks=[stats_handler]) start_time = time.time() emitted_reports = set() research_emitted = False trader_emitted = False risk_emitted = False final_state = None prev_statuses = {} # Emit all analysts as "in_progress" immediately (they run in parallel) analyst_name_map = { "market": "Market Analyst", "social": "Social Analyst", "news": "News Analyst", "fundamentals": "Fundamentals Analyst", } for analyst_type in selected_analysts: agent_name = analyst_name_map[analyst_type] buf.update_agent_status(agent_name, "in_progress") st = get_stats_dict(stats_handler, buf, start_time) evt = { "type": "agent_update", "agent": agent_name, "stage": "analysts", "status": "in_progress", "stats": st, } state["events"].append(evt) await q.put(evt) prev_statuses[agent_name] = "in_progress" try: async for chunk in graph.graph.astream(init_state, **args): final_state = chunk # Process messages (same logic as Chainlit app) if chunk.get("messages") and len(chunk["messages"]) > 0: last_msg = chunk["messages"][-1] msg_id = getattr(last_msg, "id", None) if msg_id != buf._last_message_id: buf._last_message_id = msg_id msg_type, content = classify_message_type(last_msg) if content and content.strip(): buf.add_message(msg_type, content) if hasattr(last_msg, "tool_calls") and last_msg.tool_calls: for tc in last_msg.tool_calls: if isinstance(tc, dict): buf.add_tool_call(tc["name"], tc["args"]) else: buf.add_tool_call(tc.name, tc.args) update_analyst_statuses(buf, chunk) st = get_stats_dict(stats_handler, buf, start_time) # Emit agent status changes only (avoid flooding) for agent, status in buf.agent_status.items(): if prev_statuses.get(agent) != status: prev_statuses[agent] = status evt = { "type": "agent_update", "agent": agent, "stage": _agent_stage(agent), "status": status, "stats": st, } state["events"].append(evt) await q.put(evt) # Analyst reports report_map = { "market_report": ("Market Analyst", "analysts"), "sentiment_report": ("Social Analyst", "analysts"), "news_report": ("News Analyst", "analysts"), "fundamentals_report": ("Fundamentals Analyst", "analysts"), } for field, (agent_name, stage) in report_map.items(): if field not in emitted_reports and chunk.get(field): emitted_reports.add(field) evt = { "type": "report", "agent": agent_name, "stage": stage, "field": field, "report": chunk[field], "stats": st, } state["events"].append(evt) await q.put(evt) # Research debate if chunk.get("investment_debate_state") and not research_emitted: debate = chunk["investment_debate_state"] bull = debate.get("bull_history", "").strip() bear = debate.get("bear_history", "").strip() judge = debate.get("judge_decision", "").strip() if bull or bear: for a in ("Bull Researcher", "Bear Researcher", "Research Manager"): buf.update_agent_status(a, "in_progress") if judge: research_emitted = True for a in ("Bull Researcher", "Bear Researcher", "Research Manager"): buf.update_agent_status(a, "completed") buf.update_agent_status("Trader", "in_progress") buf.update_report_section("investment_plan", judge) evt = { "type": "debate", "stage": "research", "bull": bull, "bear": bear, "judge": judge, "stats": get_stats_dict(stats_handler, buf, start_time), } state["events"].append(evt) await q.put(evt) # Trader plan if chunk.get("trader_investment_plan") and not trader_emitted: trader_emitted = True buf.update_agent_status("Trader", "completed") buf.update_agent_status("Aggressive Analyst", "in_progress") buf.update_agent_status("Conservative Analyst", "in_progress") buf.update_agent_status("Neutral Analyst", "in_progress") buf.update_report_section("trader_investment_plan", chunk["trader_investment_plan"]) evt = { "type": "trader", "stage": "trading", "plan": chunk["trader_investment_plan"], "stats": get_stats_dict(stats_handler, buf, start_time), } state["events"].append(evt) await q.put(evt) # Risk debate if chunk.get("risk_debate_state") and not risk_emitted: risk = chunk["risk_debate_state"] agg = risk.get("aggressive_history", "").strip() con = risk.get("conservative_history", "").strip() neu = risk.get("neutral_history", "").strip() judge = risk.get("judge_decision", "").strip() if agg: buf.update_agent_status("Aggressive Analyst", "in_progress") if con: buf.update_agent_status("Conservative Analyst", "in_progress") if neu: buf.update_agent_status("Neutral Analyst", "in_progress") if judge: risk_emitted = True buf.update_agent_status("Aggressive Analyst", "completed") buf.update_agent_status("Conservative Analyst", "completed") buf.update_agent_status("Neutral Analyst", "completed") buf.update_agent_status("Portfolio Manager", "completed") evt = { "type": "risk", "stage": "risk", "aggressive": agg, "conservative": con, "neutral": neu, "judge": judge, "stats": get_stats_dict(stats_handler, buf, start_time), } state["events"].append(evt) await q.put(evt) except Exception as e: print(f"[ANALYSIS] Stream error: {e}\n{_tb.format_exc()}", flush=True) evt = {"type": "error", "message": str(e)} state["events"].append(evt) await q.put(evt) state["done"] = True await q.put(None) return # Final decision if final_state: decision_text = final_state.get("final_trade_decision", "No decision reached.") signal = graph.process_signal(decision_text) buf.update_report_section("final_trade_decision", decision_text) for agent in buf.agent_status: buf.update_agent_status(agent, "completed") st = get_stats_dict(stats_handler, buf, start_time) for agent, status in buf.agent_status.items(): if prev_statuses.get(agent) != "completed": prev_statuses[agent] = "completed" evt = { "type": "agent_update", "agent": agent, "stage": _agent_stage(agent), "status": "completed", "stats": st, } state["events"].append(evt) await q.put(evt) evt = { "type": "decision", "stage": "decision", "signal": signal, "decision_text": decision_text, "stats": st, } state["events"].append(evt) await q.put(evt) state["done"] = True await q.put(None) # sentinel — stream done async def run_analysis(analysis_id: str, ticker: str, trade_date: str): """Background task: acquires semaphore, runs analysis with timeout.""" state = analyses[analysis_id] q = state["queue"] async with _semaphore: try: await asyncio.wait_for( _run_analysis_inner(analysis_id, ticker, trade_date), timeout=600, # 10 minutes ) except asyncio.TimeoutError: print(f"[ANALYSIS] Timeout for {analysis_id}", flush=True) evt = {"type": "error", "message": "Analysis timed out after 10 minutes"} state["events"].append(evt) await q.put(evt) state["done"] = True await q.put(None) # --- Memory cleanup background task --- async def _cleanup_loop(): """Remove analyses older than 30 minutes every 5 minutes.""" while True: await asyncio.sleep(300) now = time.time() expired = [aid for aid, s in analyses.items() if now - s["created_at"] > 1800] for aid in expired: analyses.pop(aid, None) if expired: print(f"[CLEANUP] Removed {len(expired)} expired analyses", flush=True) @app.on_event("startup") async def _start_cleanup(): asyncio.create_task(_cleanup_loop()) # --- Routes --- @app.post("/analyze", dependencies=[Depends(verify_api_key)]) async def start_analysis(req: AnalyzeRequest): ticker = req.ticker.upper().strip() if not ticker or len(ticker) > 5 or not ticker.isalpha(): raise HTTPException(400, "Invalid ticker") trade_date = req.date or str(date.today()) analysis_id = str(uuid.uuid4()) analyses[analysis_id] = { "queue": asyncio.Queue(), "events": [], "done": False, "created_at": time.time(), } asyncio.create_task(run_analysis(analysis_id, ticker, trade_date)) return {"id": analysis_id, "ticker": ticker, "date": trade_date} @app.get("/analyze/{analysis_id}/stream", dependencies=[Depends(verify_api_key)]) async def stream_analysis(analysis_id: str, last_event: int = 0): """Stream SSE events. Supports reconnection via ?last_event=N to replay missed events.""" if analysis_id not in analyses: raise HTTPException(404, "Analysis not found") state = analyses[analysis_id] async def event_generator(): idx = last_event # Replay any events the client missed while idx < len(state["events"]): evt = state["events"][idx] idx += 1 yield {"id": str(idx), "data": json.dumps(evt)} # If analysis already done after replay, stop if state["done"]: return # Stream new events from queue q = state["queue"] while True: try: event = await asyncio.wait_for(q.get(), timeout=15) except asyncio.TimeoutError: yield {"event": "heartbeat", "data": ""} continue if event is None: break idx += 1 yield {"id": str(idx), "data": json.dumps(event)} return EventSourceResponse(event_generator()) @app.get("/health") async def health(): return {"status": "ok"}