feat(orchestrator): merge orchestrator module into main

This commit is contained in:
陈少杰 2026-04-10 01:52:00 +08:00
commit 8960fdf321
16 changed files with 899 additions and 2 deletions

0
orchestrator/__init__.py Normal file
View File

View File

@ -0,0 +1,64 @@
import logging
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import List
from orchestrator.signals import FinalSignal
logger = logging.getLogger(__name__)
@dataclass
class BacktestResult:
records: List[dict] = field(default_factory=list)
summary: dict = field(default_factory=dict)
class BacktestMode:
def __init__(self, orchestrator):
self._orchestrator = orchestrator
def run(self, tickers: List[str], start_date: str, end_date: str) -> BacktestResult:
start = datetime.strptime(start_date, "%Y-%m-%d")
end = datetime.strptime(end_date, "%Y-%m-%d")
records = []
current = start
while current <= end:
if current.weekday() < 5: # skip weekends
date_str = current.strftime("%Y-%m-%d")
for ticker in tickers:
try:
sig = self._orchestrator.get_combined_signal(ticker, date_str)
records.append({
"ticker": ticker,
"date": date_str,
"direction": sig.direction,
"confidence": sig.confidence,
"quant_direction": sig.quant_signal.direction if sig.quant_signal else None,
"llm_direction": sig.llm_signal.direction if sig.llm_signal else None,
})
except Exception as e:
logger.error("BacktestMode: failed for %s %s: %s", ticker, date_str, e)
current += timedelta(days=1)
summary = self._compute_summary(records, tickers)
return BacktestResult(records=records, summary=summary)
def _compute_summary(self, records: List[dict], tickers: List[str]) -> dict:
summary = {}
for ticker in tickers:
ticker_records = [r for r in records if r["ticker"] == ticker]
if not ticker_records:
summary[ticker] = {"total_days": 0}
continue
directions = [r["direction"] for r in ticker_records]
confidences = [r["confidence"] for r in ticker_records]
summary[ticker] = {
"total_days": len(ticker_records),
"buy_days": directions.count(1),
"sell_days": directions.count(-1),
"hold_days": directions.count(0),
"avg_confidence": sum(confidences) / len(confidences),
}
return summary

14
orchestrator/config.py Normal file
View File

@ -0,0 +1,14 @@
from dataclasses import dataclass, field
@dataclass
class OrchestratorConfig:
# Must be set to the local quant backtest output directory before use
quant_backtest_path: str = ""
trading_agents_config: dict = field(default_factory=dict)
quant_weight_cap: float = 0.8 # quant 置信度上限
llm_weight_cap: float = 0.9 # llm 置信度上限
llm_batch_days: int = 7 # LLM 每隔几天运行一次(节省 API
cache_dir: str = "orchestrator/cache" # LLM 信号缓存目录
llm_solo_penalty: float = 0.7 # LLM 单轨时的置信度折扣
quant_solo_penalty: float = 0.8 # Quant 单轨时的置信度折扣

View File

View File

@ -0,0 +1,41 @@
"""
Example: Run orchestrator backtest for 宁德时代 (300750.SZ) over 2023.
Usage:
cd /path/to/TradingAgents
QUANT_BACKTEST_PATH=/path/to/quant_backtest python orchestrator/examples/run_backtest.py
"""
import json
import logging
import os
import sys
# Add repo root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from orchestrator.config import OrchestratorConfig
from orchestrator.orchestrator import TradingOrchestrator
from orchestrator.backtest_mode import BacktestMode
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
config = OrchestratorConfig(
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
cache_dir="orchestrator/cache",
)
orchestrator = TradingOrchestrator(config)
backtest = BacktestMode(orchestrator)
result = backtest.run(
tickers=["300750.SZ"],
start_date="2023-01-01",
end_date="2023-12-31",
)
print(f"\n=== Backtest Summary ===")
print(json.dumps(result.summary, indent=2, ensure_ascii=False))
print(f"\nTotal records: {len(result.records)}")
if result.records:
print(f"First record: {result.records[0]}")
print(f"Last record: {result.records[-1]}")

View File

@ -0,0 +1,41 @@
"""
Example: Run orchestrator live mode for a list of tickers.
Usage:
cd /path/to/TradingAgents
QUANT_BACKTEST_PATH=/path/to/quant_backtest python orchestrator/examples/run_live.py
"""
import asyncio
import json
import logging
import os
import sys
from datetime import datetime, timezone
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from orchestrator.config import OrchestratorConfig
from orchestrator.orchestrator import TradingOrchestrator
from orchestrator.live_mode import LiveMode
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
TICKERS = ["300750.SZ", "603259.SS"]
config = OrchestratorConfig(
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
cache_dir="orchestrator/cache",
)
orchestrator = TradingOrchestrator(config)
live = LiveMode(orchestrator)
async def main():
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
print(f"\n=== Live Signals for {today} ===")
results = await live.run_once(TICKERS, date=today)
print(json.dumps(results, indent=2, ensure_ascii=False))
asyncio.run(main())

48
orchestrator/live_mode.py Normal file
View File

@ -0,0 +1,48 @@
import asyncio
import logging
from datetime import datetime, timezone
from typing import List, Optional
logger = logging.getLogger(__name__)
class LiveMode:
"""
Triggers signal computation for a list of tickers and broadcasts
results via a callback (e.g., WebSocket send).
"""
def __init__(self, orchestrator):
self._orchestrator = orchestrator
async def run_once(self, tickers: List[str], date: Optional[str] = None) -> List[dict]:
"""
Compute combined signals for all tickers on the given date (default: today).
Returns list of signal dicts.
"""
if date is None:
date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
results = []
for ticker in tickers:
try:
sig = await asyncio.to_thread(
self._orchestrator.get_combined_signal, ticker, date
)
results.append({
"ticker": ticker,
"date": date,
"direction": sig.direction,
"confidence": sig.confidence,
"quant_direction": sig.quant_signal.direction if sig.quant_signal else None,
"llm_direction": sig.llm_signal.direction if sig.llm_signal else None,
"timestamp": sig.timestamp.isoformat(),
})
except Exception as e:
logger.error("LiveMode: failed for %s %s: %s", ticker, date, e)
results.append({
"ticker": ticker,
"date": date,
"error": str(e),
})
return results

View File

@ -0,0 +1,95 @@
import json
import logging
import os
from datetime import datetime, timezone
from orchestrator.config import OrchestratorConfig
from orchestrator.signals import Signal
logger = logging.getLogger(__name__)
class LLMRunner:
def __init__(self, config: OrchestratorConfig):
self._config = config
self._graph = None # Lazy-initialized on first get_signal() call (requires API key)
self.cache_dir = config.cache_dir
os.makedirs(self.cache_dir, exist_ok=True)
def _get_graph(self):
"""Lazy-initialize TradingAgentsGraph (heavy, requires API key at init time)."""
if self._graph is None:
from tradingagents.graph.trading_graph import TradingAgentsGraph
trading_cfg = self._config.trading_agents_config if self._config.trading_agents_config else None
self._graph = TradingAgentsGraph(config=trading_cfg)
return self._graph
def get_signal(self, ticker: str, date: str) -> Signal:
"""获取指定股票在指定日期的 LLM 信号,带缓存。"""
safe_ticker = ticker.replace("/", "_") # sanitize for filesystem (e.g. BRK/B)
cache_path = os.path.join(self.cache_dir, f"{safe_ticker}_{date}.json")
if os.path.exists(cache_path):
logger.info("LLMRunner: cache hit for %s %s", ticker, date)
with open(cache_path, "r", encoding="utf-8") as f:
data = json.load(f)
# Use stored direction/confidence directly to avoid re-mapping drift
return Signal(
ticker=ticker,
direction=data["direction"],
confidence=data["confidence"],
source="llm",
timestamp=datetime.fromisoformat(data["timestamp"]),
metadata=data,
)
try:
_final_state, processed_signal = self._get_graph().propagate(ticker, date)
rating = processed_signal if isinstance(processed_signal, str) else str(processed_signal)
direction, confidence = self._map_rating(rating)
now = datetime.now(timezone.utc)
cache_data = {
"rating": rating,
"direction": direction,
"confidence": confidence,
"timestamp": now.isoformat(),
"ticker": ticker,
"date": date,
}
with open(cache_path, "w", encoding="utf-8") as f:
json.dump(cache_data, f, ensure_ascii=False, indent=2)
return Signal(
ticker=ticker,
direction=direction,
confidence=confidence,
source="llm",
timestamp=now,
metadata=cache_data,
)
except Exception as e:
logger.error("LLMRunner: propagate failed for %s %s: %s", ticker, date, e)
return Signal(
ticker=ticker,
direction=0,
confidence=0.0,
source="llm",
timestamp=datetime.now(timezone.utc),
metadata={"error": str(e)},
)
def _map_rating(self, rating: str) -> tuple[int, float]:
"""将 5 级评级映射为 (direction, confidence)。"""
mapping = {
"BUY": (1, 0.9),
"OVERWEIGHT": (1, 0.6),
"HOLD": (0, 0.5),
"UNDERWEIGHT": (-1, 0.6),
"SELL": (-1, 0.9),
}
result = mapping.get(rating.upper() if rating else "", None)
if result is None:
logger.warning("LLMRunner: unknown rating %r, falling back to HOLD", rating)
return (0, 0.5)
return result

