TradingAgents/backend/api/websocket/stream_handler.py

100 lines
3.8 KiB
Python

import json
from datetime import datetime
from typing import Dict, Set
from fastapi import WebSocket, WebSocketDisconnect
from ..models.schemas import StreamUpdate
from ..services.analysis_service import AnalysisService
class ConnectionManager:
"""Manages WebSocket connections for streaming analysis updates."""
def __init__(self):
self.active_connections: Dict[str, Set[WebSocket]] = {}
async def connect(self, websocket: WebSocket, analysis_id: str):
"""Accept a WebSocket connection for a specific analysis."""
await websocket.accept()
if analysis_id not in self.active_connections:
self.active_connections[analysis_id] = set()
self.active_connections[analysis_id].add(websocket)
def disconnect(self, websocket: WebSocket, analysis_id: str):
"""Remove a WebSocket connection."""
if analysis_id in self.active_connections:
self.active_connections[analysis_id].discard(websocket)
if not self.active_connections[analysis_id]:
del self.active_connections[analysis_id]
async def send_update(self, analysis_id: str, update: StreamUpdate):
"""Send an update to all connected clients for an analysis."""
if analysis_id in self.active_connections:
message = update.model_dump_json()
disconnected = set()
for connection in self.active_connections[analysis_id]:
try:
await connection.send_text(message)
except Exception:
disconnected.add(connection)
# Remove disconnected clients
for conn in disconnected:
self.disconnect(conn, analysis_id)
class StreamHandler:
"""Handles WebSocket streaming for analysis updates."""
def __init__(self, analysis_service: AnalysisService):
self.analysis_service = analysis_service
self.connection_manager = ConnectionManager()
async def handle_stream(
self,
websocket: WebSocket,
analysis_id: str
):
"""Handle WebSocket connection for streaming updates."""
await self.connection_manager.connect(websocket, analysis_id)
try:
# Send initial connection confirmation
await websocket.send_json({
"type": "status",
"data": {"message": "Connected", "analysis_id": analysis_id},
"timestamp": ""
})
# Send current agent statuses
current_statuses = self.analysis_service.agent_statuses.get(analysis_id, {})
for agent, status in current_statuses.items():
await websocket.send_json({
"type": "agent_status",
"data": {"agent": agent, "status": status},
"timestamp": datetime.now().isoformat()
})
# Keep connection alive and forward updates
while True:
# Wait for any incoming messages (ping/pong or close)
try:
data = await websocket.receive_text()
# Handle ping/pong if needed
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
break
except Exception:
# Connection closed or error
break
except WebSocketDisconnect:
pass
finally:
self.connection_manager.disconnect(websocket, analysis_id)
async def send_update(self, analysis_id: str, update: StreamUpdate):
"""Send an update to all connected clients."""
await self.connection_manager.send_update(analysis_id, update)