feat(orchestrator): merge orchestrator module into main
This commit is contained in:
commit
8960fdf321
|
|
@ -0,0 +1,64 @@
|
|||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
|
||||
from orchestrator.signals import FinalSignal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BacktestResult:
|
||||
records: List[dict] = field(default_factory=list)
|
||||
summary: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class BacktestMode:
|
||||
def __init__(self, orchestrator):
|
||||
self._orchestrator = orchestrator
|
||||
|
||||
def run(self, tickers: List[str], start_date: str, end_date: str) -> BacktestResult:
|
||||
start = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
records = []
|
||||
current = start
|
||||
while current <= end:
|
||||
if current.weekday() < 5: # skip weekends
|
||||
date_str = current.strftime("%Y-%m-%d")
|
||||
for ticker in tickers:
|
||||
try:
|
||||
sig = self._orchestrator.get_combined_signal(ticker, date_str)
|
||||
records.append({
|
||||
"ticker": ticker,
|
||||
"date": date_str,
|
||||
"direction": sig.direction,
|
||||
"confidence": sig.confidence,
|
||||
"quant_direction": sig.quant_signal.direction if sig.quant_signal else None,
|
||||
"llm_direction": sig.llm_signal.direction if sig.llm_signal else None,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("BacktestMode: failed for %s %s: %s", ticker, date_str, e)
|
||||
current += timedelta(days=1)
|
||||
|
||||
summary = self._compute_summary(records, tickers)
|
||||
return BacktestResult(records=records, summary=summary)
|
||||
|
||||
def _compute_summary(self, records: List[dict], tickers: List[str]) -> dict:
|
||||
summary = {}
|
||||
for ticker in tickers:
|
||||
ticker_records = [r for r in records if r["ticker"] == ticker]
|
||||
if not ticker_records:
|
||||
summary[ticker] = {"total_days": 0}
|
||||
continue
|
||||
directions = [r["direction"] for r in ticker_records]
|
||||
confidences = [r["confidence"] for r in ticker_records]
|
||||
summary[ticker] = {
|
||||
"total_days": len(ticker_records),
|
||||
"buy_days": directions.count(1),
|
||||
"sell_days": directions.count(-1),
|
||||
"hold_days": directions.count(0),
|
||||
"avg_confidence": sum(confidences) / len(confidences),
|
||||
}
|
||||
return summary
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrchestratorConfig:
|
||||
# Must be set to the local quant backtest output directory before use
|
||||
quant_backtest_path: str = ""
|
||||
trading_agents_config: dict = field(default_factory=dict)
|
||||
quant_weight_cap: float = 0.8 # quant 置信度上限
|
||||
llm_weight_cap: float = 0.9 # llm 置信度上限
|
||||
llm_batch_days: int = 7 # LLM 每隔几天运行一次(节省 API)
|
||||
cache_dir: str = "orchestrator/cache" # LLM 信号缓存目录
|
||||
llm_solo_penalty: float = 0.7 # LLM 单轨时的置信度折扣
|
||||
quant_solo_penalty: float = 0.8 # Quant 单轨时的置信度折扣
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
"""
|
||||
Example: Run orchestrator backtest for 宁德时代 (300750.SZ) over 2023.
|
||||
|
||||
Usage:
|
||||
cd /path/to/TradingAgents
|
||||
QUANT_BACKTEST_PATH=/path/to/quant_backtest python orchestrator/examples/run_backtest.py
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add repo root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.orchestrator import TradingOrchestrator
|
||||
from orchestrator.backtest_mode import BacktestMode
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
|
||||
config = OrchestratorConfig(
|
||||
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
|
||||
cache_dir="orchestrator/cache",
|
||||
)
|
||||
|
||||
orchestrator = TradingOrchestrator(config)
|
||||
backtest = BacktestMode(orchestrator)
|
||||
|
||||
result = backtest.run(
|
||||
tickers=["300750.SZ"],
|
||||
start_date="2023-01-01",
|
||||
end_date="2023-12-31",
|
||||
)
|
||||
|
||||
print(f"\n=== Backtest Summary ===")
|
||||
print(json.dumps(result.summary, indent=2, ensure_ascii=False))
|
||||
print(f"\nTotal records: {len(result.records)}")
|
||||
if result.records:
|
||||
print(f"First record: {result.records[0]}")
|
||||
print(f"Last record: {result.records[-1]}")
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
"""
|
||||
Example: Run orchestrator live mode for a list of tickers.
|
||||
|
||||
Usage:
|
||||
cd /path/to/TradingAgents
|
||||
QUANT_BACKTEST_PATH=/path/to/quant_backtest python orchestrator/examples/run_live.py
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.orchestrator import TradingOrchestrator
|
||||
from orchestrator.live_mode import LiveMode
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
|
||||
TICKERS = ["300750.SZ", "603259.SS"]
|
||||
|
||||
config = OrchestratorConfig(
|
||||
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
|
||||
cache_dir="orchestrator/cache",
|
||||
)
|
||||
|
||||
orchestrator = TradingOrchestrator(config)
|
||||
live = LiveMode(orchestrator)
|
||||
|
||||
|
||||
async def main():
|
||||
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
print(f"\n=== Live Signals for {today} ===")
|
||||
results = await live.run_once(TICKERS, date=today)
|
||||
print(json.dumps(results, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LiveMode:
|
||||
"""
|
||||
Triggers signal computation for a list of tickers and broadcasts
|
||||
results via a callback (e.g., WebSocket send).
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator):
|
||||
self._orchestrator = orchestrator
|
||||
|
||||
async def run_once(self, tickers: List[str], date: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
Compute combined signals for all tickers on the given date (default: today).
|
||||
Returns list of signal dicts.
|
||||
"""
|
||||
if date is None:
|
||||
date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
results = []
|
||||
for ticker in tickers:
|
||||
try:
|
||||
sig = await asyncio.to_thread(
|
||||
self._orchestrator.get_combined_signal, ticker, date
|
||||
)
|
||||
results.append({
|
||||
"ticker": ticker,
|
||||
"date": date,
|
||||
"direction": sig.direction,
|
||||
"confidence": sig.confidence,
|
||||
"quant_direction": sig.quant_signal.direction if sig.quant_signal else None,
|
||||
"llm_direction": sig.llm_signal.direction if sig.llm_signal else None,
|
||||
"timestamp": sig.timestamp.isoformat(),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("LiveMode: failed for %s %s: %s", ticker, date, e)
|
||||
results.append({
|
||||
"ticker": ticker,
|
||||
"date": date,
|
||||
"error": str(e),
|
||||
})
|
||||
return results
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.signals import Signal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMRunner:
|
||||
def __init__(self, config: OrchestratorConfig):
|
||||
self._config = config
|
||||
self._graph = None # Lazy-initialized on first get_signal() call (requires API key)
|
||||
self.cache_dir = config.cache_dir
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
|
||||
def _get_graph(self):
|
||||
"""Lazy-initialize TradingAgentsGraph (heavy, requires API key at init time)."""
|
||||
if self._graph is None:
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
trading_cfg = self._config.trading_agents_config if self._config.trading_agents_config else None
|
||||
self._graph = TradingAgentsGraph(config=trading_cfg)
|
||||
return self._graph
|
||||
|
||||
def get_signal(self, ticker: str, date: str) -> Signal:
|
||||
"""获取指定股票在指定日期的 LLM 信号,带缓存。"""
|
||||
safe_ticker = ticker.replace("/", "_") # sanitize for filesystem (e.g. BRK/B)
|
||||
cache_path = os.path.join(self.cache_dir, f"{safe_ticker}_{date}.json")
|
||||
|
||||
if os.path.exists(cache_path):
|
||||
logger.info("LLMRunner: cache hit for %s %s", ticker, date)
|
||||
with open(cache_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
# Use stored direction/confidence directly to avoid re-mapping drift
|
||||
return Signal(
|
||||
ticker=ticker,
|
||||
direction=data["direction"],
|
||||
confidence=data["confidence"],
|
||||
source="llm",
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]),
|
||||
metadata=data,
|
||||
)
|
||||
|
||||
try:
|
||||
_final_state, processed_signal = self._get_graph().propagate(ticker, date)
|
||||
rating = processed_signal if isinstance(processed_signal, str) else str(processed_signal)
|
||||
direction, confidence = self._map_rating(rating)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
cache_data = {
|
||||
"rating": rating,
|
||||
"direction": direction,
|
||||
"confidence": confidence,
|
||||
"timestamp": now.isoformat(),
|
||||
"ticker": ticker,
|
||||
"date": date,
|
||||
}
|
||||
with open(cache_path, "w", encoding="utf-8") as f:
|
||||
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
return Signal(
|
||||
ticker=ticker,
|
||||
direction=direction,
|
||||
confidence=confidence,
|
||||
source="llm",
|
||||
timestamp=now,
|
||||
metadata=cache_data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("LLMRunner: propagate failed for %s %s: %s", ticker, date, e)
|
||||
return Signal(
|
||||
ticker=ticker,
|
||||
direction=0,
|
||||
confidence=0.0,
|
||||
source="llm",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata={"error": str(e)},
|
||||
)
|
||||
|
||||
def _map_rating(self, rating: str) -> tuple[int, float]:
|
||||
"""将 5 级评级映射为 (direction, confidence)。"""
|
||||
mapping = {
|
||||
"BUY": (1, 0.9),
|
||||
"OVERWEIGHT": (1, 0.6),
|
||||
"HOLD": (0, 0.5),
|
||||
"UNDERWEIGHT": (-1, 0.6),
|
||||
"SELL": (-1, 0.9),
|
||||
}
|
||||
result = mapping.get(rating.upper() if rating else "", None)
|
||||
if result is None:
|
||||
logger.warning("LLMRunner: unknown rating %r, falling back to HOLD", rating)
|
||||
return (0, 0.5)
|
||||
return result
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.signals import Signal, FinalSignal, SignalMerger
|
||||
from orchestrator.quant_runner import QuantRunner
|
||||
from orchestrator.llm_runner import LLMRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TradingOrchestrator:
|
||||
def __init__(self, config: OrchestratorConfig):
|
||||
self._config = config
|
||||
self._merger = SignalMerger(config)
|
||||
self._quant: Optional[QuantRunner] = None
|
||||
self._llm: Optional[LLMRunner] = None
|
||||
|
||||
# Initialize runners (quant requires quant_backtest_path)
|
||||
if config.quant_backtest_path:
|
||||
try:
|
||||
self._quant = QuantRunner(config)
|
||||
except Exception as e:
|
||||
logger.warning("TradingOrchestrator: QuantRunner init failed: %s", e)
|
||||
|
||||
self._llm = LLMRunner(config)
|
||||
|
||||
def get_combined_signal(self, ticker: str, date: str) -> FinalSignal:
|
||||
"""
|
||||
Get merged signal for ticker on date.
|
||||
Degradation:
|
||||
- quant fails (error signal): use llm only with llm_solo_penalty
|
||||
- llm fails (error signal): use quant only with quant_solo_penalty
|
||||
- both fail: raises ValueError
|
||||
"""
|
||||
quant_sig: Optional[Signal] = None
|
||||
llm_sig: Optional[Signal] = None
|
||||
|
||||
# Get quant signal
|
||||
if self._quant is not None:
|
||||
try:
|
||||
quant_sig = self._quant.get_signal(ticker, date)
|
||||
# Treat error signals (confidence=0, direction=0 with error metadata) as None
|
||||
if quant_sig.metadata.get("error") or quant_sig.metadata.get("reason") == "no_data":
|
||||
logger.warning("TradingOrchestrator: quant signal degraded for %s %s", ticker, date)
|
||||
quant_sig = None
|
||||
except Exception as e:
|
||||
logger.error("TradingOrchestrator: quant get_signal failed: %s", e)
|
||||
quant_sig = None
|
||||
|
||||
# Get llm signal
|
||||
try:
|
||||
llm_sig = self._llm.get_signal(ticker, date)
|
||||
if llm_sig.metadata.get("error"):
|
||||
logger.warning("TradingOrchestrator: llm signal degraded for %s %s", ticker, date)
|
||||
llm_sig = None
|
||||
except Exception as e:
|
||||
logger.error("TradingOrchestrator: llm get_signal failed: %s", e)
|
||||
llm_sig = None
|
||||
|
||||
# merge raises ValueError if both None
|
||||
return self._merger.merge(quant_sig, llm_sig)
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import sys
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Any
|
||||
|
||||
import yfinance as yf
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.signals import Signal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QuantRunner:
|
||||
def __init__(self, config: OrchestratorConfig):
|
||||
if not config.quant_backtest_path:
|
||||
raise ValueError("OrchestratorConfig.quant_backtest_path must be set")
|
||||
self._config = config
|
||||
path = config.quant_backtest_path
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
self._db_path = f"{path}/research_results/runs.db"
|
||||
|
||||
def get_signal(self, ticker: str, date: str) -> Signal:
|
||||
"""
|
||||
获取指定股票在指定日期的量化信号。
|
||||
date 格式:'YYYY-MM-DD'
|
||||
返回 Signal(source="quant")
|
||||
"""
|
||||
result = self._load_best_params()
|
||||
params: dict = result["params"]
|
||||
sharpe: float = result["sharpe_ratio"]
|
||||
|
||||
# 获取 date 前 60 天的历史数据
|
||||
end_dt = datetime.strptime(date, "%Y-%m-%d")
|
||||
start_dt = end_dt - timedelta(days=60)
|
||||
start_str = start_dt.strftime("%Y-%m-%d")
|
||||
|
||||
df = yf.download(ticker, start=start_str, end=date, progress=False, auto_adjust=True)
|
||||
if df.empty:
|
||||
logger.warning("No price data for %s between %s and %s", ticker, start_str, date)
|
||||
return Signal(
|
||||
ticker=ticker,
|
||||
direction=0,
|
||||
confidence=0.0,
|
||||
source="quant",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata={"reason": "no_data"},
|
||||
)
|
||||
|
||||
# 标准化列名为小写
|
||||
df.columns = [c[0].lower() if isinstance(c, tuple) else c.lower() for c in df.columns]
|
||||
|
||||
# 用最佳参数创建 BollingerStrategy 实例
|
||||
# Lazy import: requires quant_backtest_path to be in sys.path (set in __init__)
|
||||
from strategies.momentum import BollingerStrategy
|
||||
from core.data_models import Bar, OrderDirection
|
||||
|
||||
strategy = BollingerStrategy(
|
||||
period=params.get("period", 20),
|
||||
num_std=params.get("num_std", 2.0),
|
||||
position_pct=params.get("position_pct", 0.20),
|
||||
stop_loss_pct=params.get("stop_loss_pct", 0.05),
|
||||
take_profit_pct=params.get("take_profit_pct", 0.15),
|
||||
)
|
||||
|
||||
# 逐 bar 喂给策略,模拟历史回放
|
||||
direction = 0
|
||||
orders: list = []
|
||||
context: dict[str, Any] = {"positions": {}}
|
||||
|
||||
for ts, row in df.iterrows():
|
||||
bar = Bar(
|
||||
symbol=ticker,
|
||||
timestamp=ts.to_pydatetime() if hasattr(ts, "to_pydatetime") else ts,
|
||||
open=float(row["open"]),
|
||||
high=float(row["high"]),
|
||||
low=float(row["low"]),
|
||||
close=float(row["close"]),
|
||||
volume=float(row.get("volume", 0)),
|
||||
)
|
||||
orders = strategy.on_bar(bar, context)
|
||||
# 更新模拟持仓
|
||||
for order in orders:
|
||||
if order.direction == OrderDirection.BUY:
|
||||
context["positions"][ticker] = order.volume
|
||||
elif order.direction == OrderDirection.SELL:
|
||||
context["positions"][ticker] = 0
|
||||
|
||||
# 最后一个 bar 的信号
|
||||
last_orders = orders if df.shape[0] > 0 else []
|
||||
for order in last_orders:
|
||||
if order.direction == OrderDirection.BUY:
|
||||
direction = 1
|
||||
break
|
||||
elif order.direction == OrderDirection.SELL:
|
||||
direction = -1
|
||||
break
|
||||
|
||||
# 计算 max_sharpe(从 DB 中取全局最大值)
|
||||
try:
|
||||
with sqlite3.connect(self._db_path) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT MAX(sharpe_ratio) FROM backtest_results")
|
||||
row = cur.fetchone()
|
||||
max_sharpe = float(row[0]) if row and row[0] is not None else sharpe
|
||||
except Exception:
|
||||
max_sharpe = sharpe
|
||||
|
||||
confidence = self._calc_confidence(sharpe, max_sharpe)
|
||||
|
||||
return Signal(
|
||||
ticker=ticker,
|
||||
direction=direction,
|
||||
confidence=confidence,
|
||||
source="quant",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata={"params": params, "sharpe_ratio": sharpe, "max_sharpe": max_sharpe},
|
||||
)
|
||||
|
||||
def _load_best_params(self) -> dict:
|
||||
"""
|
||||
直接查 SQLite 获取 BollingerStrategy 最佳参数。
|
||||
参数是全局最优,不区分股票(backtest_results 表无 ticker 列,优化是全局的)。
|
||||
strategy_type 支持 'BollingerStrategy' 和 'bollinger'(兼容两种写法)。
|
||||
"""
|
||||
with sqlite3.connect(self._db_path) as conn:
|
||||
cur = conn.cursor()
|
||||
# 先按规格查 'BollingerStrategy',再 fallback 到 'bollinger'
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT params, sharpe_ratio
|
||||
FROM backtest_results
|
||||
WHERE strategy_type IN ('BollingerStrategy', 'bollinger')
|
||||
ORDER BY sharpe_ratio DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if row is None:
|
||||
raise ValueError(
|
||||
"No BollingerStrategy results found in ResultStore. "
|
||||
"Run optimization first: python quant_backtest/run_research.py"
|
||||
)
|
||||
|
||||
params = json.loads(row[0]) if isinstance(row[0], str) else row[0]
|
||||
return {"params": params, "sharpe_ratio": float(row[1])}
|
||||
|
||||
def _calc_confidence(self, sharpe: float, max_sharpe: float) -> float:
|
||||
"""
|
||||
Sharpe 归一化为置信度。
|
||||
- max_sharpe=0 时返回 0.5(默认值,避免除零)
|
||||
- sharpe/max_sharpe 上限截断到 1.0
|
||||
- 下限截断到 0.0(负 Sharpe 不产生负置信度)
|
||||
"""
|
||||
if max_sharpe == 0:
|
||||
return 0.5
|
||||
ratio = sharpe / max_sharpe
|
||||
return max(0.0, min(1.0, ratio))
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Signal:
|
||||
ticker: str
|
||||
direction: int # +1 买入, -1 卖出, 0 持有
|
||||
confidence: float # 0.0 ~ 1.0
|
||||
source: str # "quant" | "llm"
|
||||
timestamp: datetime
|
||||
metadata: dict = field(default_factory=dict) # 原始输出,用于调试
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinalSignal:
|
||||
ticker: str
|
||||
direction: int # sign(quant_dir×quant_conf + llm_dir×llm_conf),sign(0)→0(HOLD)
|
||||
confidence: float # abs(weighted_sum) / total_conf
|
||||
quant_signal: Optional[Signal]
|
||||
llm_signal: Optional[Signal]
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
def _sign(x: float) -> int:
|
||||
"""Return +1, -1, or 0."""
|
||||
if x > 0:
|
||||
return 1
|
||||
elif x < 0:
|
||||
return -1
|
||||
return 0
|
||||
|
||||
|
||||
class SignalMerger:
|
||||
def __init__(self, config: OrchestratorConfig) -> None:
|
||||
self._config = config
|
||||
|
||||
def merge(self, quant: Optional[Signal], llm: Optional[Signal]) -> FinalSignal:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# 两者均失败
|
||||
if quant is None and llm is None:
|
||||
raise ValueError("both quant and llm signals are None")
|
||||
|
||||
ticker = (quant or llm).ticker # type: ignore[union-attr]
|
||||
|
||||
# 只有 LLM(quant 失败)
|
||||
if quant is None:
|
||||
return FinalSignal(
|
||||
ticker=ticker,
|
||||
direction=llm.direction,
|
||||
confidence=min(llm.confidence * self._config.llm_solo_penalty,
|
||||
self._config.llm_weight_cap),
|
||||
quant_signal=None,
|
||||
llm_signal=llm,
|
||||
timestamp=now,
|
||||
)
|
||||
|
||||
# 只有 Quant(llm 失败)
|
||||
if llm is None:
|
||||
return FinalSignal(
|
||||
ticker=ticker,
|
||||
direction=quant.direction,
|
||||
confidence=min(quant.confidence * self._config.quant_solo_penalty,
|
||||
self._config.quant_weight_cap),
|
||||
quant_signal=quant,
|
||||
llm_signal=None,
|
||||
timestamp=now,
|
||||
)
|
||||
|
||||
# 两者都有:加权合并
|
||||
# Cap each signal's contribution before merging
|
||||
quant_conf = min(quant.confidence, self._config.quant_weight_cap)
|
||||
llm_conf = min(llm.confidence, self._config.llm_weight_cap)
|
||||
weighted_sum = (
|
||||
quant.direction * quant_conf
|
||||
+ llm.direction * llm_conf
|
||||
)
|
||||
final_direction = _sign(weighted_sum)
|
||||
if final_direction == 0:
|
||||
logger.info(
|
||||
"SignalMerger: weighted_sum=0 for %s — signals cancel out, HOLD",
|
||||
ticker,
|
||||
)
|
||||
total_conf = quant_conf + llm_conf
|
||||
final_confidence = abs(weighted_sum) / total_conf if total_conf > 0 else 0.0
|
||||
|
||||
return FinalSignal(
|
||||
ticker=ticker,
|
||||
direction=final_direction,
|
||||
confidence=final_confidence,
|
||||
quant_signal=quant,
|
||||
llm_signal=llm,
|
||||
timestamp=now,
|
||||
)
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
"""Tests for LLMRunner._map_rating()."""
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.llm_runner import LLMRunner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(tmp_path):
|
||||
cfg = OrchestratorConfig(cache_dir=str(tmp_path))
|
||||
return LLMRunner(cfg)
|
||||
|
||||
|
||||
# All 5 known ratings
|
||||
@pytest.mark.parametrize("rating,expected", [
|
||||
("BUY", (1, 0.9)),
|
||||
("OVERWEIGHT", (1, 0.6)),
|
||||
("HOLD", (0, 0.5)),
|
||||
("UNDERWEIGHT", (-1, 0.6)),
|
||||
("SELL", (-1, 0.9)),
|
||||
])
|
||||
def test_map_rating_known(runner, rating, expected):
|
||||
assert runner._map_rating(rating) == expected
|
||||
|
||||
|
||||
# Unknown rating → (0, 0.5)
|
||||
def test_map_rating_unknown(runner):
|
||||
assert runner._map_rating("STRONG_BUY") == (0, 0.5)
|
||||
|
||||
|
||||
# Case-insensitive
|
||||
def test_map_rating_lowercase(runner):
|
||||
assert runner._map_rating("buy") == (1, 0.9)
|
||||
assert runner._map_rating("sell") == (-1, 0.9)
|
||||
assert runner._map_rating("hold") == (0, 0.5)
|
||||
|
||||
|
||||
# Empty string → (0, 0.5)
|
||||
def test_map_rating_empty_string(runner):
|
||||
assert runner._map_rating("") == (0, 0.5)
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
"""Tests for QuantRunner._calc_confidence()."""
|
||||
import json
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.quant_runner import QuantRunner
|
||||
|
||||
|
||||
def _make_runner(tmp_path):
|
||||
"""Create a QuantRunner with a minimal SQLite DB so __init__ succeeds."""
|
||||
db_dir = tmp_path / "research_results"
|
||||
db_dir.mkdir(parents=True)
|
||||
db_path = db_dir / "runs.db"
|
||||
|
||||
with sqlite3.connect(str(db_path)) as conn:
|
||||
conn.execute(
|
||||
"""CREATE TABLE backtest_results (
|
||||
id INTEGER PRIMARY KEY,
|
||||
strategy_type TEXT,
|
||||
params TEXT,
|
||||
sharpe_ratio REAL
|
||||
)"""
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO backtest_results (strategy_type, params, sharpe_ratio) VALUES (?, ?, ?)",
|
||||
("BollingerStrategy", json.dumps({"period": 20, "num_std": 2.0,
|
||||
"position_pct": 0.2,
|
||||
"stop_loss_pct": 0.05,
|
||||
"take_profit_pct": 0.15}), 1.5),
|
||||
)
|
||||
|
||||
cfg = OrchestratorConfig(quant_backtest_path=str(tmp_path))
|
||||
return QuantRunner(cfg)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(tmp_path):
|
||||
return _make_runner(tmp_path)
|
||||
|
||||
|
||||
def test_calc_confidence_max_sharpe_zero(runner):
|
||||
assert runner._calc_confidence(1.0, 0) == 0.5
|
||||
|
||||
|
||||
def test_calc_confidence_half(runner):
|
||||
result = runner._calc_confidence(1.0, 2.0)
|
||||
assert result == pytest.approx(0.5)
|
||||
|
||||
|
||||
def test_calc_confidence_full(runner):
|
||||
result = runner._calc_confidence(2.0, 2.0)
|
||||
assert result == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_calc_confidence_clamped_above(runner):
|
||||
result = runner._calc_confidence(3.0, 2.0)
|
||||
assert result == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_calc_confidence_clamped_below(runner):
|
||||
result = runner._calc_confidence(-1.0, 2.0)
|
||||
assert result == pytest.approx(0.0)
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
"""Tests for SignalMerger in orchestrator/signals.py."""
|
||||
import math
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.signals import Signal, SignalMerger
|
||||
|
||||
|
||||
def _make_signal(ticker="AAPL", direction=1, confidence=0.8, source="quant"):
|
||||
return Signal(
|
||||
ticker=ticker,
|
||||
direction=direction,
|
||||
confidence=confidence,
|
||||
source=source,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def merger():
|
||||
return SignalMerger(OrchestratorConfig())
|
||||
|
||||
|
||||
# Branch 1: both None → ValueError
|
||||
def test_merge_both_none_raises(merger):
|
||||
with pytest.raises(ValueError):
|
||||
merger.merge(None, None)
|
||||
|
||||
|
||||
# Branch 2: quant only
|
||||
def test_merge_quant_only(merger):
|
||||
cfg = OrchestratorConfig()
|
||||
q = _make_signal(direction=1, confidence=0.8, source="quant")
|
||||
result = merger.merge(q, None)
|
||||
assert result.direction == 1
|
||||
expected_conf = min(0.8 * cfg.quant_solo_penalty, cfg.quant_weight_cap)
|
||||
assert math.isclose(result.confidence, expected_conf)
|
||||
assert result.quant_signal is q
|
||||
assert result.llm_signal is None
|
||||
|
||||
|
||||
def test_merge_quant_only_capped(merger):
|
||||
cfg = OrchestratorConfig()
|
||||
# confidence=1.0 * quant_solo_penalty=0.8 → 0.8 == quant_weight_cap=0.8, no clamp needed
|
||||
q = _make_signal(direction=-1, confidence=1.0, source="quant")
|
||||
result = merger.merge(q, None)
|
||||
expected_conf = min(1.0 * cfg.quant_solo_penalty, cfg.quant_weight_cap)
|
||||
assert math.isclose(result.confidence, expected_conf)
|
||||
assert result.direction == -1
|
||||
|
||||
|
||||
# Branch 3: llm only
|
||||
def test_merge_llm_only(merger):
|
||||
cfg = OrchestratorConfig()
|
||||
l = _make_signal(direction=-1, confidence=0.9, source="llm")
|
||||
result = merger.merge(None, l)
|
||||
assert result.direction == -1
|
||||
expected_conf = min(0.9 * cfg.llm_solo_penalty, cfg.llm_weight_cap)
|
||||
assert math.isclose(result.confidence, expected_conf)
|
||||
assert result.llm_signal is l
|
||||
assert result.quant_signal is None
|
||||
|
||||
|
||||
def test_merge_llm_only_capped(merger):
|
||||
cfg = OrchestratorConfig()
|
||||
# Force cap: confidence=1.0, llm_solo_penalty=0.7 → 0.7 < llm_weight_cap=0.9, no cap
|
||||
l = _make_signal(direction=1, confidence=1.0, source="llm")
|
||||
result = merger.merge(None, l)
|
||||
expected_conf = min(1.0 * cfg.llm_solo_penalty, cfg.llm_weight_cap)
|
||||
assert math.isclose(result.confidence, expected_conf)
|
||||
|
||||
|
||||
# Branch 4: both present, same direction
|
||||
def test_merge_both_same_direction(merger):
|
||||
cfg = OrchestratorConfig()
|
||||
q = _make_signal(direction=1, confidence=0.6, source="quant")
|
||||
l = _make_signal(direction=1, confidence=0.8, source="llm")
|
||||
result = merger.merge(q, l)
|
||||
assert result.direction == 1
|
||||
# caps applied per-signal before merging
|
||||
quant_conf = min(0.6, cfg.quant_weight_cap) # 0.6
|
||||
llm_conf = min(0.8, cfg.llm_weight_cap) # 0.8
|
||||
weighted_sum = 1 * quant_conf + 1 * llm_conf # 1.4
|
||||
total_conf = quant_conf + llm_conf # 1.4
|
||||
expected_conf = abs(weighted_sum) / total_conf # 1.0
|
||||
assert math.isclose(result.confidence, expected_conf)
|
||||
|
||||
|
||||
# Branch 5: both present, opposite direction
|
||||
def test_merge_both_opposite_direction_quant_wins(merger):
|
||||
cfg = OrchestratorConfig()
|
||||
# quant stronger: direction should be quant's
|
||||
q = _make_signal(direction=1, confidence=0.9, source="quant")
|
||||
l = _make_signal(direction=-1, confidence=0.3, source="llm")
|
||||
result = merger.merge(q, l)
|
||||
assert result.direction == 1
|
||||
# caps applied per-signal before merging
|
||||
quant_conf = min(0.9, cfg.quant_weight_cap) # 0.8
|
||||
llm_conf = min(0.3, cfg.llm_weight_cap) # 0.3
|
||||
weighted_sum = 1 * quant_conf + (-1) * llm_conf # 0.5
|
||||
total_conf = quant_conf + llm_conf # 1.1
|
||||
expected_conf = abs(weighted_sum) / total_conf
|
||||
assert math.isclose(result.confidence, expected_conf)
|
||||
|
||||
|
||||
def test_merge_both_opposite_direction_llm_wins(merger):
|
||||
q = _make_signal(direction=1, confidence=0.2, source="quant")
|
||||
l = _make_signal(direction=-1, confidence=0.8, source="llm")
|
||||
result = merger.merge(q, l)
|
||||
assert result.direction == -1
|
||||
|
||||
|
||||
# weighted_sum=0 → direction=HOLD
|
||||
def test_merge_weighted_sum_zero(merger):
|
||||
q = _make_signal(direction=1, confidence=0.5, source="quant")
|
||||
l = _make_signal(direction=-1, confidence=0.5, source="llm")
|
||||
result = merger.merge(q, l)
|
||||
assert result.direction == 0
|
||||
assert math.isclose(result.confidence, 0.0)
|
||||
|
|
@ -4,6 +4,7 @@ FastAPI REST API + WebSocket for real-time analysis progress
|
|||
"""
|
||||
import asyncio
|
||||
import fcntl
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
|
|
@ -105,7 +106,9 @@ def _check_api_key(api_key: Optional[str]) -> bool:
|
|||
required = _get_api_key()
|
||||
if not required:
|
||||
return True
|
||||
return api_key == required
|
||||
if not api_key:
|
||||
return False
|
||||
return hmac.compare_digest(api_key, required)
|
||||
|
||||
def _auth_error():
|
||||
raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required")
|
||||
|
|
@ -363,7 +366,7 @@ async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Head
|
|||
# Use clean environment - don't inherit parent env
|
||||
clean_env = {k: v for k, v in os.environ.items()
|
||||
if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))}
|
||||
clean_env["ANTHROPIC_API_KEY"] = api_key
|
||||
clean_env["ANTHROPIC_API_KEY"] = anthropic_key
|
||||
clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
|
|
@ -1100,6 +1103,45 @@ async def root():
|
|||
return {"message": "TradingAgents Web Dashboard API", "version": "0.1.0"}
|
||||
|
||||
|
||||
@app.websocket("/ws/orchestrator")
|
||||
async def ws_orchestrator(websocket: WebSocket, api_key: Optional[str] = None):
|
||||
"""WebSocket endpoint for orchestrator live signals."""
|
||||
# Auth check before accepting — reject unauthenticated connections
|
||||
if not _check_api_key(api_key):
|
||||
await websocket.close(code=4401)
|
||||
return
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.orchestrator import TradingOrchestrator
|
||||
from orchestrator.live_mode import LiveMode
|
||||
|
||||
config = OrchestratorConfig(
|
||||
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
|
||||
)
|
||||
orchestrator = TradingOrchestrator(config)
|
||||
live = LiveMode(orchestrator)
|
||||
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
payload = json.loads(data)
|
||||
tickers = payload.get("tickers", [])
|
||||
date = payload.get("date")
|
||||
|
||||
results = await live.run_once(tickers, date)
|
||||
await websocket.send_text(json.dumps({"signals": results}))
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as e:
|
||||
try:
|
||||
await websocket.send_text(json.dumps({"error": str(e)}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
# Run with: cd web_dashboard && ../env312/bin/python -m uvicorn main:app --reload
|
||||
|
|
|
|||
Loading…
Reference in New Issue