100 lines
3.8 KiB
Python
100 lines
3.8 KiB
Python
import json
|
|
from datetime import datetime
|
|
from typing import Dict, Set
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
from ..models.schemas import StreamUpdate
|
|
from ..services.analysis_service import AnalysisService
|
|
|
|
|
|
class ConnectionManager:
|
|
"""Manages WebSocket connections for streaming analysis updates."""
|
|
|
|
def __init__(self):
|
|
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
|
|
|
async def connect(self, websocket: WebSocket, analysis_id: str):
|
|
"""Accept a WebSocket connection for a specific analysis."""
|
|
await websocket.accept()
|
|
if analysis_id not in self.active_connections:
|
|
self.active_connections[analysis_id] = set()
|
|
self.active_connections[analysis_id].add(websocket)
|
|
|
|
def disconnect(self, websocket: WebSocket, analysis_id: str):
|
|
"""Remove a WebSocket connection."""
|
|
if analysis_id in self.active_connections:
|
|
self.active_connections[analysis_id].discard(websocket)
|
|
if not self.active_connections[analysis_id]:
|
|
del self.active_connections[analysis_id]
|
|
|
|
async def send_update(self, analysis_id: str, update: StreamUpdate):
|
|
"""Send an update to all connected clients for an analysis."""
|
|
if analysis_id in self.active_connections:
|
|
message = update.model_dump_json()
|
|
disconnected = set()
|
|
for connection in self.active_connections[analysis_id]:
|
|
try:
|
|
await connection.send_text(message)
|
|
except Exception:
|
|
disconnected.add(connection)
|
|
|
|
# Remove disconnected clients
|
|
for conn in disconnected:
|
|
self.disconnect(conn, analysis_id)
|
|
|
|
|
|
class StreamHandler:
|
|
"""Handles WebSocket streaming for analysis updates."""
|
|
|
|
def __init__(self, analysis_service: AnalysisService):
|
|
self.analysis_service = analysis_service
|
|
self.connection_manager = ConnectionManager()
|
|
|
|
async def handle_stream(
|
|
self,
|
|
websocket: WebSocket,
|
|
analysis_id: str
|
|
):
|
|
"""Handle WebSocket connection for streaming updates."""
|
|
await self.connection_manager.connect(websocket, analysis_id)
|
|
|
|
try:
|
|
# Send initial connection confirmation
|
|
await websocket.send_json({
|
|
"type": "status",
|
|
"data": {"message": "Connected", "analysis_id": analysis_id},
|
|
"timestamp": ""
|
|
})
|
|
|
|
# Send current agent statuses
|
|
current_statuses = self.analysis_service.agent_statuses.get(analysis_id, {})
|
|
for agent, status in current_statuses.items():
|
|
await websocket.send_json({
|
|
"type": "agent_status",
|
|
"data": {"agent": agent, "status": status},
|
|
"timestamp": datetime.now().isoformat()
|
|
})
|
|
|
|
# Keep connection alive and forward updates
|
|
while True:
|
|
# Wait for any incoming messages (ping/pong or close)
|
|
try:
|
|
data = await websocket.receive_text()
|
|
# Handle ping/pong if needed
|
|
if data == "ping":
|
|
await websocket.send_text("pong")
|
|
except WebSocketDisconnect:
|
|
break
|
|
except Exception:
|
|
# Connection closed or error
|
|
break
|
|
|
|
except WebSocketDisconnect:
|
|
pass
|
|
finally:
|
|
self.connection_manager.disconnect(websocket, analysis_id)
|
|
|
|
async def send_update(self, analysis_id: str, update: StreamUpdate):
|
|
"""Send an update to all connected clients."""
|
|
await self.connection_manager.send_update(analysis_id, update)
|
|
|