fix: add SSE replay buffer + fix research agent status tracking
- Replace global update_research_team_status() with local buf calls (was updating CLI's global buffer, not analysis-specific one) - Add replay buffer: all events stored in memory per analysis - Support ?last_event=N query param for reconnection replay - Send event IDs so browser can track position - Mark analysis as done so replay works after completion Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
777226722a
commit
f5519b9efe
78
app.py
78
app.py
|
|
@ -19,7 +19,6 @@ from cli.main import (
|
||||||
MessageBuffer,
|
MessageBuffer,
|
||||||
classify_message_type,
|
classify_message_type,
|
||||||
update_analyst_statuses,
|
update_analyst_statuses,
|
||||||
update_research_team_status,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
app = FastAPI(title="TradingAgents API")
|
app = FastAPI(title="TradingAgents API")
|
||||||
|
|
@ -30,8 +29,8 @@ app.add_middleware(
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Active analysis queues: id -> asyncio.Queue
|
# Active analysis state: id -> {queue, events (replay buffer), done}
|
||||||
analyses: dict[str, asyncio.Queue] = {}
|
analyses: dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
class AnalyzeRequest(BaseModel):
|
class AnalyzeRequest(BaseModel):
|
||||||
|
|
@ -92,7 +91,8 @@ def _agent_stage(agent_name):
|
||||||
|
|
||||||
async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||||
"""Background task that runs the TradingAgents pipeline and pushes SSE events."""
|
"""Background task that runs the TradingAgents pipeline and pushes SSE events."""
|
||||||
q = analyses[analysis_id]
|
state = analyses[analysis_id]
|
||||||
|
q = state["queue"]
|
||||||
config = build_config()
|
config = build_config()
|
||||||
stats_handler = StatsCallbackHandler()
|
stats_handler = StatsCallbackHandler()
|
||||||
selected_analysts = ["market", "social", "news", "fundamentals"]
|
selected_analysts = ["market", "social", "news", "fundamentals"]
|
||||||
|
|
@ -149,13 +149,15 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||||
for agent, status in buf.agent_status.items():
|
for agent, status in buf.agent_status.items():
|
||||||
if prev_statuses.get(agent) != status:
|
if prev_statuses.get(agent) != status:
|
||||||
prev_statuses[agent] = status
|
prev_statuses[agent] = status
|
||||||
await q.put({
|
evt = {
|
||||||
"type": "agent_update",
|
"type": "agent_update",
|
||||||
"agent": agent,
|
"agent": agent,
|
||||||
"stage": _agent_stage(agent),
|
"stage": _agent_stage(agent),
|
||||||
"status": status,
|
"status": status,
|
||||||
"stats": st,
|
"stats": st,
|
||||||
})
|
}
|
||||||
|
state["events"].append(evt)
|
||||||
|
await q.put(evt)
|
||||||
|
|
||||||
# Analyst reports
|
# Analyst reports
|
||||||
report_map = {
|
report_map = {
|
||||||
|
|
@ -167,14 +169,16 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||||
for field, (agent_name, stage) in report_map.items():
|
for field, (agent_name, stage) in report_map.items():
|
||||||
if field not in emitted_reports and chunk.get(field):
|
if field not in emitted_reports and chunk.get(field):
|
||||||
emitted_reports.add(field)
|
emitted_reports.add(field)
|
||||||
await q.put({
|
evt = {
|
||||||
"type": "report",
|
"type": "report",
|
||||||
"agent": agent_name,
|
"agent": agent_name,
|
||||||
"stage": stage,
|
"stage": stage,
|
||||||
"field": field,
|
"field": field,
|
||||||
"report": chunk[field],
|
"report": chunk[field],
|
||||||
"stats": st,
|
"stats": st,
|
||||||
})
|
}
|
||||||
|
state["events"].append(evt)
|
||||||
|
await q.put(evt)
|
||||||
|
|
||||||
# Research debate
|
# Research debate
|
||||||
if chunk.get("investment_debate_state"):
|
if chunk.get("investment_debate_state"):
|
||||||
|
|
@ -184,32 +188,38 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||||
judge = debate.get("judge_decision", "").strip()
|
judge = debate.get("judge_decision", "").strip()
|
||||||
|
|
||||||
if bull or bear:
|
if bull or bear:
|
||||||
update_research_team_status("in_progress")
|
for a in ("Bull Researcher", "Bear Researcher", "Research Manager"):
|
||||||
|
buf.update_agent_status(a, "in_progress")
|
||||||
|
|
||||||
if judge and not research_emitted:
|
if judge and not research_emitted:
|
||||||
research_emitted = True
|
research_emitted = True
|
||||||
update_research_team_status("completed")
|
for a in ("Bull Researcher", "Bear Researcher", "Research Manager"):
|
||||||
|
buf.update_agent_status(a, "completed")
|
||||||
buf.update_agent_status("Trader", "in_progress")
|
buf.update_agent_status("Trader", "in_progress")
|
||||||
await q.put({
|
evt = {
|
||||||
"type": "debate",
|
"type": "debate",
|
||||||
"stage": "research",
|
"stage": "research",
|
||||||
"bull": bull,
|
"bull": bull,
|
||||||
"bear": bear,
|
"bear": bear,
|
||||||
"judge": judge,
|
"judge": judge,
|
||||||
"stats": get_stats_dict(stats_handler, buf, start_time),
|
"stats": get_stats_dict(stats_handler, buf, start_time),
|
||||||
})
|
}
|
||||||
|
state["events"].append(evt)
|
||||||
|
await q.put(evt)
|
||||||
|
|
||||||
# Trader plan
|
# Trader plan
|
||||||
if chunk.get("trader_investment_plan") and not trader_emitted:
|
if chunk.get("trader_investment_plan") and not trader_emitted:
|
||||||
trader_emitted = True
|
trader_emitted = True
|
||||||
buf.update_agent_status("Trader", "completed")
|
buf.update_agent_status("Trader", "completed")
|
||||||
buf.update_agent_status("Aggressive Analyst", "in_progress")
|
buf.update_agent_status("Aggressive Analyst", "in_progress")
|
||||||
await q.put({
|
evt = {
|
||||||
"type": "trader",
|
"type": "trader",
|
||||||
"stage": "trading",
|
"stage": "trading",
|
||||||
"plan": chunk["trader_investment_plan"],
|
"plan": chunk["trader_investment_plan"],
|
||||||
"stats": get_stats_dict(stats_handler, buf, start_time),
|
"stats": get_stats_dict(stats_handler, buf, start_time),
|
||||||
})
|
}
|
||||||
|
state["events"].append(evt)
|
||||||
|
await q.put(evt)
|
||||||
|
|
||||||
# Risk debate
|
# Risk debate
|
||||||
if chunk.get("risk_debate_state"):
|
if chunk.get("risk_debate_state"):
|
||||||
|
|
@ -232,7 +242,7 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||||
buf.update_agent_status("Conservative Analyst", "completed")
|
buf.update_agent_status("Conservative Analyst", "completed")
|
||||||
buf.update_agent_status("Neutral Analyst", "completed")
|
buf.update_agent_status("Neutral Analyst", "completed")
|
||||||
buf.update_agent_status("Portfolio Manager", "completed")
|
buf.update_agent_status("Portfolio Manager", "completed")
|
||||||
await q.put({
|
evt = {
|
||||||
"type": "risk",
|
"type": "risk",
|
||||||
"stage": "risk",
|
"stage": "risk",
|
||||||
"aggressive": agg,
|
"aggressive": agg,
|
||||||
|
|
@ -240,10 +250,15 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||||
"neutral": neu,
|
"neutral": neu,
|
||||||
"judge": judge,
|
"judge": judge,
|
||||||
"stats": get_stats_dict(stats_handler, buf, start_time),
|
"stats": get_stats_dict(stats_handler, buf, start_time),
|
||||||
})
|
}
|
||||||
|
state["events"].append(evt)
|
||||||
|
await q.put(evt)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await q.put({"type": "error", "message": str(e)})
|
evt = {"type": "error", "message": str(e)}
|
||||||
|
state["events"].append(evt)
|
||||||
|
await q.put(evt)
|
||||||
|
state["done"] = True
|
||||||
await q.put(None)
|
await q.put(None)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -254,14 +269,17 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||||
for agent in buf.agent_status:
|
for agent in buf.agent_status:
|
||||||
buf.update_agent_status(agent, "completed")
|
buf.update_agent_status(agent, "completed")
|
||||||
st = get_stats_dict(stats_handler, buf, start_time)
|
st = get_stats_dict(stats_handler, buf, start_time)
|
||||||
await q.put({
|
evt = {
|
||||||
"type": "decision",
|
"type": "decision",
|
||||||
"stage": "decision",
|
"stage": "decision",
|
||||||
"signal": signal,
|
"signal": signal,
|
||||||
"decision_text": decision_text,
|
"decision_text": decision_text,
|
||||||
"stats": st,
|
"stats": st,
|
||||||
})
|
}
|
||||||
|
state["events"].append(evt)
|
||||||
|
await q.put(evt)
|
||||||
|
|
||||||
|
state["done"] = True
|
||||||
await q.put(None) # sentinel — stream done
|
await q.put(None) # sentinel — stream done
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -272,18 +290,30 @@ async def start_analysis(req: AnalyzeRequest):
|
||||||
raise HTTPException(400, "Invalid ticker")
|
raise HTTPException(400, "Invalid ticker")
|
||||||
trade_date = req.date or str(date.today())
|
trade_date = req.date or str(date.today())
|
||||||
analysis_id = str(uuid.uuid4())
|
analysis_id = str(uuid.uuid4())
|
||||||
analyses[analysis_id] = asyncio.Queue()
|
analyses[analysis_id] = {"queue": asyncio.Queue(), "events": [], "done": False}
|
||||||
asyncio.create_task(run_analysis(analysis_id, ticker, trade_date))
|
asyncio.create_task(run_analysis(analysis_id, ticker, trade_date))
|
||||||
return {"id": analysis_id, "ticker": ticker, "date": trade_date}
|
return {"id": analysis_id, "ticker": ticker, "date": trade_date}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/analyze/{analysis_id}/stream")
|
@app.get("/analyze/{analysis_id}/stream")
|
||||||
async def stream_analysis(analysis_id: str):
|
async def stream_analysis(analysis_id: str, last_event: int = 0):
|
||||||
|
"""Stream SSE events. Supports reconnection via ?last_event=N to replay missed events."""
|
||||||
if analysis_id not in analyses:
|
if analysis_id not in analyses:
|
||||||
raise HTTPException(404, "Analysis not found")
|
raise HTTPException(404, "Analysis not found")
|
||||||
q = analyses[analysis_id]
|
state = analyses[analysis_id]
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
|
idx = last_event
|
||||||
|
# Replay any events the client missed
|
||||||
|
while idx < len(state["events"]):
|
||||||
|
evt = state["events"][idx]
|
||||||
|
idx += 1
|
||||||
|
yield {"id": str(idx), "data": json.dumps(evt)}
|
||||||
|
# If analysis already done after replay, stop
|
||||||
|
if state["done"]:
|
||||||
|
return
|
||||||
|
# Stream new events from queue
|
||||||
|
q = state["queue"]
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
event = await asyncio.wait_for(q.get(), timeout=15)
|
event = await asyncio.wait_for(q.get(), timeout=15)
|
||||||
|
|
@ -292,8 +322,8 @@ async def stream_analysis(analysis_id: str):
|
||||||
continue
|
continue
|
||||||
if event is None:
|
if event is None:
|
||||||
break
|
break
|
||||||
yield {"data": json.dumps(event)}
|
idx += 1
|
||||||
analyses.pop(analysis_id, None)
|
yield {"id": str(idx), "data": json.dumps(event)}
|
||||||
|
|
||||||
return EventSourceResponse(event_generator())
|
return EventSourceResponse(event_generator())
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue