fix: add SSE replay buffer + fix research agent status tracking

- Replace global update_research_team_status() with local buf calls
  (was updating CLI's global buffer, not analysis-specific one)
- Add replay buffer: all events stored in memory per analysis
- Support ?last_event=N query param for reconnection replay
- Send event IDs so browser can track position
- Mark analysis as done so replay works after completion

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
dtarkent2-sys 2026-02-20 03:48:22 +00:00
parent 777226722a
commit f5519b9efe
1 changed files with 54 additions and 24 deletions

78
app.py
View File

@ -19,7 +19,6 @@ from cli.main import (
MessageBuffer, MessageBuffer,
classify_message_type, classify_message_type,
update_analyst_statuses, update_analyst_statuses,
update_research_team_status,
) )
app = FastAPI(title="TradingAgents API") app = FastAPI(title="TradingAgents API")
@ -30,8 +29,8 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# Active analysis queues: id -> asyncio.Queue # Active analysis state: id -> {queue, events (replay buffer), done}
analyses: dict[str, asyncio.Queue] = {} analyses: dict[str, dict] = {}
class AnalyzeRequest(BaseModel): class AnalyzeRequest(BaseModel):
@ -92,7 +91,8 @@ def _agent_stage(agent_name):
async def run_analysis(analysis_id: str, ticker: str, trade_date: str): async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
"""Background task that runs the TradingAgents pipeline and pushes SSE events.""" """Background task that runs the TradingAgents pipeline and pushes SSE events."""
q = analyses[analysis_id] state = analyses[analysis_id]
q = state["queue"]
config = build_config() config = build_config()
stats_handler = StatsCallbackHandler() stats_handler = StatsCallbackHandler()
selected_analysts = ["market", "social", "news", "fundamentals"] selected_analysts = ["market", "social", "news", "fundamentals"]
@ -149,13 +149,15 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
for agent, status in buf.agent_status.items(): for agent, status in buf.agent_status.items():
if prev_statuses.get(agent) != status: if prev_statuses.get(agent) != status:
prev_statuses[agent] = status prev_statuses[agent] = status
await q.put({ evt = {
"type": "agent_update", "type": "agent_update",
"agent": agent, "agent": agent,
"stage": _agent_stage(agent), "stage": _agent_stage(agent),
"status": status, "status": status,
"stats": st, "stats": st,
}) }
state["events"].append(evt)
await q.put(evt)
# Analyst reports # Analyst reports
report_map = { report_map = {
@ -167,14 +169,16 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
for field, (agent_name, stage) in report_map.items(): for field, (agent_name, stage) in report_map.items():
if field not in emitted_reports and chunk.get(field): if field not in emitted_reports and chunk.get(field):
emitted_reports.add(field) emitted_reports.add(field)
await q.put({ evt = {
"type": "report", "type": "report",
"agent": agent_name, "agent": agent_name,
"stage": stage, "stage": stage,
"field": field, "field": field,
"report": chunk[field], "report": chunk[field],
"stats": st, "stats": st,
}) }
state["events"].append(evt)
await q.put(evt)
# Research debate # Research debate
if chunk.get("investment_debate_state"): if chunk.get("investment_debate_state"):
@ -184,32 +188,38 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
judge = debate.get("judge_decision", "").strip() judge = debate.get("judge_decision", "").strip()
if bull or bear: if bull or bear:
update_research_team_status("in_progress") for a in ("Bull Researcher", "Bear Researcher", "Research Manager"):
buf.update_agent_status(a, "in_progress")
if judge and not research_emitted: if judge and not research_emitted:
research_emitted = True research_emitted = True
update_research_team_status("completed") for a in ("Bull Researcher", "Bear Researcher", "Research Manager"):
buf.update_agent_status(a, "completed")
buf.update_agent_status("Trader", "in_progress") buf.update_agent_status("Trader", "in_progress")
await q.put({ evt = {
"type": "debate", "type": "debate",
"stage": "research", "stage": "research",
"bull": bull, "bull": bull,
"bear": bear, "bear": bear,
"judge": judge, "judge": judge,
"stats": get_stats_dict(stats_handler, buf, start_time), "stats": get_stats_dict(stats_handler, buf, start_time),
}) }
state["events"].append(evt)
await q.put(evt)
# Trader plan # Trader plan
if chunk.get("trader_investment_plan") and not trader_emitted: if chunk.get("trader_investment_plan") and not trader_emitted:
trader_emitted = True trader_emitted = True
buf.update_agent_status("Trader", "completed") buf.update_agent_status("Trader", "completed")
buf.update_agent_status("Aggressive Analyst", "in_progress") buf.update_agent_status("Aggressive Analyst", "in_progress")
await q.put({ evt = {
"type": "trader", "type": "trader",
"stage": "trading", "stage": "trading",
"plan": chunk["trader_investment_plan"], "plan": chunk["trader_investment_plan"],
"stats": get_stats_dict(stats_handler, buf, start_time), "stats": get_stats_dict(stats_handler, buf, start_time),
}) }
state["events"].append(evt)
await q.put(evt)
# Risk debate # Risk debate
if chunk.get("risk_debate_state"): if chunk.get("risk_debate_state"):
@ -232,7 +242,7 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
buf.update_agent_status("Conservative Analyst", "completed") buf.update_agent_status("Conservative Analyst", "completed")
buf.update_agent_status("Neutral Analyst", "completed") buf.update_agent_status("Neutral Analyst", "completed")
buf.update_agent_status("Portfolio Manager", "completed") buf.update_agent_status("Portfolio Manager", "completed")
await q.put({ evt = {
"type": "risk", "type": "risk",
"stage": "risk", "stage": "risk",
"aggressive": agg, "aggressive": agg,
@ -240,10 +250,15 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
"neutral": neu, "neutral": neu,
"judge": judge, "judge": judge,
"stats": get_stats_dict(stats_handler, buf, start_time), "stats": get_stats_dict(stats_handler, buf, start_time),
}) }
state["events"].append(evt)
await q.put(evt)
except Exception as e: except Exception as e:
await q.put({"type": "error", "message": str(e)}) evt = {"type": "error", "message": str(e)}
state["events"].append(evt)
await q.put(evt)
state["done"] = True
await q.put(None) await q.put(None)
return return
@ -254,14 +269,17 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
for agent in buf.agent_status: for agent in buf.agent_status:
buf.update_agent_status(agent, "completed") buf.update_agent_status(agent, "completed")
st = get_stats_dict(stats_handler, buf, start_time) st = get_stats_dict(stats_handler, buf, start_time)
await q.put({ evt = {
"type": "decision", "type": "decision",
"stage": "decision", "stage": "decision",
"signal": signal, "signal": signal,
"decision_text": decision_text, "decision_text": decision_text,
"stats": st, "stats": st,
}) }
state["events"].append(evt)
await q.put(evt)
state["done"] = True
await q.put(None) # sentinel — stream done await q.put(None) # sentinel — stream done
@ -272,18 +290,30 @@ async def start_analysis(req: AnalyzeRequest):
raise HTTPException(400, "Invalid ticker") raise HTTPException(400, "Invalid ticker")
trade_date = req.date or str(date.today()) trade_date = req.date or str(date.today())
analysis_id = str(uuid.uuid4()) analysis_id = str(uuid.uuid4())
analyses[analysis_id] = asyncio.Queue() analyses[analysis_id] = {"queue": asyncio.Queue(), "events": [], "done": False}
asyncio.create_task(run_analysis(analysis_id, ticker, trade_date)) asyncio.create_task(run_analysis(analysis_id, ticker, trade_date))
return {"id": analysis_id, "ticker": ticker, "date": trade_date} return {"id": analysis_id, "ticker": ticker, "date": trade_date}
@app.get("/analyze/{analysis_id}/stream") @app.get("/analyze/{analysis_id}/stream")
async def stream_analysis(analysis_id: str): async def stream_analysis(analysis_id: str, last_event: int = 0):
"""Stream SSE events. Supports reconnection via ?last_event=N to replay missed events."""
if analysis_id not in analyses: if analysis_id not in analyses:
raise HTTPException(404, "Analysis not found") raise HTTPException(404, "Analysis not found")
q = analyses[analysis_id] state = analyses[analysis_id]
async def event_generator(): async def event_generator():
idx = last_event
# Replay any events the client missed
while idx < len(state["events"]):
evt = state["events"][idx]
idx += 1
yield {"id": str(idx), "data": json.dumps(evt)}
# If analysis already done after replay, stop
if state["done"]:
return
# Stream new events from queue
q = state["queue"]
while True: while True:
try: try:
event = await asyncio.wait_for(q.get(), timeout=15) event = await asyncio.wait_for(q.get(), timeout=15)
@ -292,8 +322,8 @@ async def stream_analysis(analysis_id: str):
continue continue
if event is None: if event is None:
break break
yield {"data": json.dumps(event)} idx += 1
analyses.pop(analysis_id, None) yield {"id": str(idx), "data": json.dumps(event)}
return EventSourceResponse(event_generator()) return EventSourceResponse(event_generator())