From 56dc76d44a6cc3a4b67ea9d80cb3401000262d87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 9 Apr 2026 21:35:31 +0800 Subject: [PATCH 01/14] feat(orchestrator): add signals.py and config.py - Signal / FinalSignal dataclasses - SignalMerger with weighted merge, single-track fallbacks, and cancel-out HOLD - OrchestratorConfig with all required fields --- orchestrator/__init__.py | 0 orchestrator/config.py | 11 +++++ orchestrator/signals.py | 100 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+) create mode 100644 orchestrator/__init__.py create mode 100644 orchestrator/config.py create mode 100644 orchestrator/signals.py diff --git a/orchestrator/__init__.py b/orchestrator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestrator/config.py b/orchestrator/config.py new file mode 100644 index 00000000..0beb6c5e --- /dev/null +++ b/orchestrator/config.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass, field + + +@dataclass +class OrchestratorConfig: + quant_backtest_path: str = "/Users/chenshaojie/Downloads/quant_backtest" + trading_agents_config: dict = field(default_factory=dict) + quant_weight_cap: float = 0.8 # quant 置信度上限 + llm_weight_cap: float = 0.9 # llm 置信度上限 + llm_batch_days: int = 7 # LLM 每隔几天运行一次(节省 API) + cache_dir: str = "orchestrator/cache" # LLM 信号缓存目录 diff --git a/orchestrator/signals.py b/orchestrator/signals.py new file mode 100644 index 00000000..f9549dcc --- /dev/null +++ b/orchestrator/signals.py @@ -0,0 +1,100 @@ +import logging +from dataclasses import dataclass, field +from datetime import datetime +from typing import Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class Signal: + ticker: str + direction: int # +1 买入, -1 卖出, 0 持有 + confidence: float # 0.0 ~ 1.0 + source: str # "quant" | "llm" + timestamp: datetime + metadata: dict = field(default_factory=dict) # 原始输出,用于调试 + + +@dataclass +class FinalSignal: + ticker: str + direction: int # sign(quant_dir×quant_conf + llm_dir×llm_conf),sign(0)→0(HOLD) + confidence: float # abs(weighted_sum) / total_conf + quant_signal: Optional[Signal] + llm_signal: Optional[Signal] + timestamp: datetime + + +def _sign(x: float) -> int: + """Return +1, -1, or 0.""" + if x > 0: + return 1 + elif x < 0: + return -1 + return 0 + + +class SignalMerger: + def merge(self, quant: Optional[Signal], llm: Optional[Signal]) -> FinalSignal: + now = datetime.utcnow() + + # 两者均失败 + if quant is None and llm is None: + ticker = "" + return FinalSignal( + ticker=ticker, + direction=0, + confidence=0.0, + quant_signal=None, + llm_signal=None, + timestamp=now, + ) + + ticker = (quant or llm).ticker # type: ignore[union-attr] + + # 只有 LLM(quant 失败) + if quant is None: + assert llm is not None + return FinalSignal( + ticker=ticker, + direction=llm.direction, + confidence=llm.confidence * 0.7, + quant_signal=None, + llm_signal=llm, + timestamp=now, + ) + + # 只有 Quant(llm 失败) + if llm is None: + return FinalSignal( + ticker=ticker, + direction=quant.direction, + confidence=quant.confidence * 0.8, + quant_signal=quant, + llm_signal=None, + timestamp=now, + ) + + # 两者都有:加权合并 + weighted_sum = ( + quant.direction * quant.confidence + + llm.direction * llm.confidence + ) + final_direction = _sign(weighted_sum) + if final_direction == 0: + logger.info( + "SignalMerger: weighted_sum=0 for %s — signals cancel out, HOLD", + ticker, + ) + total_conf = quant.confidence + llm.confidence + final_confidence = abs(weighted_sum) / total_conf if total_conf > 0 else 0.0 + + return FinalSignal( + ticker=ticker, + direction=final_direction, + confidence=final_confidence, + quant_signal=quant, + llm_signal=llm, + timestamp=now, + ) From dacb3316fa5bdc03ca3fc14d39171b9d190a5607 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 9 Apr 2026 21:39:23 +0800 Subject: [PATCH 02/14] fix(orchestrator): code quality fixes in config and signals - config: remove hardcoded absolute path for quant_backtest_path (now empty string) - config: add llm_solo_penalty (0.7) and quant_solo_penalty (0.8) fields - signals: SignalMerger now accepts OrchestratorConfig in __init__ - signals: use config.llm_solo_penalty / quant_solo_penalty instead of magic numbers - signals: apply quant_weight_cap / llm_weight_cap as confidence upper bounds - signals: both-None branch raises ValueError instead of returning ticker="" - signals: replace assert with explicit ValueError for llm-None-when-quant-None - signals: replace datetime.utcnow() with datetime.now(timezone.utc) --- orchestrator/config.py | 5 ++++- orchestrator/signals.py | 32 +++++++++++++++++--------------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/orchestrator/config.py b/orchestrator/config.py index 0beb6c5e..9d3eaea5 100644 --- a/orchestrator/config.py +++ b/orchestrator/config.py @@ -3,9 +3,12 @@ from dataclasses import dataclass, field @dataclass class OrchestratorConfig: - quant_backtest_path: str = "/Users/chenshaojie/Downloads/quant_backtest" + # Must be set to the local quant backtest output directory before use + quant_backtest_path: str = "" trading_agents_config: dict = field(default_factory=dict) quant_weight_cap: float = 0.8 # quant 置信度上限 llm_weight_cap: float = 0.9 # llm 置信度上限 llm_batch_days: int = 7 # LLM 每隔几天运行一次(节省 API) cache_dir: str = "orchestrator/cache" # LLM 信号缓存目录 + llm_solo_penalty: float = 0.7 # LLM 单轨时的置信度折扣 + quant_solo_penalty: float = 0.8 # Quant 单轨时的置信度折扣 diff --git a/orchestrator/signals.py b/orchestrator/signals.py index f9549dcc..1ccecaa3 100644 --- a/orchestrator/signals.py +++ b/orchestrator/signals.py @@ -1,8 +1,10 @@ import logging from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timezone from typing import Optional +from orchestrator.config import OrchestratorConfig + logger = logging.getLogger(__name__) @@ -36,30 +38,27 @@ def _sign(x: float) -> int: class SignalMerger: + def __init__(self, config: OrchestratorConfig) -> None: + self._config = config + def merge(self, quant: Optional[Signal], llm: Optional[Signal]) -> FinalSignal: - now = datetime.utcnow() + now = datetime.now(timezone.utc) # 两者均失败 if quant is None and llm is None: - ticker = "" - return FinalSignal( - ticker=ticker, - direction=0, - confidence=0.0, - quant_signal=None, - llm_signal=None, - timestamp=now, - ) + raise ValueError("both quant and llm signals are None") ticker = (quant or llm).ticker # type: ignore[union-attr] # 只有 LLM(quant 失败) if quant is None: - assert llm is not None + if llm is None: + raise ValueError("llm signal is None when quant is None") return FinalSignal( ticker=ticker, direction=llm.direction, - confidence=llm.confidence * 0.7, + confidence=min(llm.confidence * self._config.llm_solo_penalty, + self._config.llm_weight_cap), quant_signal=None, llm_signal=llm, timestamp=now, @@ -70,7 +69,8 @@ class SignalMerger: return FinalSignal( ticker=ticker, direction=quant.direction, - confidence=quant.confidence * 0.8, + confidence=min(quant.confidence * self._config.quant_solo_penalty, + self._config.quant_weight_cap), quant_signal=quant, llm_signal=None, timestamp=now, @@ -88,7 +88,9 @@ class SignalMerger: ticker, ) total_conf = quant.confidence + llm.confidence - final_confidence = abs(weighted_sum) / total_conf if total_conf > 0 else 0.0 + raw_confidence = abs(weighted_sum) / total_conf if total_conf > 0 else 0.0 + final_confidence = min(raw_confidence, self._config.quant_weight_cap, + self._config.llm_weight_cap) return FinalSignal( ticker=ticker, From 7a03c29330ec178bdeceb239687e15734291f967 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 9 Apr 2026 21:44:34 +0800 Subject: [PATCH 03/14] feat(orchestrator): implement QuantRunner with BollingerStrategy signal generation --- orchestrator/quant_runner.py | 164 +++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 orchestrator/quant_runner.py diff --git a/orchestrator/quant_runner.py b/orchestrator/quant_runner.py new file mode 100644 index 00000000..87102b05 --- /dev/null +++ b/orchestrator/quant_runner.py @@ -0,0 +1,164 @@ +import json +import logging +import sqlite3 +import sys +from datetime import datetime, timezone, timedelta +from typing import Any + +import yfinance as yf + +from orchestrator.config import OrchestratorConfig +from orchestrator.signals import Signal + +logger = logging.getLogger(__name__) + + +class QuantRunner: + def __init__(self, config: OrchestratorConfig): + if not config.quant_backtest_path: + raise ValueError("OrchestratorConfig.quant_backtest_path must be set") + self._config = config + path = config.quant_backtest_path + if path not in sys.path: + sys.path.insert(0, path) + + def get_signal(self, ticker: str, date: str) -> Signal: + """ + 获取指定股票在指定日期的量化信号。 + date 格式:'YYYY-MM-DD' + 返回 Signal(source="quant") + """ + result = self._load_best_params(ticker) + params: dict = result["params"] + sharpe: float = result["sharpe_ratio"] + + # 获取 date 前 60 天的历史数据 + end_dt = datetime.strptime(date, "%Y-%m-%d") + start_dt = end_dt - timedelta(days=60) + start_str = start_dt.strftime("%Y-%m-%d") + + df = yf.download(ticker, start=start_str, end=date, progress=False, auto_adjust=True) + if df.empty: + logger.warning("No price data for %s between %s and %s", ticker, start_str, date) + return Signal( + ticker=ticker, + direction=0, + confidence=0.0, + source="quant", + timestamp=datetime.now(timezone.utc), + metadata={"reason": "no_data"}, + ) + + # 标准化列名为小写 + df.columns = [c[0].lower() if isinstance(c, tuple) else c.lower() for c in df.columns] + + # 用最佳参数创建 BollingerStrategy 实例 + from strategies.momentum import BollingerStrategy + from core.data_models import Bar, OrderDirection + + strategy = BollingerStrategy( + period=params.get("period", 20), + num_std=params.get("num_std", 2.0), + position_pct=params.get("position_pct", 0.20), + stop_loss_pct=params.get("stop_loss_pct", 0.05), + take_profit_pct=params.get("take_profit_pct", 0.15), + ) + + # 逐 bar 喂给策略,模拟历史回放 + direction = 0 + context: dict[str, Any] = {"positions": {}} + + for ts, row in df.iterrows(): + bar = Bar( + symbol=ticker, + timestamp=ts.to_pydatetime() if hasattr(ts, "to_pydatetime") else ts, + open=float(row["open"]), + high=float(row["high"]), + low=float(row["low"]), + close=float(row["close"]), + volume=float(row.get("volume", 0)), + ) + orders = strategy.on_bar(bar, context) + # 更新模拟持仓 + for order in orders: + if order.direction == OrderDirection.BUY: + context["positions"][ticker] = order.volume + elif order.direction == OrderDirection.SELL: + context["positions"][ticker] = 0 + + # 最后一个 bar 的信号 + last_orders = orders if df.shape[0] > 0 else [] + for order in last_orders: + if order.direction == OrderDirection.BUY: + direction = 1 + break + elif order.direction == OrderDirection.SELL: + direction = -1 + break + + # 计算 max_sharpe(从 DB 中取全局最大值) + db_path = f"{self._config.quant_backtest_path}/research_results/runs.db" + try: + conn = sqlite3.connect(db_path) + cur = conn.cursor() + cur.execute("SELECT MAX(sharpe_ratio) FROM backtest_results") + row = cur.fetchone() + max_sharpe = float(row[0]) if row and row[0] is not None else sharpe + conn.close() + except Exception: + max_sharpe = sharpe + + confidence = self._calc_confidence(sharpe, max_sharpe) + + return Signal( + ticker=ticker, + direction=direction, + confidence=confidence, + source="quant", + timestamp=datetime.now(timezone.utc), + metadata={"params": params, "sharpe_ratio": sharpe, "max_sharpe": max_sharpe}, + ) + + def _load_best_params(self, ticker: str) -> dict: + """ + 直接查 SQLite 获取 BollingerStrategy 最佳参数。 + strategy_type 支持 'BollingerStrategy' 和 'bollinger'(兼容两种写法)。 + """ + db_path = f"{self._config.quant_backtest_path}/research_results/runs.db" + conn = sqlite3.connect(db_path) + try: + cur = conn.cursor() + # 先按规格查 'BollingerStrategy',再 fallback 到 'bollinger' + cur.execute( + """ + SELECT params, sharpe_ratio + FROM backtest_results + WHERE strategy_type IN ('BollingerStrategy', 'bollinger') + ORDER BY sharpe_ratio DESC + LIMIT 1 + """, + ) + row = cur.fetchone() + finally: + conn.close() + + if row is None: + raise ValueError( + "No BollingerStrategy results found in ResultStore. " + "Run optimization first: python quant_backtest/run_research.py" + ) + + params = json.loads(row[0]) if isinstance(row[0], str) else row[0] + return {"params": params, "sharpe_ratio": float(row[1])} + + def _calc_confidence(self, sharpe: float, max_sharpe: float) -> float: + """ + Sharpe 归一化为置信度。 + - max_sharpe=0 时返回 0.5(默认值,避免除零) + - sharpe/max_sharpe 上限截断到 1.0 + - 下限截断到 0.0(负 Sharpe 不产生负置信度) + """ + if max_sharpe == 0: + return 0.5 + ratio = sharpe / max_sharpe + return max(0.0, min(1.0, ratio)) From 30d8f9046700bb892fc1785257262eeaa45284f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 9 Apr 2026 21:51:38 +0800 Subject: [PATCH 04/14] fix(quant_runner): fix 3 critical issues and 2 important improvements - Critical 1: initialize orders=[] before loop to prevent NameError when df is empty - Critical 2: replace bare sqlite3 conn with context manager (with statement) in get_signal - Critical 3: remove ticker param from _load_best_params (table has no ticker col, params are global) - Important: extract db_path as self._db_path attribute in __init__ (DRY) - Important: add comment explaining lazy imports require sys.path set in __init__ --- orchestrator/quant_runner.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/orchestrator/quant_runner.py b/orchestrator/quant_runner.py index 87102b05..42d2b8b1 100644 --- a/orchestrator/quant_runner.py +++ b/orchestrator/quant_runner.py @@ -21,6 +21,7 @@ class QuantRunner: path = config.quant_backtest_path if path not in sys.path: sys.path.insert(0, path) + self._db_path = f"{path}/research_results/runs.db" def get_signal(self, ticker: str, date: str) -> Signal: """ @@ -28,7 +29,7 @@ class QuantRunner: date 格式:'YYYY-MM-DD' 返回 Signal(source="quant") """ - result = self._load_best_params(ticker) + result = self._load_best_params() params: dict = result["params"] sharpe: float = result["sharpe_ratio"] @@ -53,6 +54,7 @@ class QuantRunner: df.columns = [c[0].lower() if isinstance(c, tuple) else c.lower() for c in df.columns] # 用最佳参数创建 BollingerStrategy 实例 + # Lazy import: requires quant_backtest_path to be in sys.path (set in __init__) from strategies.momentum import BollingerStrategy from core.data_models import Bar, OrderDirection @@ -66,6 +68,7 @@ class QuantRunner: # 逐 bar 喂给策略,模拟历史回放 direction = 0 + orders: list = [] context: dict[str, Any] = {"positions": {}} for ts, row in df.iterrows(): @@ -97,14 +100,12 @@ class QuantRunner: break # 计算 max_sharpe(从 DB 中取全局最大值) - db_path = f"{self._config.quant_backtest_path}/research_results/runs.db" try: - conn = sqlite3.connect(db_path) - cur = conn.cursor() - cur.execute("SELECT MAX(sharpe_ratio) FROM backtest_results") - row = cur.fetchone() - max_sharpe = float(row[0]) if row and row[0] is not None else sharpe - conn.close() + with sqlite3.connect(self._db_path) as conn: + cur = conn.cursor() + cur.execute("SELECT MAX(sharpe_ratio) FROM backtest_results") + row = cur.fetchone() + max_sharpe = float(row[0]) if row and row[0] is not None else sharpe except Exception: max_sharpe = sharpe @@ -119,14 +120,13 @@ class QuantRunner: metadata={"params": params, "sharpe_ratio": sharpe, "max_sharpe": max_sharpe}, ) - def _load_best_params(self, ticker: str) -> dict: + def _load_best_params(self) -> dict: """ 直接查 SQLite 获取 BollingerStrategy 最佳参数。 + 参数是全局最优,不区分股票(backtest_results 表无 ticker 列,优化是全局的)。 strategy_type 支持 'BollingerStrategy' 和 'bollinger'(兼容两种写法)。 """ - db_path = f"{self._config.quant_backtest_path}/research_results/runs.db" - conn = sqlite3.connect(db_path) - try: + with sqlite3.connect(self._db_path) as conn: cur = conn.cursor() # 先按规格查 'BollingerStrategy',再 fallback 到 'bollinger' cur.execute( @@ -139,8 +139,6 @@ class QuantRunner: """, ) row = cur.fetchone() - finally: - conn.close() if row is None: raise ValueError( From 29aae4bb18d6be5fb8ab93f3e9934f75f5387f4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 9 Apr 2026 21:54:48 +0800 Subject: [PATCH 05/14] feat(orchestrator): implement LLMRunner with caching and rating mapping --- orchestrator/llm_runner.py | 88 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 orchestrator/llm_runner.py diff --git a/orchestrator/llm_runner.py b/orchestrator/llm_runner.py new file mode 100644 index 00000000..527586d9 --- /dev/null +++ b/orchestrator/llm_runner.py @@ -0,0 +1,88 @@ +import json +import logging +import os +from datetime import datetime, timezone + +from orchestrator.config import OrchestratorConfig +from orchestrator.signals import Signal + +logger = logging.getLogger(__name__) + + +class LLMRunner: + def __init__(self, config: OrchestratorConfig): + from tradingagents.graph.trading_graph import TradingAgentsGraph + + trading_cfg = config.trading_agents_config if config.trading_agents_config else None + self.graph = TradingAgentsGraph(config=trading_cfg) + self.cache_dir = config.cache_dir + os.makedirs(self.cache_dir, exist_ok=True) + + def get_signal(self, ticker: str, date: str) -> Signal: + """获取指定股票在指定日期的 LLM 信号,带缓存。""" + cache_path = os.path.join(self.cache_dir, f"{ticker}_{date}.json") + + if os.path.exists(cache_path): + logger.info("LLMRunner: cache hit for %s %s", ticker, date) + with open(cache_path, "r", encoding="utf-8") as f: + data = json.load(f) + direction, confidence = self._map_rating(data["rating"]) + return Signal( + ticker=ticker, + direction=direction, + confidence=confidence, + source="llm", + timestamp=datetime.fromisoformat(data["timestamp"]), + metadata=data, + ) + + try: + _final_state, processed_signal = self.graph.propagate(ticker, date) + rating = processed_signal if isinstance(processed_signal, str) else str(processed_signal) + direction, confidence = self._map_rating(rating) + now = datetime.now(timezone.utc) + + cache_data = { + "rating": rating, + "direction": direction, + "confidence": confidence, + "timestamp": now.isoformat(), + "ticker": ticker, + "date": date, + } + with open(cache_path, "w", encoding="utf-8") as f: + json.dump(cache_data, f, ensure_ascii=False, indent=2) + + return Signal( + ticker=ticker, + direction=direction, + confidence=confidence, + source="llm", + timestamp=now, + metadata=cache_data, + ) + except Exception as e: + logger.error("LLMRunner: propagate failed for %s %s: %s", ticker, date, e) + return Signal( + ticker=ticker, + direction=0, + confidence=0.0, + source="llm", + timestamp=datetime.now(timezone.utc), + metadata={"error": str(e)}, + ) + + def _map_rating(self, rating: str) -> tuple[int, float]: + """将 5 级评级映射为 (direction, confidence)。""" + mapping = { + "BUY": (1, 0.9), + "OVERWEIGHT": (1, 0.6), + "HOLD": (0, 0.5), + "UNDERWEIGHT": (-1, 0.6), + "SELL": (-1, 0.9), + } + result = mapping.get(rating.upper() if rating else "", None) + if result is None: + logger.warning("LLMRunner: unknown rating %r, falling back to HOLD", rating) + return (0, 0.5) + return result From 852b6c98e3c369f367ebf9b9fc3fb5531e9a492a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 9 Apr 2026 21:58:38 +0800 Subject: [PATCH 06/14] feat(orchestrator): implement LLMRunner with lazy graph init and JSON cache --- orchestrator/llm_runner.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/orchestrator/llm_runner.py b/orchestrator/llm_runner.py index 527586d9..6f36e6a3 100644 --- a/orchestrator/llm_runner.py +++ b/orchestrator/llm_runner.py @@ -11,13 +11,19 @@ logger = logging.getLogger(__name__) class LLMRunner: def __init__(self, config: OrchestratorConfig): - from tradingagents.graph.trading_graph import TradingAgentsGraph - - trading_cfg = config.trading_agents_config if config.trading_agents_config else None - self.graph = TradingAgentsGraph(config=trading_cfg) + self._config = config + self._graph = None # Lazy-initialized on first get_signal() call (requires API key) self.cache_dir = config.cache_dir os.makedirs(self.cache_dir, exist_ok=True) + def _get_graph(self): + """Lazy-initialize TradingAgentsGraph (heavy, requires API key at init time).""" + if self._graph is None: + from tradingagents.graph.trading_graph import TradingAgentsGraph + trading_cfg = self._config.trading_agents_config if self._config.trading_agents_config else None + self._graph = TradingAgentsGraph(config=trading_cfg) + return self._graph + def get_signal(self, ticker: str, date: str) -> Signal: """获取指定股票在指定日期的 LLM 信号,带缓存。""" cache_path = os.path.join(self.cache_dir, f"{ticker}_{date}.json") @@ -37,7 +43,7 @@ class LLMRunner: ) try: - _final_state, processed_signal = self.graph.propagate(ticker, date) + _final_state, processed_signal = self._get_graph().propagate(ticker, date) rating = processed_signal if isinstance(processed_signal, str) else str(processed_signal) direction, confidence = self._map_rating(rating) now = datetime.now(timezone.utc) From ba3297a696beaf23fa6d1ee0f9caca8fb2f0b05b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 9 Apr 2026 22:03:17 +0800 Subject: [PATCH 07/14] fix(llm_runner): use stored direction/confidence on cache hit, sanitize ticker path --- orchestrator/llm_runner.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/orchestrator/llm_runner.py b/orchestrator/llm_runner.py index 6f36e6a3..8dcb3c46 100644 --- a/orchestrator/llm_runner.py +++ b/orchestrator/llm_runner.py @@ -26,17 +26,18 @@ class LLMRunner: def get_signal(self, ticker: str, date: str) -> Signal: """获取指定股票在指定日期的 LLM 信号,带缓存。""" - cache_path = os.path.join(self.cache_dir, f"{ticker}_{date}.json") + safe_ticker = ticker.replace("/", "_") # sanitize for filesystem (e.g. BRK/B) + cache_path = os.path.join(self.cache_dir, f"{safe_ticker}_{date}.json") if os.path.exists(cache_path): logger.info("LLMRunner: cache hit for %s %s", ticker, date) with open(cache_path, "r", encoding="utf-8") as f: data = json.load(f) - direction, confidence = self._map_rating(data["rating"]) + # Use stored direction/confidence directly to avoid re-mapping drift return Signal( ticker=ticker, - direction=direction, - confidence=confidence, + direction=data["direction"], + confidence=data["confidence"], source="llm", timestamp=datetime.fromisoformat(data["timestamp"]), metadata=data, From 14191abc297db05ff6c36a3ce6f0031d99d406e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 9 Apr 2026 22:05:03 +0800 Subject: [PATCH 08/14] feat(orchestrator): TradingOrchestrator main class with get_combined_signal --- orchestrator/orchestrator.py | 63 ++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 orchestrator/orchestrator.py diff --git a/orchestrator/orchestrator.py b/orchestrator/orchestrator.py new file mode 100644 index 00000000..baf042eb --- /dev/null +++ b/orchestrator/orchestrator.py @@ -0,0 +1,63 @@ +import logging +from datetime import datetime, timezone +from typing import Optional + +from orchestrator.config import OrchestratorConfig +from orchestrator.signals import Signal, FinalSignal, SignalMerger +from orchestrator.quant_runner import QuantRunner +from orchestrator.llm_runner import LLMRunner + +logger = logging.getLogger(__name__) + + +class TradingOrchestrator: + def __init__(self, config: OrchestratorConfig): + self._config = config + self._merger = SignalMerger(config) + self._quant: Optional[QuantRunner] = None + self._llm: Optional[LLMRunner] = None + + # Initialize runners (quant requires quant_backtest_path) + if config.quant_backtest_path: + try: + self._quant = QuantRunner(config) + except Exception as e: + logger.warning("TradingOrchestrator: QuantRunner init failed: %s", e) + + self._llm = LLMRunner(config) + + def get_combined_signal(self, ticker: str, date: str) -> FinalSignal: + """ + Get merged signal for ticker on date. + Degradation: + - quant fails (error signal): use llm only with llm_solo_penalty + - llm fails (error signal): use quant only with quant_solo_penalty + - both fail: raises ValueError + """ + quant_sig: Optional[Signal] = None + llm_sig: Optional[Signal] = None + + # Get quant signal + if self._quant is not None: + try: + quant_sig = self._quant.get_signal(ticker, date) + # Treat error signals (confidence=0, direction=0 with error metadata) as None + if quant_sig.metadata.get("error") or quant_sig.metadata.get("reason") == "no_data": + logger.warning("TradingOrchestrator: quant signal degraded for %s %s", ticker, date) + quant_sig = None + except Exception as e: + logger.error("TradingOrchestrator: quant get_signal failed: %s", e) + quant_sig = None + + # Get llm signal + try: + llm_sig = self._llm.get_signal(ticker, date) + if llm_sig.metadata.get("error"): + logger.warning("TradingOrchestrator: llm signal degraded for %s %s", ticker, date) + llm_sig = None + except Exception as e: + logger.error("TradingOrchestrator: llm get_signal failed: %s", e) + llm_sig = None + + # merge raises ValueError if both None + return self._merger.merge(quant_sig, llm_sig) From 928f0691849ea61b1a2cffb78a806f2c0a94acab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 9 Apr 2026 22:07:21 +0800 Subject: [PATCH 09/14] test(orchestrator): unit tests for SignalMerger, LLMRunner._map_rating, QuantRunner._calc_confidence --- orchestrator/tests/__init__.py | 0 orchestrator/tests/test_llm_runner.py | 41 +++++++++ orchestrator/tests/test_quant_runner.py | 65 +++++++++++++ orchestrator/tests/test_signals.py | 117 ++++++++++++++++++++++++ 4 files changed, 223 insertions(+) create mode 100644 orchestrator/tests/__init__.py create mode 100644 orchestrator/tests/test_llm_runner.py create mode 100644 orchestrator/tests/test_quant_runner.py create mode 100644 orchestrator/tests/test_signals.py diff --git a/orchestrator/tests/__init__.py b/orchestrator/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestrator/tests/test_llm_runner.py b/orchestrator/tests/test_llm_runner.py new file mode 100644 index 00000000..a4b7bbeb --- /dev/null +++ b/orchestrator/tests/test_llm_runner.py @@ -0,0 +1,41 @@ +"""Tests for LLMRunner._map_rating().""" +import tempfile +import pytest + +from orchestrator.config import OrchestratorConfig +from orchestrator.llm_runner import LLMRunner + + +@pytest.fixture +def runner(tmp_path): + cfg = OrchestratorConfig(cache_dir=str(tmp_path)) + return LLMRunner(cfg) + + +# All 5 known ratings +@pytest.mark.parametrize("rating,expected", [ + ("BUY", (1, 0.9)), + ("OVERWEIGHT", (1, 0.6)), + ("HOLD", (0, 0.5)), + ("UNDERWEIGHT", (-1, 0.6)), + ("SELL", (-1, 0.9)), +]) +def test_map_rating_known(runner, rating, expected): + assert runner._map_rating(rating) == expected + + +# Unknown rating → (0, 0.5) +def test_map_rating_unknown(runner): + assert runner._map_rating("STRONG_BUY") == (0, 0.5) + + +# Case-insensitive +def test_map_rating_lowercase(runner): + assert runner._map_rating("buy") == (1, 0.9) + assert runner._map_rating("sell") == (-1, 0.9) + assert runner._map_rating("hold") == (0, 0.5) + + +# Empty string → (0, 0.5) +def test_map_rating_empty_string(runner): + assert runner._map_rating("") == (0, 0.5) diff --git a/orchestrator/tests/test_quant_runner.py b/orchestrator/tests/test_quant_runner.py new file mode 100644 index 00000000..73b95da5 --- /dev/null +++ b/orchestrator/tests/test_quant_runner.py @@ -0,0 +1,65 @@ +"""Tests for QuantRunner._calc_confidence().""" +import json +import sqlite3 +import tempfile +import os +import pytest + +from orchestrator.config import OrchestratorConfig +from orchestrator.quant_runner import QuantRunner + + +def _make_runner(tmp_path): + """Create a QuantRunner with a minimal SQLite DB so __init__ succeeds.""" + db_dir = tmp_path / "research_results" + db_dir.mkdir(parents=True) + db_path = db_dir / "runs.db" + + with sqlite3.connect(str(db_path)) as conn: + conn.execute( + """CREATE TABLE backtest_results ( + id INTEGER PRIMARY KEY, + strategy_type TEXT, + params TEXT, + sharpe_ratio REAL + )""" + ) + conn.execute( + "INSERT INTO backtest_results (strategy_type, params, sharpe_ratio) VALUES (?, ?, ?)", + ("BollingerStrategy", json.dumps({"period": 20, "num_std": 2.0, + "position_pct": 0.2, + "stop_loss_pct": 0.05, + "take_profit_pct": 0.15}), 1.5), + ) + + cfg = OrchestratorConfig(quant_backtest_path=str(tmp_path)) + return QuantRunner(cfg) + + +@pytest.fixture +def runner(tmp_path): + return _make_runner(tmp_path) + + +def test_calc_confidence_max_sharpe_zero(runner): + assert runner._calc_confidence(1.0, 0) == 0.5 + + +def test_calc_confidence_half(runner): + result = runner._calc_confidence(1.0, 2.0) + assert result == pytest.approx(0.5) + + +def test_calc_confidence_full(runner): + result = runner._calc_confidence(2.0, 2.0) + assert result == pytest.approx(1.0) + + +def test_calc_confidence_clamped_above(runner): + result = runner._calc_confidence(3.0, 2.0) + assert result == pytest.approx(1.0) + + +def test_calc_confidence_clamped_below(runner): + result = runner._calc_confidence(-1.0, 2.0) + assert result == pytest.approx(0.0) diff --git a/orchestrator/tests/test_signals.py b/orchestrator/tests/test_signals.py new file mode 100644 index 00000000..9e8ebfd8 --- /dev/null +++ b/orchestrator/tests/test_signals.py @@ -0,0 +1,117 @@ +"""Tests for SignalMerger in orchestrator/signals.py.""" +import math +import pytest +from datetime import datetime, timezone + +from orchestrator.config import OrchestratorConfig +from orchestrator.signals import Signal, SignalMerger + + +def _make_signal(ticker="AAPL", direction=1, confidence=0.8, source="quant"): + return Signal( + ticker=ticker, + direction=direction, + confidence=confidence, + source=source, + timestamp=datetime.now(timezone.utc), + ) + + +@pytest.fixture +def merger(): + return SignalMerger(OrchestratorConfig()) + + +# Branch 1: both None → ValueError +def test_merge_both_none_raises(merger): + with pytest.raises(ValueError): + merger.merge(None, None) + + +# Branch 2: quant only +def test_merge_quant_only(merger): + cfg = OrchestratorConfig() + q = _make_signal(direction=1, confidence=0.8, source="quant") + result = merger.merge(q, None) + assert result.direction == 1 + expected_conf = min(0.8 * cfg.quant_solo_penalty, cfg.quant_weight_cap) + assert math.isclose(result.confidence, expected_conf) + assert result.quant_signal is q + assert result.llm_signal is None + + +def test_merge_quant_only_capped(merger): + cfg = OrchestratorConfig() + # confidence=1.0 * quant_solo_penalty=0.8 → 0.8 == quant_weight_cap=0.8, no clamp needed + q = _make_signal(direction=-1, confidence=1.0, source="quant") + result = merger.merge(q, None) + expected_conf = min(1.0 * cfg.quant_solo_penalty, cfg.quant_weight_cap) + assert math.isclose(result.confidence, expected_conf) + assert result.direction == -1 + + +# Branch 3: llm only +def test_merge_llm_only(merger): + cfg = OrchestratorConfig() + l = _make_signal(direction=-1, confidence=0.9, source="llm") + result = merger.merge(None, l) + assert result.direction == -1 + expected_conf = min(0.9 * cfg.llm_solo_penalty, cfg.llm_weight_cap) + assert math.isclose(result.confidence, expected_conf) + assert result.llm_signal is l + assert result.quant_signal is None + + +def test_merge_llm_only_capped(merger): + cfg = OrchestratorConfig() + # Force cap: confidence=1.0, llm_solo_penalty=0.7 → 0.7 < llm_weight_cap=0.9, no cap + l = _make_signal(direction=1, confidence=1.0, source="llm") + result = merger.merge(None, l) + expected_conf = min(1.0 * cfg.llm_solo_penalty, cfg.llm_weight_cap) + assert math.isclose(result.confidence, expected_conf) + + +# Branch 4: both present, same direction +def test_merge_both_same_direction(merger): + cfg = OrchestratorConfig() + q = _make_signal(direction=1, confidence=0.6, source="quant") + l = _make_signal(direction=1, confidence=0.8, source="llm") + result = merger.merge(q, l) + assert result.direction == 1 + weighted_sum = 1 * 0.6 + 1 * 0.8 # 1.4 + total_conf = 0.6 + 0.8 # 1.4 + raw_conf = abs(weighted_sum) / total_conf # 1.0 + # actual code caps at min(raw, quant_weight_cap, llm_weight_cap) + expected_conf = min(raw_conf, cfg.quant_weight_cap, cfg.llm_weight_cap) + assert math.isclose(result.confidence, expected_conf) + + +# Branch 5: both present, opposite direction +def test_merge_both_opposite_direction_quant_wins(merger): + cfg = OrchestratorConfig() + # quant stronger: direction should be quant's + q = _make_signal(direction=1, confidence=0.9, source="quant") + l = _make_signal(direction=-1, confidence=0.3, source="llm") + result = merger.merge(q, l) + weighted_sum = 1 * 0.9 + (-1) * 0.3 # 0.6 + assert result.direction == 1 + total_conf = 0.9 + 0.3 + raw_conf = abs(weighted_sum) / total_conf + expected_conf = min(raw_conf, cfg.quant_weight_cap, cfg.llm_weight_cap) + assert math.isclose(result.confidence, expected_conf) + + +def test_merge_both_opposite_direction_llm_wins(merger): + q = _make_signal(direction=1, confidence=0.2, source="quant") + l = _make_signal(direction=-1, confidence=0.8, source="llm") + result = merger.merge(q, l) + assert result.direction == -1 + + +# weighted_sum=0 → direction=HOLD +def test_merge_weighted_sum_zero(merger): + q = _make_signal(direction=1, confidence=0.5, source="quant") + l = _make_signal(direction=-1, confidence=0.5, source="llm") + result = merger.merge(q, l) + assert result.direction == 0 + assert math.isclose(result.confidence, 0.0) From 724c447720127d550fc1698fea84418fd29da920 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 9 Apr 2026 22:09:38 +0800 Subject: [PATCH 10/14] feat(orchestrator): BacktestMode for historical signal collection --- orchestrator/backtest_mode.py | 65 +++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 orchestrator/backtest_mode.py diff --git a/orchestrator/backtest_mode.py b/orchestrator/backtest_mode.py new file mode 100644 index 00000000..a0e2488e --- /dev/null +++ b/orchestrator/backtest_mode.py @@ -0,0 +1,65 @@ +import logging +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import List, Optional + +from orchestrator.config import OrchestratorConfig +from orchestrator.signals import FinalSignal + +logger = logging.getLogger(__name__) + + +@dataclass +class BacktestResult: + records: List[dict] = field(default_factory=list) + summary: dict = field(default_factory=dict) + + +class BacktestMode: + def __init__(self, orchestrator): + self._orchestrator = orchestrator + + def run(self, tickers: List[str], start_date: str, end_date: str) -> BacktestResult: + start = datetime.strptime(start_date, "%Y-%m-%d") + end = datetime.strptime(end_date, "%Y-%m-%d") + + records = [] + current = start + while current <= end: + if current.weekday() < 5: # skip weekends + date_str = current.strftime("%Y-%m-%d") + for ticker in tickers: + try: + sig = self._orchestrator.get_combined_signal(ticker, date_str) + records.append({ + "ticker": ticker, + "date": date_str, + "direction": sig.direction, + "confidence": sig.confidence, + "quant_direction": sig.quant_signal.direction if sig.quant_signal else None, + "llm_direction": sig.llm_signal.direction if sig.llm_signal else None, + }) + except Exception as e: + logger.error("BacktestMode: failed for %s %s: %s", ticker, date_str, e) + current += timedelta(days=1) + + summary = self._compute_summary(records, tickers) + return BacktestResult(records=records, summary=summary) + + def _compute_summary(self, records: List[dict], tickers: List[str]) -> dict: + summary = {} + for ticker in tickers: + ticker_records = [r for r in records if r["ticker"] == ticker] + if not ticker_records: + summary[ticker] = {"total_days": 0} + continue + directions = [r["direction"] for r in ticker_records] + confidences = [r["confidence"] for r in ticker_records] + summary[ticker] = { + "total_days": len(ticker_records), + "buy_days": directions.count(1), + "sell_days": directions.count(-1), + "hold_days": directions.count(0), + "avg_confidence": sum(confidences) / len(confidences), + } + return summary From 480f0299b050f078283bd13abd697afbb7a3a76b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 9 Apr 2026 22:10:15 +0800 Subject: [PATCH 11/14] feat(orchestrator): LiveMode + /ws/orchestrator WebSocket endpoint --- orchestrator/live_mode.py | 47 +++++++++++++++++++++++++++++++++++ web_dashboard/backend/main.py | 34 +++++++++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 orchestrator/live_mode.py diff --git a/orchestrator/live_mode.py b/orchestrator/live_mode.py new file mode 100644 index 00000000..b96b5e04 --- /dev/null +++ b/orchestrator/live_mode.py @@ -0,0 +1,47 @@ +import asyncio +import json +import logging +from datetime import datetime, timezone +from typing import List, Optional + +logger = logging.getLogger(__name__) + + +class LiveMode: + """ + Triggers signal computation for a list of tickers and broadcasts + results via a callback (e.g., WebSocket send). + """ + + def __init__(self, orchestrator): + self._orchestrator = orchestrator + + async def run_once(self, tickers: List[str], date: Optional[str] = None) -> List[dict]: + """ + Compute combined signals for all tickers on the given date (default: today). + Returns list of signal dicts. + """ + if date is None: + date = datetime.now(timezone.utc).strftime("%Y-%m-%d") + + results = [] + for ticker in tickers: + try: + sig = self._orchestrator.get_combined_signal(ticker, date) + results.append({ + "ticker": ticker, + "date": date, + "direction": sig.direction, + "confidence": sig.confidence, + "quant_direction": sig.quant_signal.direction if sig.quant_signal else None, + "llm_direction": sig.llm_signal.direction if sig.llm_signal else None, + "timestamp": sig.timestamp.isoformat(), + }) + except Exception as e: + logger.error("LiveMode: failed for %s %s: %s", ticker, date, e) + results.append({ + "ticker": ticker, + "date": date, + "error": str(e), + }) + return results diff --git a/web_dashboard/backend/main.py b/web_dashboard/backend/main.py index 05c70daa..229b2852 100644 --- a/web_dashboard/backend/main.py +++ b/web_dashboard/backend/main.py @@ -1100,6 +1100,40 @@ async def root(): return {"message": "TradingAgents Web Dashboard API", "version": "0.1.0"} +@app.websocket("/ws/orchestrator") +async def ws_orchestrator(websocket: WebSocket): + """WebSocket endpoint for orchestrator live signals.""" + await websocket.accept() + try: + while True: + data = await websocket.receive_text() + payload = json.loads(data) + tickers = payload.get("tickers", []) + date = payload.get("date") + + # Lazy import to avoid loading heavy deps at startup + import sys + sys.path.insert(0, str(REPO_ROOT)) + from orchestrator.config import OrchestratorConfig + from orchestrator.orchestrator import TradingOrchestrator + from orchestrator.live_mode import LiveMode + + config = OrchestratorConfig( + quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""), + ) + orchestrator = TradingOrchestrator(config) + live = LiveMode(orchestrator) + results = await live.run_once(tickers, date) + await websocket.send_text(json.dumps({"signals": results})) + except WebSocketDisconnect: + pass + except Exception as e: + try: + await websocket.send_text(json.dumps({"error": str(e)})) + except Exception: + pass + + if __name__ == "__main__": import uvicorn # Run with: cd web_dashboard && ../env312/bin/python -m uvicorn main:app --reload From ce2e6d32cc6072d1bc4134a9ef9fdc1c2e594fbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 9 Apr 2026 22:12:02 +0800 Subject: [PATCH 12/14] feat(orchestrator): example scripts for backtest and live mode --- orchestrator/examples/__init__.py | 0 orchestrator/examples/run_backtest.py | 41 +++++++++++++++++++++++++++ orchestrator/examples/run_live.py | 41 +++++++++++++++++++++++++++ 3 files changed, 82 insertions(+) create mode 100644 orchestrator/examples/__init__.py create mode 100644 orchestrator/examples/run_backtest.py create mode 100644 orchestrator/examples/run_live.py diff --git a/orchestrator/examples/__init__.py b/orchestrator/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestrator/examples/run_backtest.py b/orchestrator/examples/run_backtest.py new file mode 100644 index 00000000..f7f5bc49 --- /dev/null +++ b/orchestrator/examples/run_backtest.py @@ -0,0 +1,41 @@ +""" +Example: Run orchestrator backtest for 宁德时代 (300750.SZ) over 2023. + +Usage: + cd /path/to/TradingAgents + QUANT_BACKTEST_PATH=/path/to/quant_backtest python orchestrator/examples/run_backtest.py +""" +import json +import logging +import os +import sys + +# Add repo root to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from orchestrator.config import OrchestratorConfig +from orchestrator.orchestrator import TradingOrchestrator +from orchestrator.backtest_mode import BacktestMode + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") + +config = OrchestratorConfig( + quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""), + cache_dir="orchestrator/cache", +) + +orchestrator = TradingOrchestrator(config) +backtest = BacktestMode(orchestrator) + +result = backtest.run( + tickers=["300750.SZ"], + start_date="2023-01-01", + end_date="2023-12-31", +) + +print(f"\n=== Backtest Summary ===") +print(json.dumps(result.summary, indent=2, ensure_ascii=False)) +print(f"\nTotal records: {len(result.records)}") +if result.records: + print(f"First record: {result.records[0]}") + print(f"Last record: {result.records[-1]}") diff --git a/orchestrator/examples/run_live.py b/orchestrator/examples/run_live.py new file mode 100644 index 00000000..4a652bfa --- /dev/null +++ b/orchestrator/examples/run_live.py @@ -0,0 +1,41 @@ +""" +Example: Run orchestrator live mode for a list of tickers. + +Usage: + cd /path/to/TradingAgents + QUANT_BACKTEST_PATH=/path/to/quant_backtest python orchestrator/examples/run_live.py +""" +import asyncio +import json +import logging +import os +import sys +from datetime import datetime, timezone + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from orchestrator.config import OrchestratorConfig +from orchestrator.orchestrator import TradingOrchestrator +from orchestrator.live_mode import LiveMode + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") + +TICKERS = ["300750.SZ", "603259.SS"] + +config = OrchestratorConfig( + quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""), + cache_dir="orchestrator/cache", +) + +orchestrator = TradingOrchestrator(config) +live = LiveMode(orchestrator) + + +async def main(): + today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + print(f"\n=== Live Signals for {today} ===") + results = await live.run_once(TICKERS, date=today) + print(json.dumps(results, indent=2, ensure_ascii=False)) + + +asyncio.run(main()) From 28a95f34a77545f579ec67bbebec7114d701a74f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 9 Apr 2026 22:55:36 +0800 Subject: [PATCH 13/14] =?UTF-8?q?fix(review):=20api=5Fkey=E2=86=92anthropi?= =?UTF-8?q?c=5Fkey=20bug,=20sync-in-async=20event=20loop=20block,=20orches?= =?UTF-8?q?trator=20per-message=20re-init,=20dead=20code=20cleanup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- orchestrator/backtest_mode.py | 3 +-- orchestrator/live_mode.py | 5 +++-- orchestrator/signals.py | 2 -- web_dashboard/backend/main.py | 26 +++++++++++++------------- 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/orchestrator/backtest_mode.py b/orchestrator/backtest_mode.py index a0e2488e..604b81d2 100644 --- a/orchestrator/backtest_mode.py +++ b/orchestrator/backtest_mode.py @@ -1,9 +1,8 @@ import logging from dataclasses import dataclass, field from datetime import datetime, timedelta -from typing import List, Optional +from typing import List -from orchestrator.config import OrchestratorConfig from orchestrator.signals import FinalSignal logger = logging.getLogger(__name__) diff --git a/orchestrator/live_mode.py b/orchestrator/live_mode.py index b96b5e04..76c04c51 100644 --- a/orchestrator/live_mode.py +++ b/orchestrator/live_mode.py @@ -1,5 +1,4 @@ import asyncio -import json import logging from datetime import datetime, timezone from typing import List, Optional @@ -27,7 +26,9 @@ class LiveMode: results = [] for ticker in tickers: try: - sig = self._orchestrator.get_combined_signal(ticker, date) + sig = await asyncio.to_thread( + self._orchestrator.get_combined_signal, ticker, date + ) results.append({ "ticker": ticker, "date": date, diff --git a/orchestrator/signals.py b/orchestrator/signals.py index 1ccecaa3..0715409d 100644 --- a/orchestrator/signals.py +++ b/orchestrator/signals.py @@ -52,8 +52,6 @@ class SignalMerger: # 只有 LLM(quant 失败) if quant is None: - if llm is None: - raise ValueError("llm signal is None when quant is None") return FinalSignal( ticker=ticker, direction=llm.direction, diff --git a/web_dashboard/backend/main.py b/web_dashboard/backend/main.py index 229b2852..e27d2671 100644 --- a/web_dashboard/backend/main.py +++ b/web_dashboard/backend/main.py @@ -363,7 +363,7 @@ async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Head # Use clean environment - don't inherit parent env clean_env = {k: v for k, v in os.environ.items() if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))} - clean_env["ANTHROPIC_API_KEY"] = api_key + clean_env["ANTHROPIC_API_KEY"] = anthropic_key clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic" proc = await asyncio.create_subprocess_exec( @@ -1103,6 +1103,18 @@ async def root(): @app.websocket("/ws/orchestrator") async def ws_orchestrator(websocket: WebSocket): """WebSocket endpoint for orchestrator live signals.""" + import sys + sys.path.insert(0, str(REPO_ROOT)) + from orchestrator.config import OrchestratorConfig + from orchestrator.orchestrator import TradingOrchestrator + from orchestrator.live_mode import LiveMode + + config = OrchestratorConfig( + quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""), + ) + orchestrator = TradingOrchestrator(config) + live = LiveMode(orchestrator) + await websocket.accept() try: while True: @@ -1111,18 +1123,6 @@ async def ws_orchestrator(websocket: WebSocket): tickers = payload.get("tickers", []) date = payload.get("date") - # Lazy import to avoid loading heavy deps at startup - import sys - sys.path.insert(0, str(REPO_ROOT)) - from orchestrator.config import OrchestratorConfig - from orchestrator.orchestrator import TradingOrchestrator - from orchestrator.live_mode import LiveMode - - config = OrchestratorConfig( - quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""), - ) - orchestrator = TradingOrchestrator(config) - live = LiveMode(orchestrator) results = await live.run_once(tickers, date) await websocket.send_text(json.dumps({"signals": results})) except WebSocketDisconnect: From b50e5b47253fa9bcccc335b85cf008a63c4802c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 9 Apr 2026 23:00:20 +0800 Subject: [PATCH 14/14] fix(review): hmac.compare_digest for API key, ws/orchestrator auth, SignalMerger per-signal cap logic --- orchestrator/signals.py | 13 +++++++------ orchestrator/tests/test_signals.py | 21 ++++++++++++--------- web_dashboard/backend/main.py | 12 ++++++++++-- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/orchestrator/signals.py b/orchestrator/signals.py index 0715409d..7283c725 100644 --- a/orchestrator/signals.py +++ b/orchestrator/signals.py @@ -75,9 +75,12 @@ class SignalMerger: ) # 两者都有:加权合并 + # Cap each signal's contribution before merging + quant_conf = min(quant.confidence, self._config.quant_weight_cap) + llm_conf = min(llm.confidence, self._config.llm_weight_cap) weighted_sum = ( - quant.direction * quant.confidence - + llm.direction * llm.confidence + quant.direction * quant_conf + + llm.direction * llm_conf ) final_direction = _sign(weighted_sum) if final_direction == 0: @@ -85,10 +88,8 @@ class SignalMerger: "SignalMerger: weighted_sum=0 for %s — signals cancel out, HOLD", ticker, ) - total_conf = quant.confidence + llm.confidence - raw_confidence = abs(weighted_sum) / total_conf if total_conf > 0 else 0.0 - final_confidence = min(raw_confidence, self._config.quant_weight_cap, - self._config.llm_weight_cap) + total_conf = quant_conf + llm_conf + final_confidence = abs(weighted_sum) / total_conf if total_conf > 0 else 0.0 return FinalSignal( ticker=ticker, diff --git a/orchestrator/tests/test_signals.py b/orchestrator/tests/test_signals.py index 9e8ebfd8..bbd5b2aa 100644 --- a/orchestrator/tests/test_signals.py +++ b/orchestrator/tests/test_signals.py @@ -78,11 +78,12 @@ def test_merge_both_same_direction(merger): l = _make_signal(direction=1, confidence=0.8, source="llm") result = merger.merge(q, l) assert result.direction == 1 - weighted_sum = 1 * 0.6 + 1 * 0.8 # 1.4 - total_conf = 0.6 + 0.8 # 1.4 - raw_conf = abs(weighted_sum) / total_conf # 1.0 - # actual code caps at min(raw, quant_weight_cap, llm_weight_cap) - expected_conf = min(raw_conf, cfg.quant_weight_cap, cfg.llm_weight_cap) + # caps applied per-signal before merging + quant_conf = min(0.6, cfg.quant_weight_cap) # 0.6 + llm_conf = min(0.8, cfg.llm_weight_cap) # 0.8 + weighted_sum = 1 * quant_conf + 1 * llm_conf # 1.4 + total_conf = quant_conf + llm_conf # 1.4 + expected_conf = abs(weighted_sum) / total_conf # 1.0 assert math.isclose(result.confidence, expected_conf) @@ -93,11 +94,13 @@ def test_merge_both_opposite_direction_quant_wins(merger): q = _make_signal(direction=1, confidence=0.9, source="quant") l = _make_signal(direction=-1, confidence=0.3, source="llm") result = merger.merge(q, l) - weighted_sum = 1 * 0.9 + (-1) * 0.3 # 0.6 assert result.direction == 1 - total_conf = 0.9 + 0.3 - raw_conf = abs(weighted_sum) / total_conf - expected_conf = min(raw_conf, cfg.quant_weight_cap, cfg.llm_weight_cap) + # caps applied per-signal before merging + quant_conf = min(0.9, cfg.quant_weight_cap) # 0.8 + llm_conf = min(0.3, cfg.llm_weight_cap) # 0.3 + weighted_sum = 1 * quant_conf + (-1) * llm_conf # 0.5 + total_conf = quant_conf + llm_conf # 1.1 + expected_conf = abs(weighted_sum) / total_conf assert math.isclose(result.confidence, expected_conf) diff --git a/web_dashboard/backend/main.py b/web_dashboard/backend/main.py index e27d2671..2e859540 100644 --- a/web_dashboard/backend/main.py +++ b/web_dashboard/backend/main.py @@ -4,6 +4,7 @@ FastAPI REST API + WebSocket for real-time analysis progress """ import asyncio import fcntl +import hmac import json import os import subprocess @@ -105,7 +106,9 @@ def _check_api_key(api_key: Optional[str]) -> bool: required = _get_api_key() if not required: return True - return api_key == required + if not api_key: + return False + return hmac.compare_digest(api_key, required) def _auth_error(): raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required") @@ -1101,8 +1104,13 @@ async def root(): @app.websocket("/ws/orchestrator") -async def ws_orchestrator(websocket: WebSocket): +async def ws_orchestrator(websocket: WebSocket, api_key: Optional[str] = None): """WebSocket endpoint for orchestrator live signals.""" + # Auth check before accepting — reject unauthenticated connections + if not _check_api_key(api_key): + await websocket.close(code=4401) + return + import sys sys.path.insert(0, str(REPO_ROOT)) from orchestrator.config import OrchestratorConfig