TradingAgents/backend/analysis/application/websocket_manager.py

65 lines
2.4 KiB
Python

from typing import Dict, Set
from fastapi import WebSocket
import json
from datetime import datetime
class WebSocketManager:
def __init__(self):
# Store active connections by member_id
self.active_connections: Dict[str, Set[WebSocket]] = {}
# Store analysis_id to member_id mapping
self.analysis_member_map: Dict[str, str] = {}
async def connect(self, websocket: WebSocket, member_id: str):
await websocket.accept()
if member_id not in self.active_connections:
self.active_connections[member_id] = set()
self.active_connections[member_id].add(websocket)
def disconnect(self, websocket: WebSocket, member_id: str):
if member_id in self.active_connections:
self.active_connections[member_id].discard(websocket)
if not self.active_connections[member_id]:
del self.active_connections[member_id]
def register_analysis(self, analysis_id: str, member_id: str):
"""Register which member owns which analysis"""
self.analysis_member_map[analysis_id] = member_id
async def send_analysis_update(self, analysis_id: str, update_type: str, data: dict):
"""Send analysis update to the member who owns the analysis"""
member_id = self.analysis_member_map.get(analysis_id)
if not member_id:
return
message = {
"type": "analysis_update",
"analysis_id": analysis_id,
"update_type": update_type,
"data": data,
"timestamp": datetime.now().isoformat()
}
await self.send_to_member(member_id, message)
async def send_to_member(self, member_id: str, message: dict|str):
"""Send message to all connections of a specific member"""
if member_id not in self.active_connections:
return
dead_connections = set()
for connection in self.active_connections[member_id]:
try:
if isinstance(message, dict):
await connection.send_json(message)
else:
await connection.send_text(message)
except Exception:
dead_connections.add(connection)
# Clean up dead connections
for connection in dead_connections:
self.disconnect(connection, member_id)