149 lines
5.1 KiB
Python
149 lines
5.1 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import threading
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Optional, Set
|
|
|
|
import websocket
|
|
|
|
from tradingagents.services.autopilot_worker import AutopilotWorker
|
|
from tradingagents.services.hypothesis_store import HypothesisStore, HypothesisRecord
|
|
|
|
|
|
class RealtimeNewsBroker:
|
|
"""Subscribe to Alpaca news WebSocket and forward events to the autopilot worker."""
|
|
|
|
DEFAULT_URL = "wss://stream.data.alpaca.markets/v1beta1/news"
|
|
|
|
def __init__(
|
|
self,
|
|
store: HypothesisStore,
|
|
worker: AutopilotWorker,
|
|
api_key: str,
|
|
secret_key: str,
|
|
*,
|
|
url: Optional[str] = None,
|
|
logger: Optional[logging.Logger] = None,
|
|
) -> None:
|
|
self.store = store
|
|
self.worker = worker
|
|
self.api_key = api_key
|
|
self.secret_key = secret_key
|
|
self.url = url or self.DEFAULT_URL
|
|
self.logger = logger or logging.getLogger(__name__)
|
|
self.watchers: Dict[str, Set[str]] = defaultdict(set)
|
|
self.ws: Optional[websocket.WebSocketApp] = None
|
|
self._thread: Optional[threading.Thread] = None
|
|
self._lock = threading.Lock()
|
|
|
|
def bootstrap_watchers(self) -> None:
|
|
self.refresh_watchers()
|
|
|
|
def refresh_watchers(self, records: Optional[List[HypothesisRecord]] = None) -> int:
|
|
"""Rebuild the news watcher map from stored hypotheses."""
|
|
|
|
dataset = records or self.store.list()
|
|
with self._lock:
|
|
self.watchers.clear()
|
|
registered = 0
|
|
for record in dataset:
|
|
registered += self._register_symbol_unlocked(record.ticker.upper(), record.id)
|
|
for symbol in self._extract_symbols_from_triggers(record):
|
|
registered += self._register_symbol_unlocked(symbol, record.id)
|
|
return registered
|
|
|
|
def _extract_symbols_from_triggers(self, record: HypothesisRecord) -> List[str]:
|
|
symbols: List[str] = []
|
|
for raw in record.triggers:
|
|
text = str(raw).strip().lower()
|
|
if text.startswith("news"):
|
|
parts = text.split(":", 1)
|
|
if len(parts) == 2 and parts[1].strip():
|
|
symbols.append(parts[1].strip().upper())
|
|
else:
|
|
symbols.append(record.ticker.upper())
|
|
return symbols
|
|
|
|
def _register_symbol(self, symbol: str, hypothesis_id: str) -> int:
|
|
if not symbol:
|
|
return 0
|
|
symbol_key = symbol.upper()
|
|
with self._lock:
|
|
return self._register_symbol_unlocked(symbol_key, hypothesis_id)
|
|
|
|
def _register_symbol_unlocked(self, symbol: str, hypothesis_id: str) -> int:
|
|
if not symbol:
|
|
return 0
|
|
bucket = self.watchers[symbol]
|
|
before = len(bucket)
|
|
bucket.add(hypothesis_id)
|
|
return 1 if len(bucket) > before else 0
|
|
|
|
def start(self) -> None:
|
|
if self.ws is not None:
|
|
self.logger.info("News broker already running")
|
|
return
|
|
self.bootstrap_watchers()
|
|
|
|
headers = [
|
|
f"APCA-API-KEY-ID: {self.api_key}",
|
|
f"APCA-API-SECRET-KEY: {self.secret_key}",
|
|
]
|
|
|
|
self.ws = websocket.WebSocketApp(
|
|
self.url,
|
|
header=headers,
|
|
on_open=self._on_open,
|
|
on_message=self._on_message,
|
|
on_error=self._on_error,
|
|
on_close=self._on_close,
|
|
)
|
|
|
|
def _run(): # pragma: no cover - network behavior
|
|
self.logger.info("Connecting to Alpaca news stream …")
|
|
self.ws.run_forever()
|
|
|
|
self._thread = threading.Thread(target=_run, daemon=True)
|
|
self._thread.start()
|
|
|
|
def _on_open(self, ws): # pragma: no cover - network callback
|
|
self.logger.info("News stream connected; subscribing to all news…")
|
|
ws.send(json.dumps({"action": "subscribe", "news": ["*"]}))
|
|
|
|
def _on_message(self, ws, message): # pragma: no cover - network callback
|
|
try:
|
|
data = json.loads(message)
|
|
except json.JSONDecodeError:
|
|
return
|
|
if isinstance(data, list):
|
|
for item in data:
|
|
self._handle_news(item)
|
|
else:
|
|
self._handle_news(data)
|
|
|
|
def _handle_news(self, payload: Dict[str, object]) -> None:
|
|
if payload.get("T") != "n":
|
|
return
|
|
symbols = payload.get("symbols") or []
|
|
if not isinstance(symbols, list):
|
|
return
|
|
for symbol in symbols:
|
|
symbol_key = str(symbol or "").upper()
|
|
with self._lock:
|
|
targets = list(self.watchers.get(symbol_key, ()))
|
|
for hypothesis_id in targets:
|
|
self.worker.enqueue_event(
|
|
hypothesis_id,
|
|
event_type="news",
|
|
payload=payload,
|
|
)
|
|
|
|
def _on_error(self, ws, error): # pragma: no cover - network callback
|
|
self.logger.error("News stream error: %s", error)
|
|
|
|
def _on_close(self, ws, code, msg): # pragma: no cover - network callback
|
|
self.logger.info("News stream closed: %s %s", code, msg)
|
|
self.ws = None
|