View File

@ -0,0 +1,63 @@
import logging
from datetime import datetime, timezone
from typing import Optional
from orchestrator.config import OrchestratorConfig
from orchestrator.signals import Signal, FinalSignal, SignalMerger
from orchestrator.quant_runner import QuantRunner
from orchestrator.llm_runner import LLMRunner
logger = logging.getLogger(__name__)
class TradingOrchestrator:
def __init__(self, config: OrchestratorConfig):
self._config = config
self._merger = SignalMerger(config)
self._quant: Optional[QuantRunner] = None
self._llm: Optional[LLMRunner] = None
# Initialize runners (quant requires quant_backtest_path)
if config.quant_backtest_path:
try:
self._quant = QuantRunner(config)
except Exception as e:
logger.warning("TradingOrchestrator: QuantRunner init failed: %s", e)
self._llm = LLMRunner(config)
def get_combined_signal(self, ticker: str, date: str) -> FinalSignal:
"""
Get merged signal for ticker on date.
Degradation:
- quant fails (error signal): use llm only with llm_solo_penalty
- llm fails (error signal): use quant only with quant_solo_penalty
- both fail: raises ValueError
"""
quant_sig: Optional[Signal] = None
llm_sig: Optional[Signal] = None
# Get quant signal
if self._quant is not None:
try:
quant_sig = self._quant.get_signal(ticker, date)
# Treat error signals (confidence=0, direction=0 with error metadata) as None
if quant_sig.metadata.get("error") or quant_sig.metadata.get("reason") == "no_data":
logger.warning("TradingOrchestrator: quant signal degraded for %s %s", ticker, date)
quant_sig = None
except Exception as e:
logger.error("TradingOrchestrator: quant get_signal failed: %s", e)
quant_sig = None
# Get llm signal
try:
llm_sig = self._llm.get_signal(ticker, date)
if llm_sig.metadata.get("error"):
logger.warning("TradingOrchestrator: llm signal degraded for %s %s", ticker, date)
llm_sig = None
except Exception as e:
logger.error("TradingOrchestrator: llm get_signal failed: %s", e)
llm_sig = None
# merge raises ValueError if both None
return self._merger.merge(quant_sig, llm_sig)

View File

@ -0,0 +1,162 @@
import json
import logging
import sqlite3
import sys
from datetime import datetime, timezone, timedelta
from typing import Any
import yfinance as yf
from orchestrator.config import OrchestratorConfig
from orchestrator.signals import Signal
logger = logging.getLogger(__name__)
class QuantRunner:
def __init__(self, config: OrchestratorConfig):
if not config.quant_backtest_path:
raise ValueError("OrchestratorConfig.quant_backtest_path must be set")
self._config = config
path = config.quant_backtest_path
if path not in sys.path:
sys.path.insert(0, path)
self._db_path = f"{path}/research_results/runs.db"
def get_signal(self, ticker: str, date: str) -> Signal:
"""
获取指定股票在指定日期的量化信号
date 格式'YYYY-MM-DD'
返回 Signal(source="quant")
"""
result = self._load_best_params()
params: dict = result["params"]
sharpe: float = result["sharpe_ratio"]
# 获取 date 前 60 天的历史数据
end_dt = datetime.strptime(date, "%Y-%m-%d")
start_dt = end_dt - timedelta(days=60)
start_str = start_dt.strftime("%Y-%m-%d")
df = yf.download(ticker, start=start_str, end=date, progress=False, auto_adjust=True)
if df.empty:
logger.warning("No price data for %s between %s and %s", ticker, start_str, date)
return Signal(
ticker=ticker,
direction=0,
confidence=0.0,
source="quant",
timestamp=datetime.now(timezone.utc),
metadata={"reason": "no_data"},
)
# 标准化列名为小写
df.columns = [c[0].lower() if isinstance(c, tuple) else c.lower() for c in df.columns]
# 用最佳参数创建 BollingerStrategy 实例
# Lazy import: requires quant_backtest_path to be in sys.path (set in __init__)
from strategies.momentum import BollingerStrategy
from core.data_models import Bar, OrderDirection
strategy = BollingerStrategy(
period=params.get("period", 20),
num_std=params.get("num_std", 2.0),
position_pct=params.get("position_pct", 0.20),
stop_loss_pct=params.get("stop_loss_pct", 0.05),
take_profit_pct=params.get("take_profit_pct", 0.15),
)
# 逐 bar 喂给策略,模拟历史回放
direction = 0
orders: list = []
context: dict[str, Any] = {"positions": {}}
for ts, row in df.iterrows():
bar = Bar(
symbol=ticker,
timestamp=ts.to_pydatetime() if hasattr(ts, "to_pydatetime") else ts,
open=float(row["open"]),
high=float(row["high"]),
low=float(row["low"]),
close=float(row["close"]),
volume=float(row.get("volume", 0)),
)
orders = strategy.on_bar(bar, context)
# 更新模拟持仓
for order in orders:
if order.direction == OrderDirection.BUY:
context["positions"][ticker] = order.volume
elif order.direction == OrderDirection.SELL:
context["positions"][ticker] = 0
# 最后一个 bar 的信号
last_orders = orders if df.shape[0] > 0 else []
for order in last_orders:
if order.direction == OrderDirection.BUY:
direction = 1
break
elif order.direction == OrderDirection.SELL:
direction = -1
break
# 计算 max_sharpe从 DB 中取全局最大值)
try:
with sqlite3.connect(self._db_path) as conn:
cur = conn.cursor()
cur.execute("SELECT MAX(sharpe_ratio) FROM backtest_results")
row = cur.fetchone()
max_sharpe = float(row[0]) if row and row[0] is not None else sharpe
except Exception:
max_sharpe = sharpe
confidence = self._calc_confidence(sharpe, max_sharpe)
return Signal(
ticker=ticker,
direction=direction,
confidence=confidence,
source="quant",
timestamp=datetime.now(timezone.utc),
metadata={"params": params, "sharpe_ratio": sharpe, "max_sharpe": max_sharpe},
)
def _load_best_params(self) -> dict:
"""
直接查 SQLite 获取 BollingerStrategy 最佳参数
参数是全局最优不区分股票backtest_results 表无 ticker 优化是全局的
strategy_type 支持 'BollingerStrategy' 'bollinger'兼容两种写法
"""
with sqlite3.connect(self._db_path) as conn:
cur = conn.cursor()
# 先按规格查 'BollingerStrategy',再 fallback 到 'bollinger'
cur.execute(
"""
SELECT params, sharpe_ratio
FROM backtest_results
WHERE strategy_type IN ('BollingerStrategy', 'bollinger')
ORDER BY sharpe_ratio DESC
LIMIT 1
""",
)
row = cur.fetchone()
if row is None:
raise ValueError(
"No BollingerStrategy results found in ResultStore. "
"Run optimization first: python quant_backtest/run_research.py"
)
params = json.loads(row[0]) if isinstance(row[0], str) else row[0]
return {"params": params, "sharpe_ratio": float(row[1])}
def _calc_confidence(self, sharpe: float, max_sharpe: float) -> float:
"""
Sharpe 归一化为置信度
- max_sharpe=0 时返回 0.5默认值避免除零
- sharpe/max_sharpe 上限截断到 1.0
- 下限截断到 0.0 Sharpe 不产生负置信度
"""
if max_sharpe == 0:
return 0.5
ratio = sharpe / max_sharpe
return max(0.0, min(1.0, ratio))

101
orchestrator/signals.py Normal file
View File

