fix: add SSE replay buffer + fix research agent status tracking
- Replace global update_research_team_status() with local buf calls (was updating CLI's global buffer, not analysis-specific one) - Add replay buffer: all events stored in memory per analysis - Support ?last_event=N query param for reconnection replay - Send event IDs so browser can track position - Mark analysis as done so replay works after completion Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
777226722a
commit
f5519b9efe
78
app.py
78
app.py
|
|
@ -19,7 +19,6 @@ from cli.main import (
|
|||
MessageBuffer,
|
||||
classify_message_type,
|
||||
update_analyst_statuses,
|
||||
update_research_team_status,
|
||||
)
|
||||
|
||||
app = FastAPI(title="TradingAgents API")
|
||||
|
|
@ -30,8 +29,8 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Active analysis queues: id -> asyncio.Queue
|
||||
analyses: dict[str, asyncio.Queue] = {}
|
||||
# Active analysis state: id -> {queue, events (replay buffer), done}
|
||||
analyses: dict[str, dict] = {}
|
||||
|
||||
|
||||
class AnalyzeRequest(BaseModel):
|
||||
|
|
@ -92,7 +91,8 @@ def _agent_stage(agent_name):
|
|||
|
||||
async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||
"""Background task that runs the TradingAgents pipeline and pushes SSE events."""
|
||||
q = analyses[analysis_id]
|
||||
state = analyses[analysis_id]
|
||||
q = state["queue"]
|
||||
config = build_config()
|
||||
stats_handler = StatsCallbackHandler()
|
||||
selected_analysts = ["market", "social", "news", "fundamentals"]
|
||||
|
|
@ -149,13 +149,15 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
|||
for agent, status in buf.agent_status.items():
|
||||
if prev_statuses.get(agent) != status:
|
||||
prev_statuses[agent] = status
|
||||
await q.put({
|
||||
evt = {
|
||||
"type": "agent_update",
|
||||
"agent": agent,
|
||||
"stage": _agent_stage(agent),
|
||||
"status": status,
|
||||
"stats": st,
|
||||
})
|
||||
}
|
||||
state["events"].append(evt)
|
||||
await q.put(evt)
|
||||
|
||||
# Analyst reports
|
||||
report_map = {
|
||||
|
|
@ -167,14 +169,16 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
|||
for field, (agent_name, stage) in report_map.items():
|
||||
if field not in emitted_reports and chunk.get(field):
|
||||
emitted_reports.add(field)
|
||||
await q.put({
|
||||
evt = {
|
||||
"type": "report",
|
||||
"agent": agent_name,
|
||||
"stage": stage,
|
||||
"field": field,
|
||||
"report": chunk[field],
|
||||
"stats": st,
|
||||
})
|
||||
}
|
||||
state["events"].append(evt)
|
||||
await q.put(evt)
|
||||
|
||||
# Research debate
|
||||
if chunk.get("investment_debate_state"):
|
||||
|
|
@ -184,32 +188,38 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
|||
judge = debate.get("judge_decision", "").strip()
|
||||
|
||||
if bull or bear:
|
||||
update_research_team_status("in_progress")
|
||||
for a in ("Bull Researcher", "Bear Researcher", "Research Manager"):
|
||||
buf.update_agent_status(a, "in_progress")
|
||||
|
||||
if judge and not research_emitted:
|
||||
research_emitted = True
|
||||
update_research_team_status("completed")
|
||||
for a in ("Bull Researcher", "Bear Researcher", "Research Manager"):
|
||||
buf.update_agent_status(a, "completed")
|
||||
buf.update_agent_status("Trader", "in_progress")
|
||||
await q.put({
|
||||
evt = {
|
||||
"type": "debate",
|
||||
"stage": "research",
|
||||
"bull": bull,
|
||||
"bear": bear,
|
||||
"judge": judge,
|
||||
"stats": get_stats_dict(stats_handler, buf, start_time),
|
||||
})
|
||||
}
|
||||
state["events"].append(evt)
|
||||
await q.put(evt)
|
||||
|
||||
# Trader plan
|
||||
if chunk.get("trader_investment_plan") and not trader_emitted:
|
||||
trader_emitted = True
|
||||
buf.update_agent_status("Trader", "completed")
|
||||
buf.update_agent_status("Aggressive Analyst", "in_progress")
|
||||
await q.put({
|
||||
evt = {
|
||||
"type": "trader",
|
||||
"stage": "trading",
|
||||
"plan": chunk["trader_investment_plan"],
|
||||
"stats": get_stats_dict(stats_handler, buf, start_time),
|
||||
})
|
||||
}
|
||||
state["events"].append(evt)
|
||||
await q.put(evt)
|
||||
|
||||
# Risk debate
|
||||
if chunk.get("risk_debate_state"):
|
||||
|
|
@ -232,7 +242,7 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
|||
buf.update_agent_status("Conservative Analyst", "completed")
|
||||
buf.update_agent_status("Neutral Analyst", "completed")
|
||||
buf.update_agent_status("Portfolio Manager", "completed")
|
||||
await q.put({
|
||||
evt = {
|
||||
"type": "risk",
|
||||
"stage": "risk",
|
||||
"aggressive": agg,
|
||||
|
|
@ -240,10 +250,15 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
|||
"neutral": neu,
|
||||
"judge": judge,
|
||||
"stats": get_stats_dict(stats_handler, buf, start_time),
|
||||
})
|
||||
}
|
||||
state["events"].append(evt)
|
||||
await q.put(evt)
|
||||
|
||||
except Exception as e:
|
||||
await q.put({"type": "error", "message": str(e)})
|
||||
evt = {"type": "error", "message": str(e)}
|
||||
state["events"].append(evt)
|
||||
await q.put(evt)
|
||||
state["done"] = True
|
||||
await q.put(None)
|
||||
return
|
||||
|
||||
|
|
@ -254,14 +269,17 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
|||
for agent in buf.agent_status:
|
||||
buf.update_agent_status(agent, "completed")
|
||||
st = get_stats_dict(stats_handler, buf, start_time)
|
||||
await q.put({
|
||||
evt = {
|
||||
"type": "decision",
|
||||
"stage": "decision",
|
||||
"signal": signal,
|
||||
"decision_text": decision_text,
|
||||
"stats": st,
|
||||
})
|
||||
}
|
||||
state["events"].append(evt)
|
||||
await q.put(evt)
|
||||
|
||||
state["done"] = True
|
||||
await q.put(None) # sentinel — stream done
|
||||
|
||||
|
||||
|
|
@ -272,18 +290,30 @@ async def start_analysis(req: AnalyzeRequest):
|
|||
raise HTTPException(400, "Invalid ticker")
|
||||
trade_date = req.date or str(date.today())
|
||||
analysis_id = str(uuid.uuid4())
|
||||
analyses[analysis_id] = asyncio.Queue()
|
||||
analyses[analysis_id] = {"queue": asyncio.Queue(), "events": [], "done": False}
|
||||
asyncio.create_task(run_analysis(analysis_id, ticker, trade_date))
|
||||
return {"id": analysis_id, "ticker": ticker, "date": trade_date}
|
||||
|
||||
|
||||
@app.get("/analyze/{analysis_id}/stream")
|
||||
async def stream_analysis(analysis_id: str):
|
||||
async def stream_analysis(analysis_id: str, last_event: int = 0):
|
||||
"""Stream SSE events. Supports reconnection via ?last_event=N to replay missed events."""
|
||||
if analysis_id not in analyses:
|
||||
raise HTTPException(404, "Analysis not found")
|
||||
q = analyses[analysis_id]
|
||||
state = analyses[analysis_id]
|
||||
|
||||
async def event_generator():
|
||||
idx = last_event
|
||||
# Replay any events the client missed
|
||||
while idx < len(state["events"]):
|
||||
evt = state["events"][idx]
|
||||
idx += 1
|
||||
yield {"id": str(idx), "data": json.dumps(evt)}
|
||||
# If analysis already done after replay, stop
|
||||
if state["done"]:
|
||||
return
|
||||
# Stream new events from queue
|
||||
q = state["queue"]
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(q.get(), timeout=15)
|
||||
|
|
@ -292,8 +322,8 @@ async def stream_analysis(analysis_id: str):
|
|||
continue
|
||||
if event is None:
|
||||
break
|
||||
yield {"data": json.dumps(event)}
|
||||
analyses.pop(analysis_id, None)
|
||||
idx += 1
|
||||
yield {"id": str(idx), "data": json.dumps(event)}
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue