diff --git a/orchestrator/__init__.py b/orchestrator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestrator/backtest_mode.py b/orchestrator/backtest_mode.py new file mode 100644 index 00000000..604b81d2 --- /dev/null +++ b/orchestrator/backtest_mode.py @@ -0,0 +1,64 @@ +import logging +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import List + +from orchestrator.signals import FinalSignal + +logger = logging.getLogger(__name__) + + +@dataclass +class BacktestResult: + records: List[dict] = field(default_factory=list) + summary: dict = field(default_factory=dict) + + +class BacktestMode: + def __init__(self, orchestrator): + self._orchestrator = orchestrator + + def run(self, tickers: List[str], start_date: str, end_date: str) -> BacktestResult: + start = datetime.strptime(start_date, "%Y-%m-%d") + end = datetime.strptime(end_date, "%Y-%m-%d") + + records = [] + current = start + while current <= end: + if current.weekday() < 5: # skip weekends + date_str = current.strftime("%Y-%m-%d") + for ticker in tickers: + try: + sig = self._orchestrator.get_combined_signal(ticker, date_str) + records.append({ + "ticker": ticker, + "date": date_str, + "direction": sig.direction, + "confidence": sig.confidence, + "quant_direction": sig.quant_signal.direction if sig.quant_signal else None, + "llm_direction": sig.llm_signal.direction if sig.llm_signal else None, + }) + except Exception as e: + logger.error("BacktestMode: failed for %s %s: %s", ticker, date_str, e) + current += timedelta(days=1) + + summary = self._compute_summary(records, tickers) + return BacktestResult(records=records, summary=summary) + + def _compute_summary(self, records: List[dict], tickers: List[str]) -> dict: + summary = {} + for ticker in tickers: + ticker_records = [r for r in records if r["ticker"] == ticker] + if not ticker_records: + summary[ticker] = {"total_days": 0} + continue + directions = [r["direction"] for r in ticker_records] + confidences = [r["confidence"] for r in ticker_records] + summary[ticker] = { + "total_days": len(ticker_records), + "buy_days": directions.count(1), + "sell_days": directions.count(-1), + "hold_days": directions.count(0), + "avg_confidence": sum(confidences) / len(confidences), + } + return summary diff --git a/orchestrator/config.py b/orchestrator/config.py new file mode 100644 index 00000000..9d3eaea5 --- /dev/null +++ b/orchestrator/config.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass, field + + +@dataclass +class OrchestratorConfig: + # Must be set to the local quant backtest output directory before use + quant_backtest_path: str = "" + trading_agents_config: dict = field(default_factory=dict) + quant_weight_cap: float = 0.8 # quant 置信度上限 + llm_weight_cap: float = 0.9 # llm 置信度上限 + llm_batch_days: int = 7 # LLM 每隔几天运行一次(节省 API) + cache_dir: str = "orchestrator/cache" # LLM 信号缓存目录 + llm_solo_penalty: float = 0.7 # LLM 单轨时的置信度折扣 + quant_solo_penalty: float = 0.8 # Quant 单轨时的置信度折扣 diff --git a/orchestrator/examples/__init__.py b/orchestrator/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestrator/examples/run_backtest.py b/orchestrator/examples/run_backtest.py new file mode 100644 index 00000000..f7f5bc49 --- /dev/null +++ b/orchestrator/examples/run_backtest.py @@ -0,0 +1,41 @@ +""" +Example: Run orchestrator backtest for 宁德时代 (300750.SZ) over 2023. + +Usage: + cd /path/to/TradingAgents + QUANT_BACKTEST_PATH=/path/to/quant_backtest python orchestrator/examples/run_backtest.py +""" +import json +import logging +import os +import sys + +# Add repo root to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from orchestrator.config import OrchestratorConfig +from orchestrator.orchestrator import TradingOrchestrator +from orchestrator.backtest_mode import BacktestMode + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") + +config = OrchestratorConfig( + quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""), + cache_dir="orchestrator/cache", +) + +orchestrator = TradingOrchestrator(config) +backtest = BacktestMode(orchestrator) + +result = backtest.run( + tickers=["300750.SZ"], + start_date="2023-01-01", + end_date="2023-12-31", +) + +print(f"\n=== Backtest Summary ===") +print(json.dumps(result.summary, indent=2, ensure_ascii=False)) +print(f"\nTotal records: {len(result.records)}") +if result.records: + print(f"First record: {result.records[0]}") + print(f"Last record: {result.records[-1]}") diff --git a/orchestrator/examples/run_live.py b/orchestrator/examples/run_live.py new file mode 100644 index 00000000..4a652bfa --- /dev/null +++ b/orchestrator/examples/run_live.py @@ -0,0 +1,41 @@ +""" +Example: Run orchestrator live mode for a list of tickers. + +Usage: + cd /path/to/TradingAgents + QUANT_BACKTEST_PATH=/path/to/quant_backtest python orchestrator/examples/run_live.py +""" +import asyncio +import json +import logging +import os +import sys +from datetime import datetime, timezone + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from orchestrator.config import OrchestratorConfig +from orchestrator.orchestrator import TradingOrchestrator +from orchestrator.live_mode import LiveMode + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") + +TICKERS = ["300750.SZ", "603259.SS"] + +config = OrchestratorConfig( + quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""), + cache_dir="orchestrator/cache", +) + +orchestrator = TradingOrchestrator(config) +live = LiveMode(orchestrator) + + +async def main(): + today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + print(f"\n=== Live Signals for {today} ===") + results = await live.run_once(TICKERS, date=today) + print(json.dumps(results, indent=2, ensure_ascii=False)) + + +asyncio.run(main()) diff --git a/orchestrator/live_mode.py b/orchestrator/live_mode.py new file mode 100644 index 00000000..76c04c51 --- /dev/null +++ b/orchestrator/live_mode.py @@ -0,0 +1,48 @@ +import asyncio +import logging +from datetime import datetime, timezone +from typing import List, Optional + +logger = logging.getLogger(__name__) + + +class LiveMode: + """ + Triggers signal computation for a list of tickers and broadcasts + results via a callback (e.g., WebSocket send). + """ + + def __init__(self, orchestrator): + self._orchestrator = orchestrator + + async def run_once(self, tickers: List[str], date: Optional[str] = None) -> List[dict]: + """ + Compute combined signals for all tickers on the given date (default: today). + Returns list of signal dicts. + """ + if date is None: + date = datetime.now(timezone.utc).strftime("%Y-%m-%d") + + results = [] + for ticker in tickers: + try: + sig = await asyncio.to_thread( + self._orchestrator.get_combined_signal, ticker, date + ) + results.append({ + "ticker": ticker, + "date": date, + "direction": sig.direction, + "confidence": sig.confidence, + "quant_direction": sig.quant_signal.direction if sig.quant_signal else None, + "llm_direction": sig.llm_signal.direction if sig.llm_signal else None, + "timestamp": sig.timestamp.isoformat(), + }) + except Exception as e: + logger.error("LiveMode: failed for %s %s: %s", ticker, date, e) + results.append({ + "ticker": ticker, + "date": date, + "error": str(e), + }) + return results diff --git a/orchestrator/llm_runner.py b/orchestrator/llm_runner.py new file mode 100644 index 00000000..8dcb3c46 --- /dev/null +++ b/orchestrator/llm_runner.py @@ -0,0 +1,95 @@ +import json +import logging +import os +from datetime import datetime, timezone + +from orchestrator.config import OrchestratorConfig +from orchestrator.signals import Signal + +logger = logging.getLogger(__name__) + + +class LLMRunner: + def __init__(self, config: OrchestratorConfig): + self._config = config + self._graph = None # Lazy-initialized on first get_signal() call (requires API key) + self.cache_dir = config.cache_dir + os.makedirs(self.cache_dir, exist_ok=True) + + def _get_graph(self): + """Lazy-initialize TradingAgentsGraph (heavy, requires API key at init time).""" + if self._graph is None: + from tradingagents.graph.trading_graph import TradingAgentsGraph + trading_cfg = self._config.trading_agents_config if self._config.trading_agents_config else None + self._graph = TradingAgentsGraph(config=trading_cfg) + return self._graph + + def get_signal(self, ticker: str, date: str) -> Signal: + """获取指定股票在指定日期的 LLM 信号,带缓存。""" + safe_ticker = ticker.replace("/", "_") # sanitize for filesystem (e.g. BRK/B) + cache_path = os.path.join(self.cache_dir, f"{safe_ticker}_{date}.json") + + if os.path.exists(cache_path): + logger.info("LLMRunner: cache hit for %s %s", ticker, date) + with open(cache_path, "r", encoding="utf-8") as f: + data = json.load(f) + # Use stored direction/confidence directly to avoid re-mapping drift + return Signal( + ticker=ticker, + direction=data["direction"], + confidence=data["confidence"], + source="llm", + timestamp=datetime.fromisoformat(data["timestamp"]), + metadata=data, + ) + + try: + _final_state, processed_signal = self._get_graph().propagate(ticker, date) + rating = processed_signal if isinstance(processed_signal, str) else str(processed_signal) + direction, confidence = self._map_rating(rating) + now = datetime.now(timezone.utc) + + cache_data = { + "rating": rating, + "direction": direction, + "confidence": confidence, + "timestamp": now.isoformat(), + "ticker": ticker, + "date": date, + } + with open(cache_path, "w", encoding="utf-8") as f: + json.dump(cache_data, f, ensure_ascii=False, indent=2) + + return Signal( + ticker=ticker, + direction=direction, + confidence=confidence, + source="llm", + timestamp=now, + metadata=cache_data, + ) + except Exception as e: + logger.error("LLMRunner: propagate failed for %s %s: %s", ticker, date, e) + return Signal( + ticker=ticker, + direction=0, + confidence=0.0, + source="llm", + timestamp=datetime.now(timezone.utc), + metadata={"error": str(e)}, + ) + + def _map_rating(self, rating: str) -> tuple[int, float]: + """将 5 级评级映射为 (direction, confidence)。""" + mapping = { + "BUY": (1, 0.9), + "OVERWEIGHT": (1, 0.6), + "HOLD": (0, 0.5), + "UNDERWEIGHT": (-1, 0.6), + "SELL": (-1, 0.9), + } + result = mapping.get(rating.upper() if rating else "", None) + if result is None: + logger.warning("LLMRunner: unknown rating %r, falling back to HOLD", rating) + return (0, 0.5) + return result diff --git a/orchestrator/orchestrator.py b/orchestrator/orchestrator.py new file mode 100644 index 00000000..baf042eb --- /dev/null +++ b/orchestrator/orchestrator.py @@ -0,0 +1,63 @@ +import logging +from datetime import datetime, timezone +from typing import Optional + +from orchestrator.config import OrchestratorConfig +from orchestrator.signals import Signal, FinalSignal, SignalMerger +from orchestrator.quant_runner import QuantRunner +from orchestrator.llm_runner import LLMRunner + +logger = logging.getLogger(__name__) + + +class TradingOrchestrator: + def __init__(self, config: OrchestratorConfig): + self._config = config + self._merger = SignalMerger(config) + self._quant: Optional[QuantRunner] = None + self._llm: Optional[LLMRunner] = None + + # Initialize runners (quant requires quant_backtest_path) + if config.quant_backtest_path: + try: + self._quant = QuantRunner(config) + except Exception as e: + logger.warning("TradingOrchestrator: QuantRunner init failed: %s", e) + + self._llm = LLMRunner(config) + + def get_combined_signal(self, ticker: str, date: str) -> FinalSignal: + """ + Get merged signal for ticker on date. + Degradation: + - quant fails (error signal): use llm only with llm_solo_penalty + - llm fails (error signal): use quant only with quant_solo_penalty + - both fail: raises ValueError + """ + quant_sig: Optional[Signal] = None + llm_sig: Optional[Signal] = None + + # Get quant signal + if self._quant is not None: + try: + quant_sig = self._quant.get_signal(ticker, date) + # Treat error signals (confidence=0, direction=0 with error metadata) as None + if quant_sig.metadata.get("error") or quant_sig.metadata.get("reason") == "no_data": + logger.warning("TradingOrchestrator: quant signal degraded for %s %s", ticker, date) + quant_sig = None + except Exception as e: + logger.error("TradingOrchestrator: quant get_signal failed: %s", e) + quant_sig = None + + # Get llm signal + try: + llm_sig = self._llm.get_signal(ticker, date) + if llm_sig.metadata.get("error"): + logger.warning("TradingOrchestrator: llm signal degraded for %s %s", ticker, date) + llm_sig = None + except Exception as e: + logger.error("TradingOrchestrator: llm get_signal failed: %s", e) + llm_sig = None + + # merge raises ValueError if both None + return self._merger.merge(quant_sig, llm_sig) diff --git a/orchestrator/quant_runner.py b/orchestrator/quant_runner.py new file mode 100644 index 00000000..42d2b8b1 --- /dev/null +++ b/orchestrator/quant_runner.py @@ -0,0 +1,162 @@ +import json +import logging +import sqlite3 +import sys +from datetime import datetime, timezone, timedelta +from typing import Any + +import yfinance as yf + +from orchestrator.config import OrchestratorConfig +from orchestrator.signals import Signal + +logger = logging.getLogger(__name__) + + +class QuantRunner: + def __init__(self, config: OrchestratorConfig): + if not config.quant_backtest_path: + raise ValueError("OrchestratorConfig.quant_backtest_path must be set") + self._config = config + path = config.quant_backtest_path + if path not in sys.path: + sys.path.insert(0, path) + self._db_path = f"{path}/research_results/runs.db" + + def get_signal(self, ticker: str, date: str) -> Signal: + """ + 获取指定股票在指定日期的量化信号。 + date 格式:'YYYY-MM-DD' + 返回 Signal(source="quant") + """ + result = self._load_best_params() + params: dict = result["params"] + sharpe: float = result["sharpe_ratio"] + + # 获取 date 前 60 天的历史数据 + end_dt = datetime.strptime(date, "%Y-%m-%d") + start_dt = end_dt - timedelta(days=60) + start_str = start_dt.strftime("%Y-%m-%d") + + df = yf.download(ticker, start=start_str, end=date, progress=False, auto_adjust=True) + if df.empty: + logger.warning("No price data for %s between %s and %s", ticker, start_str, date) + return Signal( + ticker=ticker, + direction=0, + confidence=0.0, + source="quant", + timestamp=datetime.now(timezone.utc), + metadata={"reason": "no_data"}, + ) + + # 标准化列名为小写 + df.columns = [c[0].lower() if isinstance(c, tuple) else c.lower() for c in df.columns] + + # 用最佳参数创建 BollingerStrategy 实例 + # Lazy import: requires quant_backtest_path to be in sys.path (set in __init__) + from strategies.momentum import BollingerStrategy + from core.data_models import Bar, OrderDirection + + strategy = BollingerStrategy( + period=params.get("period", 20), + num_std=params.get("num_std", 2.0), + position_pct=params.get("position_pct", 0.20), + stop_loss_pct=params.get("stop_loss_pct", 0.05), + take_profit_pct=params.get("take_profit_pct", 0.15), + ) + + # 逐 bar 喂给策略,模拟历史回放 + direction = 0 + orders: list = [] + context: dict[str, Any] = {"positions": {}} + + for ts, row in df.iterrows(): + bar = Bar( + symbol=ticker, + timestamp=ts.to_pydatetime() if hasattr(ts, "to_pydatetime") else ts, + open=float(row["open"]), + high=float(row["high"]), + low=float(row["low"]), + close=float(row["close"]), + volume=float(row.get("volume", 0)), + ) + orders = strategy.on_bar(bar, context) + # 更新模拟持仓 + for order in orders: + if order.direction == OrderDirection.BUY: + context["positions"][ticker] = order.volume + elif order.direction == OrderDirection.SELL: + context["positions"][ticker] = 0 + + # 最后一个 bar 的信号 + last_orders = orders if df.shape[0] > 0 else [] + for order in last_orders: + if order.direction == OrderDirection.BUY: + direction = 1 + break + elif order.direction == OrderDirection.SELL: + direction = -1 + break + + # 计算 max_sharpe(从 DB 中取全局最大值) + try: + with sqlite3.connect(self._db_path) as conn: + cur = conn.cursor() + cur.execute("SELECT MAX(sharpe_ratio) FROM backtest_results") + row = cur.fetchone() + max_sharpe = float(row[0]) if row and row[0] is not None else sharpe + except Exception: + max_sharpe = sharpe + + confidence = self._calc_confidence(sharpe, max_sharpe) + + return Signal( + ticker=ticker, + direction=direction, + confidence=confidence, + source="quant", + timestamp=datetime.now(timezone.utc), + metadata={"params": params, "sharpe_ratio": sharpe, "max_sharpe": max_sharpe}, + ) + + def _load_best_params(self) -> dict: + """ + 直接查 SQLite 获取 BollingerStrategy 最佳参数。 + 参数是全局最优,不区分股票(backtest_results 表无 ticker 列,优化是全局的)。 + strategy_type 支持 'BollingerStrategy' 和 'bollinger'(兼容两种写法)。 + """ + with sqlite3.connect(self._db_path) as conn: + cur = conn.cursor() + # 先按规格查 'BollingerStrategy',再 fallback 到 'bollinger' + cur.execute( + """ + SELECT params, sharpe_ratio + FROM backtest_results + WHERE strategy_type IN ('BollingerStrategy', 'bollinger') + ORDER BY sharpe_ratio DESC + LIMIT 1 + """, + ) + row = cur.fetchone() + + if row is None: + raise ValueError( + "No BollingerStrategy results found in ResultStore. " + "Run optimization first: python quant_backtest/run_research.py" + ) + + params = json.loads(row[0]) if isinstance(row[0], str) else row[0] + return {"params": params, "sharpe_ratio": float(row[1])} + + def _calc_confidence(self, sharpe: float, max_sharpe: float) -> float: + """ + Sharpe 归一化为置信度。 + - max_sharpe=0 时返回 0.5(默认值,避免除零) + - sharpe/max_sharpe 上限截断到 1.0 + - 下限截断到 0.0(负 Sharpe 不产生负置信度) + """ + if max_sharpe == 0: + return 0.5 + ratio = sharpe / max_sharpe + return max(0.0, min(1.0, ratio)) diff --git a/orchestrator/signals.py b/orchestrator/signals.py new file mode 100644 index 00000000..7283c725 --- /dev/null +++ b/orchestrator/signals.py @@ -0,0 +1,101 @@ +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Optional + +from orchestrator.config import OrchestratorConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class Signal: + ticker: str + direction: int # +1 买入, -1 卖出, 0 持有 + confidence: float # 0.0 ~ 1.0 + source: str # "quant" | "llm" + timestamp: datetime + metadata: dict = field(default_factory=dict) # 原始输出,用于调试 + + +@dataclass +class FinalSignal: + ticker: str + direction: int # sign(quant_dir×quant_conf + llm_dir×llm_conf),sign(0)→0(HOLD) + confidence: float # abs(weighted_sum) / total_conf + quant_signal: Optional[Signal] + llm_signal: Optional[Signal] + timestamp: datetime + + +def _sign(x: float) -> int: + """Return +1, -1, or 0.""" + if x > 0: + return 1 + elif x < 0: + return -1 + return 0 + + +class SignalMerger: + def __init__(self, config: OrchestratorConfig) -> None: + self._config = config + + def merge(self, quant: Optional[Signal], llm: Optional[Signal]) -> FinalSignal: + now = datetime.now(timezone.utc) + + # 两者均失败 + if quant is None and llm is None: + raise ValueError("both quant and llm signals are None") + + ticker = (quant or llm).ticker # type: ignore[union-attr] + + # 只有 LLM(quant 失败) + if quant is None: + return FinalSignal( + ticker=ticker, + direction=llm.direction, + confidence=min(llm.confidence * self._config.llm_solo_penalty, + self._config.llm_weight_cap), + quant_signal=None, + llm_signal=llm, + timestamp=now, + ) + + # 只有 Quant(llm 失败) + if llm is None: + return FinalSignal( + ticker=ticker, + direction=quant.direction, + confidence=min(quant.confidence * self._config.quant_solo_penalty, + self._config.quant_weight_cap), + quant_signal=quant, + llm_signal=None, + timestamp=now, + ) + + # 两者都有:加权合并 + # Cap each signal's contribution before merging + quant_conf = min(quant.confidence, self._config.quant_weight_cap) + llm_conf = min(llm.confidence, self._config.llm_weight_cap) + weighted_sum = ( + quant.direction * quant_conf + + llm.direction * llm_conf + ) + final_direction = _sign(weighted_sum) + if final_direction == 0: + logger.info( + "SignalMerger: weighted_sum=0 for %s — signals cancel out, HOLD", + ticker, + ) + total_conf = quant_conf + llm_conf + final_confidence = abs(weighted_sum) / total_conf if total_conf > 0 else 0.0 + + return FinalSignal( + ticker=ticker, + direction=final_direction, + confidence=final_confidence, + quant_signal=quant, + llm_signal=llm, + timestamp=now, + ) diff --git a/orchestrator/tests/__init__.py b/orchestrator/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestrator/tests/test_llm_runner.py b/orchestrator/tests/test_llm_runner.py new file mode 100644 index 00000000..a4b7bbeb --- /dev/null +++ b/orchestrator/tests/test_llm_runner.py @@ -0,0 +1,41 @@ +"""Tests for LLMRunner._map_rating().""" +import tempfile +import pytest + +from orchestrator.config import OrchestratorConfig +from orchestrator.llm_runner import LLMRunner + + +@pytest.fixture +def runner(tmp_path): + cfg = OrchestratorConfig(cache_dir=str(tmp_path)) + return LLMRunner(cfg) + + +# All 5 known ratings +@pytest.mark.parametrize("rating,expected", [ + ("BUY", (1, 0.9)), + ("OVERWEIGHT", (1, 0.6)), + ("HOLD", (0, 0.5)), + ("UNDERWEIGHT", (-1, 0.6)), + ("SELL", (-1, 0.9)), +]) +def test_map_rating_known(runner, rating, expected): + assert runner._map_rating(rating) == expected + + +# Unknown rating → (0, 0.5) +def test_map_rating_unknown(runner): + assert runner._map_rating("STRONG_BUY") == (0, 0.5) + + +# Case-insensitive +def test_map_rating_lowercase(runner): + assert runner._map_rating("buy") == (1, 0.9) + assert runner._map_rating("sell") == (-1, 0.9) + assert runner._map_rating("hold") == (0, 0.5) + + +# Empty string → (0, 0.5) +def test_map_rating_empty_string(runner): + assert runner._map_rating("") == (0, 0.5) diff --git a/orchestrator/tests/test_quant_runner.py b/orchestrator/tests/test_quant_runner.py new file mode 100644 index 00000000..73b95da5 --- /dev/null +++ b/orchestrator/tests/test_quant_runner.py @@ -0,0 +1,65 @@ +"""Tests for QuantRunner._calc_confidence().""" +import json +import sqlite3 +import tempfile +import os +import pytest + +from orchestrator.config import OrchestratorConfig +from orchestrator.quant_runner import QuantRunner + + +def _make_runner(tmp_path): + """Create a QuantRunner with a minimal SQLite DB so __init__ succeeds.""" + db_dir = tmp_path / "research_results" + db_dir.mkdir(parents=True) + db_path = db_dir / "runs.db" + + with sqlite3.connect(str(db_path)) as conn: + conn.execute( + """CREATE TABLE backtest_results ( + id INTEGER PRIMARY KEY, + strategy_type TEXT, + params TEXT, + sharpe_ratio REAL + )""" + ) + conn.execute( + "INSERT INTO backtest_results (strategy_type, params, sharpe_ratio) VALUES (?, ?, ?)", + ("BollingerStrategy", json.dumps({"period": 20, "num_std": 2.0, + "position_pct": 0.2, + "stop_loss_pct": 0.05, + "take_profit_pct": 0.15}), 1.5), + ) + + cfg = OrchestratorConfig(quant_backtest_path=str(tmp_path)) + return QuantRunner(cfg) + + +@pytest.fixture +def runner(tmp_path): + return _make_runner(tmp_path) + + +def test_calc_confidence_max_sharpe_zero(runner): + assert runner._calc_confidence(1.0, 0) == 0.5 + + +def test_calc_confidence_half(runner): + result = runner._calc_confidence(1.0, 2.0) + assert result == pytest.approx(0.5) + + +def test_calc_confidence_full(runner): + result = runner._calc_confidence(2.0, 2.0) + assert result == pytest.approx(1.0) + + +def test_calc_confidence_clamped_above(runner): + result = runner._calc_confidence(3.0, 2.0) + assert result == pytest.approx(1.0) + + +def test_calc_confidence_clamped_below(runner): + result = runner._calc_confidence(-1.0, 2.0) + assert result == pytest.approx(0.0) diff --git a/orchestrator/tests/test_signals.py b/orchestrator/tests/test_signals.py new file mode 100644 index 00000000..bbd5b2aa --- /dev/null +++ b/orchestrator/tests/test_signals.py @@ -0,0 +1,120 @@ +"""Tests for SignalMerger in orchestrator/signals.py.""" +import math +import pytest +from datetime import datetime, timezone + +from orchestrator.config import OrchestratorConfig +from orchestrator.signals import Signal, SignalMerger + + +def _make_signal(ticker="AAPL", direction=1, confidence=0.8, source="quant"): + return Signal( + ticker=ticker, + direction=direction, + confidence=confidence, + source=source, + timestamp=datetime.now(timezone.utc), + ) + + +@pytest.fixture +def merger(): + return SignalMerger(OrchestratorConfig()) + + +# Branch 1: both None → ValueError +def test_merge_both_none_raises(merger): + with pytest.raises(ValueError): + merger.merge(None, None) + + +# Branch 2: quant only +def test_merge_quant_only(merger): + cfg = OrchestratorConfig() + q = _make_signal(direction=1, confidence=0.8, source="quant") + result = merger.merge(q, None) + assert result.direction == 1 + expected_conf = min(0.8 * cfg.quant_solo_penalty, cfg.quant_weight_cap) + assert math.isclose(result.confidence, expected_conf) + assert result.quant_signal is q + assert result.llm_signal is None + + +def test_merge_quant_only_capped(merger): + cfg = OrchestratorConfig() + # confidence=1.0 * quant_solo_penalty=0.8 → 0.8 == quant_weight_cap=0.8, no clamp needed + q = _make_signal(direction=-1, confidence=1.0, source="quant") + result = merger.merge(q, None) + expected_conf = min(1.0 * cfg.quant_solo_penalty, cfg.quant_weight_cap) + assert math.isclose(result.confidence, expected_conf) + assert result.direction == -1 + + +# Branch 3: llm only +def test_merge_llm_only(merger): + cfg = OrchestratorConfig() + l = _make_signal(direction=-1, confidence=0.9, source="llm") + result = merger.merge(None, l) + assert result.direction == -1 + expected_conf = min(0.9 * cfg.llm_solo_penalty, cfg.llm_weight_cap) + assert math.isclose(result.confidence, expected_conf) + assert result.llm_signal is l + assert result.quant_signal is None + + +def test_merge_llm_only_capped(merger): + cfg = OrchestratorConfig() + # Force cap: confidence=1.0, llm_solo_penalty=0.7 → 0.7 < llm_weight_cap=0.9, no cap + l = _make_signal(direction=1, confidence=1.0, source="llm") + result = merger.merge(None, l) + expected_conf = min(1.0 * cfg.llm_solo_penalty, cfg.llm_weight_cap) + assert math.isclose(result.confidence, expected_conf) + + +# Branch 4: both present, same direction +def test_merge_both_same_direction(merger): + cfg = OrchestratorConfig() + q = _make_signal(direction=1, confidence=0.6, source="quant") + l = _make_signal(direction=1, confidence=0.8, source="llm") + result = merger.merge(q, l) + assert result.direction == 1 + # caps applied per-signal before merging + quant_conf = min(0.6, cfg.quant_weight_cap) # 0.6 + llm_conf = min(0.8, cfg.llm_weight_cap) # 0.8 + weighted_sum = 1 * quant_conf + 1 * llm_conf # 1.4 + total_conf = quant_conf + llm_conf # 1.4 + expected_conf = abs(weighted_sum) / total_conf # 1.0 + assert math.isclose(result.confidence, expected_conf) + + +# Branch 5: both present, opposite direction +def test_merge_both_opposite_direction_quant_wins(merger): + cfg = OrchestratorConfig() + # quant stronger: direction should be quant's + q = _make_signal(direction=1, confidence=0.9, source="quant") + l = _make_signal(direction=-1, confidence=0.3, source="llm") + result = merger.merge(q, l) + assert result.direction == 1 + # caps applied per-signal before merging + quant_conf = min(0.9, cfg.quant_weight_cap) # 0.8 + llm_conf = min(0.3, cfg.llm_weight_cap) # 0.3 + weighted_sum = 1 * quant_conf + (-1) * llm_conf # 0.5 + total_conf = quant_conf + llm_conf # 1.1 + expected_conf = abs(weighted_sum) / total_conf + assert math.isclose(result.confidence, expected_conf) + + +def test_merge_both_opposite_direction_llm_wins(merger): + q = _make_signal(direction=1, confidence=0.2, source="quant") + l = _make_signal(direction=-1, confidence=0.8, source="llm") + result = merger.merge(q, l) + assert result.direction == -1 + + +# weighted_sum=0 → direction=HOLD +def test_merge_weighted_sum_zero(merger): + q = _make_signal(direction=1, confidence=0.5, source="quant") + l = _make_signal(direction=-1, confidence=0.5, source="llm") + result = merger.merge(q, l) + assert result.direction == 0 + assert math.isclose(result.confidence, 0.0) diff --git a/web_dashboard/backend/main.py b/web_dashboard/backend/main.py index 05c70daa..2e859540 100644 --- a/web_dashboard/backend/main.py +++ b/web_dashboard/backend/main.py @@ -4,6 +4,7 @@ FastAPI REST API + WebSocket for real-time analysis progress """ import asyncio import fcntl +import hmac import json import os import subprocess @@ -105,7 +106,9 @@ def _check_api_key(api_key: Optional[str]) -> bool: required = _get_api_key() if not required: return True - return api_key == required + if not api_key: + return False + return hmac.compare_digest(api_key, required) def _auth_error(): raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required") @@ -363,7 +366,7 @@ async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Head # Use clean environment - don't inherit parent env clean_env = {k: v for k, v in os.environ.items() if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))} - clean_env["ANTHROPIC_API_KEY"] = api_key + clean_env["ANTHROPIC_API_KEY"] = anthropic_key clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic" proc = await asyncio.create_subprocess_exec( @@ -1100,6 +1103,45 @@ async def root(): return {"message": "TradingAgents Web Dashboard API", "version": "0.1.0"} +@app.websocket("/ws/orchestrator") +async def ws_orchestrator(websocket: WebSocket, api_key: Optional[str] = None): + """WebSocket endpoint for orchestrator live signals.""" + # Auth check before accepting — reject unauthenticated connections + if not _check_api_key(api_key): + await websocket.close(code=4401) + return + + import sys + sys.path.insert(0, str(REPO_ROOT)) + from orchestrator.config import OrchestratorConfig + from orchestrator.orchestrator import TradingOrchestrator + from orchestrator.live_mode import LiveMode + + config = OrchestratorConfig( + quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""), + ) + orchestrator = TradingOrchestrator(config) + live = LiveMode(orchestrator) + + await websocket.accept() + try: + while True: + data = await websocket.receive_text() + payload = json.loads(data) + tickers = payload.get("tickers", []) + date = payload.get("date") + + results = await live.run_once(tickers, date) + await websocket.send_text(json.dumps({"signals": results})) + except WebSocketDisconnect: + pass + except Exception as e: + try: + await websocket.send_text(json.dumps({"error": str(e)})) + except Exception: + pass + + if __name__ == "__main__": import uvicorn # Run with: cd web_dashboard && ../env312/bin/python -m uvicorn main:app --reload