@ -0,0 +1,101 @@
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Optional
from orchestrator.config import OrchestratorConfig
logger = logging.getLogger(__name__)
@dataclass
class Signal:
ticker: str
direction: int # +1 买入, -1 卖出, 0 持有
confidence: float # 0.0 ~ 1.0
source: str # "quant" | "llm"
timestamp: datetime
metadata: dict = field(default_factory=dict) # 原始输出,用于调试
@dataclass
class FinalSignal:
ticker: str
direction: int # sign(quant_dir×quant_conf + llm_dir×llm_conf)sign(0)→0(HOLD)
confidence: float # abs(weighted_sum) / total_conf
quant_signal: Optional[Signal]
llm_signal: Optional[Signal]
timestamp: datetime
def _sign(x: float) -> int:
"""Return +1, -1, or 0."""
if x > 0:
return 1
elif x < 0:
return -1
return 0
class SignalMerger:
def __init__(self, config: OrchestratorConfig) -> None:
self._config = config
def merge(self, quant: Optional[Signal], llm: Optional[Signal]) -> FinalSignal:
now = datetime.now(timezone.utc)
# 两者均失败
if quant is None and llm is None:
raise ValueError("both quant and llm signals are None")
ticker = (quant or llm).ticker # type: ignore[union-attr]
# 只有 LLMquant 失败)
if quant is None:
return FinalSignal(
ticker=ticker,
direction=llm.direction,
confidence=min(llm.confidence * self._config.llm_solo_penalty,
self._config.llm_weight_cap),
quant_signal=None,
llm_signal=llm,
timestamp=now,
)
# 只有 Quantllm 失败)
if llm is None:
return FinalSignal(
ticker=ticker,
direction=quant.direction,
confidence=min(quant.confidence * self._config.quant_solo_penalty,
self._config.quant_weight_cap),
quant_signal=quant,
llm_signal=None,
timestamp=now,
)
# 两者都有:加权合并
# Cap each signal's contribution before merging
quant_conf = min(quant.confidence, self._config.quant_weight_cap)
llm_conf = min(llm.confidence, self._config.llm_weight_cap)
weighted_sum = (
quant.direction * quant_conf
+ llm.direction * llm_conf
)
final_direction = _sign(weighted_sum)
if final_direction == 0:
logger.info(
"SignalMerger: weighted_sum=0 for %s — signals cancel out, HOLD",
ticker,
)
total_conf = quant_conf + llm_conf
final_confidence = abs(weighted_sum) / total_conf if total_conf > 0 else 0.0
return FinalSignal(
ticker=ticker,
direction=final_direction,
confidence=final_confidence,
quant_signal=quant,
llm_signal=llm,
timestamp=now,
)

View File

View File

@ -0,0 +1,41 @@
"""Tests for LLMRunner._map_rating()."""
import tempfile
import pytest
from orchestrator.config import OrchestratorConfig
from orchestrator.llm_runner import LLMRunner
@pytest.fixture
def runner(tmp_path):
cfg = OrchestratorConfig(cache_dir=str(tmp_path))
return LLMRunner(cfg)
# All 5 known ratings
@pytest.mark.parametrize("rating,expected", [
("BUY", (1, 0.9)),
("OVERWEIGHT", (1, 0.6)),
("HOLD", (0, 0.5)),
("UNDERWEIGHT", (-1, 0.6)),
("SELL", (-1, 0.9)),
])
def test_map_rating_known(runner, rating, expected):
assert runner._map_rating(rating) == expected
# Unknown rating → (0, 0.5)
def test_map_rating_unknown(runner):
assert runner._map_rating("STRONG_BUY") == (0, 0.5)
# Case-insensitive
def test_map_rating_lowercase(runner):
assert runner._map_rating("buy") == (1, 0.9)
assert runner._map_rating("sell") == (-1, 0.9)
assert runner._map_rating("hold") == (0, 0.5)
# Empty string → (0, 0.5)
def test_map_rating_empty_string(runner):
assert runner._map_rating("") == (0, 0.5)

View File

