feat(orchestrator): LiveMode + /ws/orchestrator WebSocket endpoint
This commit is contained in:
parent
e6ff53ddea
commit
77d8e87675
|
|
@ -0,0 +1,47 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LiveMode:
|
||||||
|
"""
|
||||||
|
Triggers signal computation for a list of tickers and broadcasts
|
||||||
|
results via a callback (e.g., WebSocket send).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, orchestrator):
|
||||||
|
self._orchestrator = orchestrator
|
||||||
|
|
||||||
|
async def run_once(self, tickers: List[str], date: Optional[str] = None) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Compute combined signals for all tickers on the given date (default: today).
|
||||||
|
Returns list of signal dicts.
|
||||||
|
"""
|
||||||
|
if date is None:
|
||||||
|
date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for ticker in tickers:
|
||||||
|
try:
|
||||||
|
sig = self._orchestrator.get_combined_signal(ticker, date)
|
||||||
|
results.append({
|
||||||
|
"ticker": ticker,
|
||||||
|
"date": date,
|
||||||
|
"direction": sig.direction,
|
||||||
|
"confidence": sig.confidence,
|
||||||
|
"quant_direction": sig.quant_signal.direction if sig.quant_signal else None,
|
||||||
|
"llm_direction": sig.llm_signal.direction if sig.llm_signal else None,
|
||||||
|
"timestamp": sig.timestamp.isoformat(),
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("LiveMode: failed for %s %s: %s", ticker, date, e)
|
||||||
|
results.append({
|
||||||
|
"ticker": ticker,
|
||||||
|
"date": date,
|
||||||
|
"error": str(e),
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
|
@ -1100,6 +1100,40 @@ async def root():
|
||||||
return {"message": "TradingAgents Web Dashboard API", "version": "0.1.0"}
|
return {"message": "TradingAgents Web Dashboard API", "version": "0.1.0"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws/orchestrator")
|
||||||
|
async def ws_orchestrator(websocket: WebSocket):
|
||||||
|
"""WebSocket endpoint for orchestrator live signals."""
|
||||||
|
await websocket.accept()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
data = await websocket.receive_text()
|
||||||
|
payload = json.loads(data)
|
||||||
|
tickers = payload.get("tickers", [])
|
||||||
|
date = payload.get("date")
|
||||||
|
|
||||||
|
# Lazy import to avoid loading heavy deps at startup
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
from orchestrator.config import OrchestratorConfig
|
||||||
|
from orchestrator.orchestrator import TradingOrchestrator
|
||||||
|
from orchestrator.live_mode import LiveMode
|
||||||
|
|
||||||
|
config = OrchestratorConfig(
|
||||||
|
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
|
||||||
|
)
|
||||||
|
orchestrator = TradingOrchestrator(config)
|
||||||
|
live = LiveMode(orchestrator)
|
||||||
|
results = await live.run_once(tickers, date)
|
||||||
|
await websocket.send_text(json.dumps({"signals": results}))
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
try:
|
||||||
|
await websocket.send_text(json.dumps({"error": str(e)}))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
# Run with: cd web_dashboard && ../env312/bin/python -m uvicorn main:app --reload
|
# Run with: cd web_dashboard && ../env312/bin/python -m uvicorn main:app --reload
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue