diff --git a/app.py b/app.py index 981849b9..84d44d72 100644 --- a/app.py +++ b/app.py @@ -19,7 +19,6 @@ from cli.main import ( MessageBuffer, classify_message_type, update_analyst_statuses, - update_research_team_status, ) app = FastAPI(title="TradingAgents API") @@ -30,8 +29,8 @@ app.add_middleware( allow_headers=["*"], ) -# Active analysis queues: id -> asyncio.Queue -analyses: dict[str, asyncio.Queue] = {} +# Active analysis state: id -> {queue, events (replay buffer), done} +analyses: dict[str, dict] = {} class AnalyzeRequest(BaseModel): @@ -92,7 +91,8 @@ def _agent_stage(agent_name): async def run_analysis(analysis_id: str, ticker: str, trade_date: str): """Background task that runs the TradingAgents pipeline and pushes SSE events.""" - q = analyses[analysis_id] + state = analyses[analysis_id] + q = state["queue"] config = build_config() stats_handler = StatsCallbackHandler() selected_analysts = ["market", "social", "news", "fundamentals"] @@ -149,13 +149,15 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str): for agent, status in buf.agent_status.items(): if prev_statuses.get(agent) != status: prev_statuses[agent] = status - await q.put({ + evt = { "type": "agent_update", "agent": agent, "stage": _agent_stage(agent), "status": status, "stats": st, - }) + } + state["events"].append(evt) + await q.put(evt) # Analyst reports report_map = { @@ -167,14 +169,16 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str): for field, (agent_name, stage) in report_map.items(): if field not in emitted_reports and chunk.get(field): emitted_reports.add(field) - await q.put({ + evt = { "type": "report", "agent": agent_name, "stage": stage, "field": field, "report": chunk[field], "stats": st, - }) + } + state["events"].append(evt) + await q.put(evt) # Research debate if chunk.get("investment_debate_state"): @@ -184,32 +188,38 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str): judge = debate.get("judge_decision", "").strip() if bull or bear: - update_research_team_status("in_progress") + for a in ("Bull Researcher", "Bear Researcher", "Research Manager"): + buf.update_agent_status(a, "in_progress") if judge and not research_emitted: research_emitted = True - update_research_team_status("completed") + for a in ("Bull Researcher", "Bear Researcher", "Research Manager"): + buf.update_agent_status(a, "completed") buf.update_agent_status("Trader", "in_progress") - await q.put({ + evt = { "type": "debate", "stage": "research", "bull": bull, "bear": bear, "judge": judge, "stats": get_stats_dict(stats_handler, buf, start_time), - }) + } + state["events"].append(evt) + await q.put(evt) # Trader plan if chunk.get("trader_investment_plan") and not trader_emitted: trader_emitted = True buf.update_agent_status("Trader", "completed") buf.update_agent_status("Aggressive Analyst", "in_progress") - await q.put({ + evt = { "type": "trader", "stage": "trading", "plan": chunk["trader_investment_plan"], "stats": get_stats_dict(stats_handler, buf, start_time), - }) + } + state["events"].append(evt) + await q.put(evt) # Risk debate if chunk.get("risk_debate_state"): @@ -232,7 +242,7 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str): buf.update_agent_status("Conservative Analyst", "completed") buf.update_agent_status("Neutral Analyst", "completed") buf.update_agent_status("Portfolio Manager", "completed") - await q.put({ + evt = { "type": "risk", "stage": "risk", "aggressive": agg, @@ -240,10 +250,15 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str): "neutral": neu, "judge": judge, "stats": get_stats_dict(stats_handler, buf, start_time), - }) + } + state["events"].append(evt) + await q.put(evt) except Exception as e: - await q.put({"type": "error", "message": str(e)}) + evt = {"type": "error", "message": str(e)} + state["events"].append(evt) + await q.put(evt) + state["done"] = True await q.put(None) return @@ -254,14 +269,17 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str): for agent in buf.agent_status: buf.update_agent_status(agent, "completed") st = get_stats_dict(stats_handler, buf, start_time) - await q.put({ + evt = { "type": "decision", "stage": "decision", "signal": signal, "decision_text": decision_text, "stats": st, - }) + } + state["events"].append(evt) + await q.put(evt) + state["done"] = True await q.put(None) # sentinel — stream done @@ -272,18 +290,30 @@ async def start_analysis(req: AnalyzeRequest): raise HTTPException(400, "Invalid ticker") trade_date = req.date or str(date.today()) analysis_id = str(uuid.uuid4()) - analyses[analysis_id] = asyncio.Queue() + analyses[analysis_id] = {"queue": asyncio.Queue(), "events": [], "done": False} asyncio.create_task(run_analysis(analysis_id, ticker, trade_date)) return {"id": analysis_id, "ticker": ticker, "date": trade_date} @app.get("/analyze/{analysis_id}/stream") -async def stream_analysis(analysis_id: str): +async def stream_analysis(analysis_id: str, last_event: int = 0): + """Stream SSE events. Supports reconnection via ?last_event=N to replay missed events.""" if analysis_id not in analyses: raise HTTPException(404, "Analysis not found") - q = analyses[analysis_id] + state = analyses[analysis_id] async def event_generator(): + idx = last_event + # Replay any events the client missed + while idx < len(state["events"]): + evt = state["events"][idx] + idx += 1 + yield {"id": str(idx), "data": json.dumps(evt)} + # If analysis already done after replay, stop + if state["done"]: + return + # Stream new events from queue + q = state["queue"] while True: try: event = await asyncio.wait_for(q.get(), timeout=15) @@ -292,8 +322,8 @@ async def stream_analysis(analysis_id: str): continue if event is None: break - yield {"data": json.dumps(event)} - analyses.pop(analysis_id, None) + idx += 1 + yield {"id": str(idx), "data": json.dumps(event)} return EventSourceResponse(event_generator())