@ -0,0 +1,65 @@
"""Tests for QuantRunner._calc_confidence()."""
import json
import sqlite3
import tempfile
import os
import pytest
from orchestrator.config import OrchestratorConfig
from orchestrator.quant_runner import QuantRunner
def _make_runner(tmp_path):
"""Create a QuantRunner with a minimal SQLite DB so __init__ succeeds."""
db_dir = tmp_path / "research_results"
db_dir.mkdir(parents=True)
db_path = db_dir / "runs.db"
with sqlite3.connect(str(db_path)) as conn:
conn.execute(
"""CREATE TABLE backtest_results (
id INTEGER PRIMARY KEY,
strategy_type TEXT,
params TEXT,
sharpe_ratio REAL
)"""
)
conn.execute(
"INSERT INTO backtest_results (strategy_type, params, sharpe_ratio) VALUES (?, ?, ?)",
("BollingerStrategy", json.dumps({"period": 20, "num_std": 2.0,
"position_pct": 0.2,
"stop_loss_pct": 0.05,
"take_profit_pct": 0.15}), 1.5),
)
cfg = OrchestratorConfig(quant_backtest_path=str(tmp_path))
return QuantRunner(cfg)
@pytest.fixture
def runner(tmp_path):
return _make_runner(tmp_path)
def test_calc_confidence_max_sharpe_zero(runner):
assert runner._calc_confidence(1.0, 0) == 0.5
def test_calc_confidence_half(runner):
result = runner._calc_confidence(1.0, 2.0)
assert result == pytest.approx(0.5)
def test_calc_confidence_full(runner):
result = runner._calc_confidence(2.0, 2.0)
assert result == pytest.approx(1.0)
def test_calc_confidence_clamped_above(runner):
result = runner._calc_confidence(3.0, 2.0)
assert result == pytest.approx(1.0)
def test_calc_confidence_clamped_below(runner):
result = runner._calc_confidence(-1.0, 2.0)
assert result == pytest.approx(0.0)

View File

@ -0,0 +1,120 @@
"""Tests for SignalMerger in orchestrator/signals.py."""
import math
import pytest
from datetime import datetime, timezone
from orchestrator.config import OrchestratorConfig
from orchestrator.signals import Signal, SignalMerger
def _make_signal(ticker="AAPL", direction=1, confidence=0.8, source="quant"):
return Signal(
ticker=ticker,
direction=direction,
confidence=confidence,
source=source,
timestamp=datetime.now(timezone.utc),
)
@pytest.fixture
def merger():
return SignalMerger(OrchestratorConfig())
# Branch 1: both None → ValueError
def test_merge_both_none_raises(merger):
with pytest.raises(ValueError):
merger.merge(None, None)
# Branch 2: quant only
def test_merge_quant_only(merger):
cfg = OrchestratorConfig()
q = _make_signal(direction=1, confidence=0.8, source="quant")
result = merger.merge(q, None)
assert result.direction == 1
expected_conf = min(0.8 * cfg.quant_solo_penalty, cfg.quant_weight_cap)
assert math.isclose(result.confidence, expected_conf)
assert result.quant_signal is q
assert result.llm_signal is None
def test_merge_quant_only_capped(merger):
cfg = OrchestratorConfig()
# confidence=1.0 * quant_solo_penalty=0.8 → 0.8 == quant_weight_cap=0.8, no clamp needed
q = _make_signal(direction=-1, confidence=1.0, source="quant")
result = merger.merge(q, None)
expected_conf = min(1.0 * cfg.quant_solo_penalty, cfg.quant_weight_cap)
assert math.isclose(result.confidence, expected_conf)
assert result.direction == -1
# Branch 3: llm only
def test_merge_llm_only(merger):
cfg = OrchestratorConfig()
l = _make_signal(direction=-1, confidence=0.9, source="llm")
result = merger.merge(None, l)
assert result.direction == -1
expected_conf = min(0.9 * cfg.llm_solo_penalty, cfg.llm_weight_cap)
assert math.isclose(result.confidence, expected_conf)
assert result.llm_signal is l
assert result.quant_signal is None
def test_merge_llm_only_capped(merger):
cfg = OrchestratorConfig()
# Force cap: confidence=1.0, llm_solo_penalty=0.7 → 0.7 < llm_weight_cap=0.9, no cap
l = _make_signal(direction=1, confidence=1.0, source="llm")
result = merger.merge(None, l)
expected_conf = min(1.0 * cfg.llm_solo_penalty, cfg.llm_weight_cap)
assert math.isclose(result.confidence, expected_conf)
# Branch 4: both present, same direction
def test_merge_both_same_direction(merger):
cfg = OrchestratorConfig()
q = _make_signal(direction=1, confidence=0.6, source="quant")
l = _make_signal(direction=1, confidence=0.8, source="llm")
result = merger.merge(q, l)
assert result.direction == 1
# caps applied per-signal before merging
quant_conf = min(0.6, cfg.quant_weight_cap) # 0.6
llm_conf = min(0.8, cfg.llm_weight_cap) # 0.8
weighted_sum = 1 * quant_conf + 1 * llm_conf # 1.4
total_conf = quant_conf + llm_conf # 1.4
expected_conf = abs(weighted_sum) / total_conf # 1.0
assert math.isclose(result.confidence, expected_conf)
# Branch 5: both present, opposite direction
def test_merge_both_opposite_direction_quant_wins(merger):
cfg = OrchestratorConfig()
# quant stronger: direction should be quant's
q = _make_signal(direction=1, confidence=0.9, source="quant")
l = _make_signal(direction=-1, confidence=0.3, source="llm")
result = merger.merge(q, l)
assert result.direction == 1
# caps applied per-signal before merging
quant_conf = min(0.9, cfg.quant_weight_cap) # 0.8
llm_conf = min(0.3, cfg.llm_weight_cap) # 0.3
weighted_sum = 1 * quant_conf + (-1) * llm_conf # 0.5
total_conf = quant_conf + llm_conf # 1.1
expected_conf = abs(weighted_sum) / total_conf
assert math.isclose(result.confidence, expected_conf)
def test_merge_both_opposite_direction_llm_wins(merger):
q = _make_signal(direction=1, confidence=0.2, source="quant")
l = _make_signal(direction=-1, confidence=0.8, source="llm")
result = merger.merge(q, l)
assert result.direction == -1
# weighted_sum=0 → direction=HOLD
def test_merge_weighted_sum_zero(merger):
q = _make_signal(direction=1, confidence=0.5, source="quant")
l = _make_signal(direction=-1, confidence=0.5, source="llm")
result = merger.merge(q, l)
assert result.direction == 0
assert math.isclose(result.confidence, 0.0)

