115 lines
4.6 KiB
Python
115 lines
4.6 KiB
Python
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from typing import Dict, Any
|
|
from agent_os.backend.dependencies import get_current_user
|
|
from agent_os.backend.store import runs
|
|
from agent_os.backend.services.langgraph_engine import LangGraphEngine
|
|
|
|
logger = logging.getLogger("agent_os.websocket")
|
|
|
|
router = APIRouter(prefix="/ws", tags=["websocket"])
|
|
|
|
# Polling interval when streaming cached events from a background-task-driven run
|
|
_EVENT_POLL_INTERVAL_SECONDS = 0.05
|
|
|
|
engine = LangGraphEngine()
|
|
|
|
@router.websocket("/stream/{run_id}")
|
|
async def websocket_endpoint(
|
|
websocket: WebSocket,
|
|
run_id: str,
|
|
):
|
|
await websocket.accept()
|
|
logger.info("WebSocket connected run=%s", run_id)
|
|
|
|
if run_id not in runs:
|
|
logger.warning("Run not found run=%s", run_id)
|
|
await websocket.send_json({"type": "system", "message": f"Error: Run {run_id} not found."})
|
|
await websocket.close()
|
|
return
|
|
|
|
run_info = runs[run_id]
|
|
run_type = run_info["type"]
|
|
params = run_info.get("params", {})
|
|
|
|
try:
|
|
status = run_info.get("status", "queued")
|
|
|
|
if status in ("running", "completed", "failed"):
|
|
# Background task is already executing (or finished) — stream its cached events
|
|
# then wait for completion if still running.
|
|
logger.info(
|
|
"WebSocket streaming from cache run=%s status=%s", run_id, status
|
|
)
|
|
sent = 0
|
|
while True:
|
|
cached = run_info.get("events") or []
|
|
while sent < len(cached):
|
|
payload = cached[sent]
|
|
if "timestamp" not in payload:
|
|
payload["timestamp"] = time.strftime("%H:%M:%S")
|
|
await websocket.send_json(payload)
|
|
sent += 1
|
|
current_status = run_info.get("status")
|
|
if current_status in ("completed", "failed"):
|
|
break
|
|
# Yield to the event loop so the background task can produce more events
|
|
await asyncio.sleep(_EVENT_POLL_INTERVAL_SECONDS)
|
|
|
|
if run_info.get("status") == "failed":
|
|
await websocket.send_json(
|
|
{"type": "system", "message": f"Run failed: {run_info.get('error', 'unknown error')}"}
|
|
)
|
|
else:
|
|
# status == "queued" — WebSocket is the executor (background task didn't start yet)
|
|
stream_gen = None
|
|
if run_type == "scan":
|
|
stream_gen = engine.run_scan(run_id, params)
|
|
elif run_type == "pipeline":
|
|
stream_gen = engine.run_pipeline(run_id, params)
|
|
elif run_type == "portfolio":
|
|
stream_gen = engine.run_portfolio(run_id, params)
|
|
elif run_type == "auto":
|
|
stream_gen = engine.run_auto(run_id, params)
|
|
|
|
if stream_gen:
|
|
run_info["status"] = "running"
|
|
run_info.setdefault("events", [])
|
|
try:
|
|
async for payload in stream_gen:
|
|
run_info["events"].append(payload)
|
|
if "timestamp" not in payload:
|
|
payload["timestamp"] = time.strftime("%H:%M:%S")
|
|
await websocket.send_json(payload)
|
|
logger.debug(
|
|
"Sent event type=%s node=%s run=%s",
|
|
payload.get("type"),
|
|
payload.get("node_id"),
|
|
run_id,
|
|
)
|
|
run_info["status"] = "completed"
|
|
except Exception as exc:
|
|
run_info["status"] = "failed"
|
|
run_info["error"] = str(exc)
|
|
raise
|
|
else:
|
|
msg = f"Run type '{run_type}' streaming not yet implemented."
|
|
logger.warning(msg)
|
|
await websocket.send_json({"type": "system", "message": f"Error: {msg}"})
|
|
|
|
await websocket.send_json({"type": "system", "message": "Run completed."})
|
|
logger.info("Run completed run=%s type=%s", run_id, run_type)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info("WebSocket client disconnected run=%s", run_id)
|
|
except Exception as e:
|
|
logger.exception("Error during streaming run=%s", run_id)
|
|
try:
|
|
await websocket.send_json({"type": "system", "message": f"Error: {str(e)}"})
|
|
await websocket.close()
|
|
except Exception:
|
|
pass # client already gone
|