View File

@ -4,6 +4,7 @@ FastAPI REST API + WebSocket for real-time analysis progress
"""
import asyncio
import fcntl
import hmac
import json
import os
import subprocess
@ -105,7 +106,9 @@ def _check_api_key(api_key: Optional[str]) -> bool:
required = _get_api_key()
if not required:
return True
return api_key == required
if not api_key:
return False
return hmac.compare_digest(api_key, required)
def _auth_error():
raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required")
@ -363,7 +366,7 @@ async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Head
# Use clean environment - don't inherit parent env
clean_env = {k: v for k, v in os.environ.items()
if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))}
clean_env["ANTHROPIC_API_KEY"] = api_key
clean_env["ANTHROPIC_API_KEY"] = anthropic_key
clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
proc = await asyncio.create_subprocess_exec(
@ -1100,6 +1103,45 @@ async def root():
return {"message": "TradingAgents Web Dashboard API", "version": "0.1.0"}
@app.websocket("/ws/orchestrator")
async def ws_orchestrator(websocket: WebSocket, api_key: Optional[str] = None):
"""WebSocket endpoint for orchestrator live signals."""
# Auth check before accepting — reject unauthenticated connections
if not _check_api_key(api_key):
await websocket.close(code=4401)
return
import sys
sys.path.insert(0, str(REPO_ROOT))
from orchestrator.config import OrchestratorConfig
from orchestrator.orchestrator import TradingOrchestrator
from orchestrator.live_mode import LiveMode
config = OrchestratorConfig(
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
)
orchestrator = TradingOrchestrator(config)
live = LiveMode(orchestrator)
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
payload = json.loads(data)
tickers = payload.get("tickers", [])
date = payload.get("date")
results = await live.run_once(tickers, date)
await websocket.send_text(json.dumps({"signals": results}))
except WebSocketDisconnect:
pass
except Exception as e:
try:
await websocket.send_text(json.dumps({"error": str(e)}))
except Exception:
pass
if __name__ == "__main__":
import uvicorn
# Run with: cd web_dashboard && ../env312/bin/python -m uvicorn main:app --reload