feat(orchestrator): merge orchestrator module into main
This commit is contained in:
commit
8960fdf321
|
|
@ -0,0 +1,64 @@
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from orchestrator.signals import FinalSignal
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BacktestResult:
|
||||||
|
records: List[dict] = field(default_factory=list)
|
||||||
|
summary: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class BacktestMode:
|
||||||
|
def __init__(self, orchestrator):
|
||||||
|
self._orchestrator = orchestrator
|
||||||
|
|
||||||
|
def run(self, tickers: List[str], start_date: str, end_date: str) -> BacktestResult:
|
||||||
|
start = datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
|
end = datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
|
||||||
|
records = []
|
||||||
|
current = start
|
||||||
|
while current <= end:
|
||||||
|
if current.weekday() < 5: # skip weekends
|
||||||
|
date_str = current.strftime("%Y-%m-%d")
|
||||||
|
for ticker in tickers:
|
||||||
|
try:
|
||||||
|
sig = self._orchestrator.get_combined_signal(ticker, date_str)
|
||||||
|
records.append({
|
||||||
|
"ticker": ticker,
|
||||||
|
"date": date_str,
|
||||||
|
"direction": sig.direction,
|
||||||
|
"confidence": sig.confidence,
|
||||||
|
"quant_direction": sig.quant_signal.direction if sig.quant_signal else None,
|
||||||
|
"llm_direction": sig.llm_signal.direction if sig.llm_signal else None,
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("BacktestMode: failed for %s %s: %s", ticker, date_str, e)
|
||||||
|
current += timedelta(days=1)
|
||||||
|
|
||||||
|
summary = self._compute_summary(records, tickers)
|
||||||
|
return BacktestResult(records=records, summary=summary)
|
||||||
|
|
||||||
|
def _compute_summary(self, records: List[dict], tickers: List[str]) -> dict:
|
||||||
|
summary = {}
|
||||||
|
for ticker in tickers:
|
||||||
|
ticker_records = [r for r in records if r["ticker"] == ticker]
|
||||||
|
if not ticker_records:
|
||||||
|
summary[ticker] = {"total_days": 0}
|
||||||
|
continue
|
||||||
|
directions = [r["direction"] for r in ticker_records]
|
||||||
|
confidences = [r["confidence"] for r in ticker_records]
|
||||||
|
summary[ticker] = {
|
||||||
|
"total_days": len(ticker_records),
|
||||||
|
"buy_days": directions.count(1),
|
||||||
|
"sell_days": directions.count(-1),
|
||||||
|
"hold_days": directions.count(0),
|
||||||
|
"avg_confidence": sum(confidences) / len(confidences),
|
||||||
|
}
|
||||||
|
return summary
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OrchestratorConfig:
|
||||||
|
# Must be set to the local quant backtest output directory before use
|
||||||
|
quant_backtest_path: str = ""
|
||||||
|
trading_agents_config: dict = field(default_factory=dict)
|
||||||
|
quant_weight_cap: float = 0.8 # quant 置信度上限
|
||||||
|
llm_weight_cap: float = 0.9 # llm 置信度上限
|
||||||
|
llm_batch_days: int = 7 # LLM 每隔几天运行一次(节省 API)
|
||||||
|
cache_dir: str = "orchestrator/cache" # LLM 信号缓存目录
|
||||||
|
llm_solo_penalty: float = 0.7 # LLM 单轨时的置信度折扣
|
||||||
|
quant_solo_penalty: float = 0.8 # Quant 单轨时的置信度折扣
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
"""
|
||||||
|
Example: Run orchestrator backtest for 宁德时代 (300750.SZ) over 2023.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
cd /path/to/TradingAgents
|
||||||
|
QUANT_BACKTEST_PATH=/path/to/quant_backtest python orchestrator/examples/run_backtest.py
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Add repo root to path
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||||
|
|
||||||
|
from orchestrator.config import OrchestratorConfig
|
||||||
|
from orchestrator.orchestrator import TradingOrchestrator
|
||||||
|
from orchestrator.backtest_mode import BacktestMode
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
|
||||||
|
config = OrchestratorConfig(
|
||||||
|
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
|
||||||
|
cache_dir="orchestrator/cache",
|
||||||
|
)
|
||||||
|
|
||||||
|
orchestrator = TradingOrchestrator(config)
|
||||||
|
backtest = BacktestMode(orchestrator)
|
||||||
|
|
||||||
|
result = backtest.run(
|
||||||
|
tickers=["300750.SZ"],
|
||||||
|
start_date="2023-01-01",
|
||||||
|
end_date="2023-12-31",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n=== Backtest Summary ===")
|
||||||
|
print(json.dumps(result.summary, indent=2, ensure_ascii=False))
|
||||||
|
print(f"\nTotal records: {len(result.records)}")
|
||||||
|
if result.records:
|
||||||
|
print(f"First record: {result.records[0]}")
|
||||||
|
print(f"Last record: {result.records[-1]}")
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
"""
|
||||||
|
Example: Run orchestrator live mode for a list of tickers.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
cd /path/to/TradingAgents
|
||||||
|
QUANT_BACKTEST_PATH=/path/to/quant_backtest python orchestrator/examples/run_live.py
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||||
|
|
||||||
|
from orchestrator.config import OrchestratorConfig
|
||||||
|
from orchestrator.orchestrator import TradingOrchestrator
|
||||||
|
from orchestrator.live_mode import LiveMode
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
|
||||||
|
TICKERS = ["300750.SZ", "603259.SS"]
|
||||||
|
|
||||||
|
config = OrchestratorConfig(
|
||||||
|
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
|
||||||
|
cache_dir="orchestrator/cache",
|
||||||
|
)
|
||||||
|
|
||||||
|
orchestrator = TradingOrchestrator(config)
|
||||||
|
live = LiveMode(orchestrator)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
print(f"\n=== Live Signals for {today} ===")
|
||||||
|
results = await live.run_once(TICKERS, date=today)
|
||||||
|
print(json.dumps(results, indent=2, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LiveMode:
|
||||||
|
"""
|
||||||
|
Triggers signal computation for a list of tickers and broadcasts
|
||||||
|
results via a callback (e.g., WebSocket send).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, orchestrator):
|
||||||
|
self._orchestrator = orchestrator
|
||||||
|
|
||||||
|
async def run_once(self, tickers: List[str], date: Optional[str] = None) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Compute combined signals for all tickers on the given date (default: today).
|
||||||
|
Returns list of signal dicts.
|
||||||
|
"""
|
||||||
|
if date is None:
|
||||||
|
date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for ticker in tickers:
|
||||||
|
try:
|
||||||
|
sig = await asyncio.to_thread(
|
||||||
|
self._orchestrator.get_combined_signal, ticker, date
|
||||||
|
)
|
||||||
|
results.append({
|
||||||
|
"ticker": ticker,
|
||||||
|
"date": date,
|
||||||
|
"direction": sig.direction,
|
||||||
|
"confidence": sig.confidence,
|
||||||
|
"quant_direction": sig.quant_signal.direction if sig.quant_signal else None,
|
||||||
|
"llm_direction": sig.llm_signal.direction if sig.llm_signal else None,
|
||||||
|
"timestamp": sig.timestamp.isoformat(),
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("LiveMode: failed for %s %s: %s", ticker, date, e)
|
||||||
|
results.append({
|
||||||
|
"ticker": ticker,
|
||||||
|
"date": date,
|
||||||
|
"error": str(e),
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
|
@ -0,0 +1,95 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from orchestrator.config import OrchestratorConfig
|
||||||
|
from orchestrator.signals import Signal
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMRunner:
|
||||||
|
def __init__(self, config: OrchestratorConfig):
|
||||||
|
self._config = config
|
||||||
|
self._graph = None # Lazy-initialized on first get_signal() call (requires API key)
|
||||||
|
self.cache_dir = config.cache_dir
|
||||||
|
os.makedirs(self.cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def _get_graph(self):
|
||||||
|
"""Lazy-initialize TradingAgentsGraph (heavy, requires API key at init time)."""
|
||||||
|
if self._graph is None:
|
||||||
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
trading_cfg = self._config.trading_agents_config if self._config.trading_agents_config else None
|
||||||
|
self._graph = TradingAgentsGraph(config=trading_cfg)
|
||||||
|
return self._graph
|
||||||
|
|
||||||
|
def get_signal(self, ticker: str, date: str) -> Signal:
|
||||||
|
"""获取指定股票在指定日期的 LLM 信号,带缓存。"""
|
||||||
|
safe_ticker = ticker.replace("/", "_") # sanitize for filesystem (e.g. BRK/B)
|
||||||
|
cache_path = os.path.join(self.cache_dir, f"{safe_ticker}_{date}.json")
|
||||||
|
|
||||||
|
if os.path.exists(cache_path):
|
||||||
|
logger.info("LLMRunner: cache hit for %s %s", ticker, date)
|
||||||
|
with open(cache_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
# Use stored direction/confidence directly to avoid re-mapping drift
|
||||||
|
return Signal(
|
||||||
|
ticker=ticker,
|
||||||
|
direction=data["direction"],
|
||||||
|
confidence=data["confidence"],
|
||||||
|
source="llm",
|
||||||
|
timestamp=datetime.fromisoformat(data["timestamp"]),
|
||||||
|
metadata=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_final_state, processed_signal = self._get_graph().propagate(ticker, date)
|
||||||
|
rating = processed_signal if isinstance(processed_signal, str) else str(processed_signal)
|
||||||
|
direction, confidence = self._map_rating(rating)
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
cache_data = {
|
||||||
|
"rating": rating,
|
||||||
|
"direction": direction,
|
||||||
|
"confidence": confidence,
|
||||||
|
"timestamp": now.isoformat(),
|
||||||
|
"ticker": ticker,
|
||||||
|
"date": date,
|
||||||
|
}
|
||||||
|
with open(cache_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
return Signal(
|
||||||
|
ticker=ticker,
|
||||||
|
direction=direction,
|
||||||
|
confidence=confidence,
|
||||||
|
source="llm",
|
||||||
|
timestamp=now,
|
||||||
|
metadata=cache_data,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("LLMRunner: propagate failed for %s %s: %s", ticker, date, e)
|
||||||
|
return Signal(
|
||||||
|
ticker=ticker,
|
||||||
|
direction=0,
|
||||||
|
confidence=0.0,
|
||||||
|
source="llm",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
metadata={"error": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _map_rating(self, rating: str) -> tuple[int, float]:
|
||||||
|
"""将 5 级评级映射为 (direction, confidence)。"""
|
||||||
|
mapping = {
|
||||||
|
"BUY": (1, 0.9),
|
||||||
|
"OVERWEIGHT": (1, 0.6),
|
||||||
|
"HOLD": (0, 0.5),
|
||||||
|
"UNDERWEIGHT": (-1, 0.6),
|
||||||
|
"SELL": (-1, 0.9),
|
||||||
|
}
|
||||||
|
result = mapping.get(rating.upper() if rating else "", None)
|
||||||
|
if result is None:
|
||||||
|
logger.warning("LLMRunner: unknown rating %r, falling back to HOLD", rating)
|
||||||
|
return (0, 0.5)
|
||||||
|
return result
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from orchestrator.config import OrchestratorConfig
|
||||||
|
from orchestrator.signals import Signal, FinalSignal, SignalMerger
|
||||||
|
from orchestrator.quant_runner import QuantRunner
|
||||||
|
from orchestrator.llm_runner import LLMRunner
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TradingOrchestrator:
|
||||||
|
def __init__(self, config: OrchestratorConfig):
|
||||||
|
self._config = config
|
||||||
|
self._merger = SignalMerger(config)
|
||||||
|
self._quant: Optional[QuantRunner] = None
|
||||||
|
self._llm: Optional[LLMRunner] = None
|
||||||
|
|
||||||
|
# Initialize runners (quant requires quant_backtest_path)
|
||||||
|
if config.quant_backtest_path:
|
||||||
|
try:
|
||||||
|
self._quant = QuantRunner(config)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("TradingOrchestrator: QuantRunner init failed: %s", e)
|
||||||
|
|
||||||
|
self._llm = LLMRunner(config)
|
||||||
|
|
||||||
|
def get_combined_signal(self, ticker: str, date: str) -> FinalSignal:
|
||||||
|
"""
|
||||||
|
Get merged signal for ticker on date.
|
||||||
|
Degradation:
|
||||||
|
- quant fails (error signal): use llm only with llm_solo_penalty
|
||||||
|
- llm fails (error signal): use quant only with quant_solo_penalty
|
||||||
|
- both fail: raises ValueError
|
||||||
|
"""
|
||||||
|
quant_sig: Optional[Signal] = None
|
||||||
|
llm_sig: Optional[Signal] = None
|
||||||
|
|
||||||
|
# Get quant signal
|
||||||
|
if self._quant is not None:
|
||||||
|
try:
|
||||||
|
quant_sig = self._quant.get_signal(ticker, date)
|
||||||
|
# Treat error signals (confidence=0, direction=0 with error metadata) as None
|
||||||
|
if quant_sig.metadata.get("error") or quant_sig.metadata.get("reason") == "no_data":
|
||||||
|
logger.warning("TradingOrchestrator: quant signal degraded for %s %s", ticker, date)
|
||||||
|
quant_sig = None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("TradingOrchestrator: quant get_signal failed: %s", e)
|
||||||
|
quant_sig = None
|
||||||
|
|
||||||
|
# Get llm signal
|
||||||
|
try:
|
||||||
|
llm_sig = self._llm.get_signal(ticker, date)
|
||||||
|
if llm_sig.metadata.get("error"):
|
||||||
|
logger.warning("TradingOrchestrator: llm signal degraded for %s %s", ticker, date)
|
||||||
|
llm_sig = None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("TradingOrchestrator: llm get_signal failed: %s", e)
|
||||||
|
llm_sig = None
|
||||||
|
|
||||||
|
# merge raises ValueError if both None
|
||||||
|
return self._merger.merge(quant_sig, llm_sig)
|
||||||
|
|
@ -0,0 +1,162 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yfinance as yf
|
||||||
|
|
||||||
|
from orchestrator.config import OrchestratorConfig
|
||||||
|
from orchestrator.signals import Signal
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QuantRunner:
|
||||||
|
def __init__(self, config: OrchestratorConfig):
|
||||||
|
if not config.quant_backtest_path:
|
||||||
|
raise ValueError("OrchestratorConfig.quant_backtest_path must be set")
|
||||||
|
self._config = config
|
||||||
|
path = config.quant_backtest_path
|
||||||
|
if path not in sys.path:
|
||||||
|
sys.path.insert(0, path)
|
||||||
|
self._db_path = f"{path}/research_results/runs.db"
|
||||||
|
|
||||||
|
def get_signal(self, ticker: str, date: str) -> Signal:
|
||||||
|
"""
|
||||||
|
获取指定股票在指定日期的量化信号。
|
||||||
|
date 格式:'YYYY-MM-DD'
|
||||||
|
返回 Signal(source="quant")
|
||||||
|
"""
|
||||||
|
result = self._load_best_params()
|
||||||
|
params: dict = result["params"]
|
||||||
|
sharpe: float = result["sharpe_ratio"]
|
||||||
|
|
||||||
|
# 获取 date 前 60 天的历史数据
|
||||||
|
end_dt = datetime.strptime(date, "%Y-%m-%d")
|
||||||
|
start_dt = end_dt - timedelta(days=60)
|
||||||
|
start_str = start_dt.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
df = yf.download(ticker, start=start_str, end=date, progress=False, auto_adjust=True)
|
||||||
|
if df.empty:
|
||||||
|
logger.warning("No price data for %s between %s and %s", ticker, start_str, date)
|
||||||
|
return Signal(
|
||||||
|
ticker=ticker,
|
||||||
|
direction=0,
|
||||||
|
confidence=0.0,
|
||||||
|
source="quant",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
metadata={"reason": "no_data"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 标准化列名为小写
|
||||||
|
df.columns = [c[0].lower() if isinstance(c, tuple) else c.lower() for c in df.columns]
|
||||||
|
|
||||||
|
# 用最佳参数创建 BollingerStrategy 实例
|
||||||
|
# Lazy import: requires quant_backtest_path to be in sys.path (set in __init__)
|
||||||
|
from strategies.momentum import BollingerStrategy
|
||||||
|
from core.data_models import Bar, OrderDirection
|
||||||
|
|
||||||
|
strategy = BollingerStrategy(
|
||||||
|
period=params.get("period", 20),
|
||||||
|
num_std=params.get("num_std", 2.0),
|
||||||
|
position_pct=params.get("position_pct", 0.20),
|
||||||
|
stop_loss_pct=params.get("stop_loss_pct", 0.05),
|
||||||
|
take_profit_pct=params.get("take_profit_pct", 0.15),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 逐 bar 喂给策略,模拟历史回放
|
||||||
|
direction = 0
|
||||||
|
orders: list = []
|
||||||
|
context: dict[str, Any] = {"positions": {}}
|
||||||
|
|
||||||
|
for ts, row in df.iterrows():
|
||||||
|
bar = Bar(
|
||||||
|
symbol=ticker,
|
||||||
|
timestamp=ts.to_pydatetime() if hasattr(ts, "to_pydatetime") else ts,
|
||||||
|
open=float(row["open"]),
|
||||||
|
high=float(row["high"]),
|
||||||
|
low=float(row["low"]),
|
||||||
|
close=float(row["close"]),
|
||||||
|
volume=float(row.get("volume", 0)),
|
||||||
|
)
|
||||||
|
orders = strategy.on_bar(bar, context)
|
||||||
|
# 更新模拟持仓
|
||||||
|
for order in orders:
|
||||||
|
if order.direction == OrderDirection.BUY:
|
||||||
|
context["positions"][ticker] = order.volume
|
||||||
|
elif order.direction == OrderDirection.SELL:
|
||||||
|
context["positions"][ticker] = 0
|
||||||
|
|
||||||
|
# 最后一个 bar 的信号
|
||||||
|
last_orders = orders if df.shape[0] > 0 else []
|
||||||
|
for order in last_orders:
|
||||||
|
if order.direction == OrderDirection.BUY:
|
||||||
|
direction = 1
|
||||||
|
break
|
||||||
|
elif order.direction == OrderDirection.SELL:
|
||||||
|
direction = -1
|
||||||
|
break
|
||||||
|
|
||||||
|
# 计算 max_sharpe(从 DB 中取全局最大值)
|
||||||
|
try:
|
||||||
|
with sqlite3.connect(self._db_path) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute("SELECT MAX(sharpe_ratio) FROM backtest_results")
|
||||||
|
row = cur.fetchone()
|
||||||
|
max_sharpe = float(row[0]) if row and row[0] is not None else sharpe
|
||||||
|
except Exception:
|
||||||
|
max_sharpe = sharpe
|
||||||
|
|
||||||
|
confidence = self._calc_confidence(sharpe, max_sharpe)
|
||||||
|
|
||||||
|
return Signal(
|
||||||
|
ticker=ticker,
|
||||||
|
direction=direction,
|
||||||
|
confidence=confidence,
|
||||||
|
source="quant",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
metadata={"params": params, "sharpe_ratio": sharpe, "max_sharpe": max_sharpe},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_best_params(self) -> dict:
|
||||||
|
"""
|
||||||
|
直接查 SQLite 获取 BollingerStrategy 最佳参数。
|
||||||
|
参数是全局最优,不区分股票(backtest_results 表无 ticker 列,优化是全局的)。
|
||||||
|
strategy_type 支持 'BollingerStrategy' 和 'bollinger'(兼容两种写法)。
|
||||||
|
"""
|
||||||
|
with sqlite3.connect(self._db_path) as conn:
|
||||||
|
cur = conn.cursor()
|
||||||
|
# 先按规格查 'BollingerStrategy',再 fallback 到 'bollinger'
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT params, sharpe_ratio
|
||||||
|
FROM backtest_results
|
||||||
|
WHERE strategy_type IN ('BollingerStrategy', 'bollinger')
|
||||||
|
ORDER BY sharpe_ratio DESC
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
raise ValueError(
|
||||||
|
"No BollingerStrategy results found in ResultStore. "
|
||||||
|
"Run optimization first: python quant_backtest/run_research.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
params = json.loads(row[0]) if isinstance(row[0], str) else row[0]
|
||||||
|
return {"params": params, "sharpe_ratio": float(row[1])}
|
||||||
|
|
||||||
|
def _calc_confidence(self, sharpe: float, max_sharpe: float) -> float:
|
||||||
|
"""
|
||||||
|
Sharpe 归一化为置信度。
|
||||||
|
- max_sharpe=0 时返回 0.5(默认值,避免除零)
|
||||||
|
- sharpe/max_sharpe 上限截断到 1.0
|
||||||
|
- 下限截断到 0.0(负 Sharpe 不产生负置信度)
|
||||||
|
"""
|
||||||
|
if max_sharpe == 0:
|
||||||
|
return 0.5
|
||||||
|
ratio = sharpe / max_sharpe
|
||||||
|
return max(0.0, min(1.0, ratio))
|
||||||
|
|
@ -0,0 +1,101 @@
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from orchestrator.config import OrchestratorConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Signal:
|
||||||
|
ticker: str
|
||||||
|
direction: int # +1 买入, -1 卖出, 0 持有
|
||||||
|
confidence: float # 0.0 ~ 1.0
|
||||||
|
source: str # "quant" | "llm"
|
||||||
|
timestamp: datetime
|
||||||
|
metadata: dict = field(default_factory=dict) # 原始输出,用于调试
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FinalSignal:
|
||||||
|
ticker: str
|
||||||
|
direction: int # sign(quant_dir×quant_conf + llm_dir×llm_conf),sign(0)→0(HOLD)
|
||||||
|
confidence: float # abs(weighted_sum) / total_conf
|
||||||
|
quant_signal: Optional[Signal]
|
||||||
|
llm_signal: Optional[Signal]
|
||||||
|
timestamp: datetime
|
||||||
|
|
||||||
|
|
||||||
|
def _sign(x: float) -> int:
|
||||||
|
"""Return +1, -1, or 0."""
|
||||||
|
if x > 0:
|
||||||
|
return 1
|
||||||
|
elif x < 0:
|
||||||
|
return -1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class SignalMerger:
|
||||||
|
def __init__(self, config: OrchestratorConfig) -> None:
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
def merge(self, quant: Optional[Signal], llm: Optional[Signal]) -> FinalSignal:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# 两者均失败
|
||||||
|
if quant is None and llm is None:
|
||||||
|
raise ValueError("both quant and llm signals are None")
|
||||||
|
|
||||||
|
ticker = (quant or llm).ticker # type: ignore[union-attr]
|
||||||
|
|
||||||
|
# 只有 LLM(quant 失败)
|
||||||
|
if quant is None:
|
||||||
|
return FinalSignal(
|
||||||
|
ticker=ticker,
|
||||||
|
direction=llm.direction,
|
||||||
|
confidence=min(llm.confidence * self._config.llm_solo_penalty,
|
||||||
|
self._config.llm_weight_cap),
|
||||||
|
quant_signal=None,
|
||||||
|
llm_signal=llm,
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 只有 Quant(llm 失败)
|
||||||
|
if llm is None:
|
||||||
|
return FinalSignal(
|
||||||
|
ticker=ticker,
|
||||||
|
direction=quant.direction,
|
||||||
|
confidence=min(quant.confidence * self._config.quant_solo_penalty,
|
||||||
|
self._config.quant_weight_cap),
|
||||||
|
quant_signal=quant,
|
||||||
|
llm_signal=None,
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 两者都有:加权合并
|
||||||
|
# Cap each signal's contribution before merging
|
||||||
|
quant_conf = min(quant.confidence, self._config.quant_weight_cap)
|
||||||
|
llm_conf = min(llm.confidence, self._config.llm_weight_cap)
|
||||||
|
weighted_sum = (
|
||||||
|
quant.direction * quant_conf
|
||||||
|
+ llm.direction * llm_conf
|
||||||
|
)
|
||||||
|
final_direction = _sign(weighted_sum)
|
||||||
|
if final_direction == 0:
|
||||||
|
logger.info(
|
||||||
|
"SignalMerger: weighted_sum=0 for %s — signals cancel out, HOLD",
|
||||||
|
ticker,
|
||||||
|
)
|
||||||
|
total_conf = quant_conf + llm_conf
|
||||||
|
final_confidence = abs(weighted_sum) / total_conf if total_conf > 0 else 0.0
|
||||||
|
|
||||||
|
return FinalSignal(
|
||||||
|
ticker=ticker,
|
||||||
|
direction=final_direction,
|
||||||
|
confidence=final_confidence,
|
||||||
|
quant_signal=quant,
|
||||||
|
llm_signal=llm,
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
"""Tests for LLMRunner._map_rating()."""
|
||||||
|
import tempfile
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from orchestrator.config import OrchestratorConfig
|
||||||
|
from orchestrator.llm_runner import LLMRunner
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def runner(tmp_path):
|
||||||
|
cfg = OrchestratorConfig(cache_dir=str(tmp_path))
|
||||||
|
return LLMRunner(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
# All 5 known ratings
|
||||||
|
@pytest.mark.parametrize("rating,expected", [
|
||||||
|
("BUY", (1, 0.9)),
|
||||||
|
("OVERWEIGHT", (1, 0.6)),
|
||||||
|
("HOLD", (0, 0.5)),
|
||||||
|
("UNDERWEIGHT", (-1, 0.6)),
|
||||||
|
("SELL", (-1, 0.9)),
|
||||||
|
])
|
||||||
|
def test_map_rating_known(runner, rating, expected):
|
||||||
|
assert runner._map_rating(rating) == expected
|
||||||
|
|
||||||
|
|
||||||
|
# Unknown rating → (0, 0.5)
|
||||||
|
def test_map_rating_unknown(runner):
|
||||||
|
assert runner._map_rating("STRONG_BUY") == (0, 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
# Case-insensitive
|
||||||
|
def test_map_rating_lowercase(runner):
|
||||||
|
assert runner._map_rating("buy") == (1, 0.9)
|
||||||
|
assert runner._map_rating("sell") == (-1, 0.9)
|
||||||
|
assert runner._map_rating("hold") == (0, 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
# Empty string → (0, 0.5)
|
||||||
|
def test_map_rating_empty_string(runner):
|
||||||
|
assert runner._map_rating("") == (0, 0.5)
|
||||||
|
|
@ -0,0 +1,65 @@
|
||||||
|
"""Tests for QuantRunner._calc_confidence()."""
|
||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from orchestrator.config import OrchestratorConfig
|
||||||
|
from orchestrator.quant_runner import QuantRunner
|
||||||
|
|
||||||
|
|
||||||
|
def _make_runner(tmp_path):
|
||||||
|
"""Create a QuantRunner with a minimal SQLite DB so __init__ succeeds."""
|
||||||
|
db_dir = tmp_path / "research_results"
|
||||||
|
db_dir.mkdir(parents=True)
|
||||||
|
db_path = db_dir / "runs.db"
|
||||||
|
|
||||||
|
with sqlite3.connect(str(db_path)) as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""CREATE TABLE backtest_results (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
strategy_type TEXT,
|
||||||
|
params TEXT,
|
||||||
|
sharpe_ratio REAL
|
||||||
|
)"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO backtest_results (strategy_type, params, sharpe_ratio) VALUES (?, ?, ?)",
|
||||||
|
("BollingerStrategy", json.dumps({"period": 20, "num_std": 2.0,
|
||||||
|
"position_pct": 0.2,
|
||||||
|
"stop_loss_pct": 0.05,
|
||||||
|
"take_profit_pct": 0.15}), 1.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = OrchestratorConfig(quant_backtest_path=str(tmp_path))
|
||||||
|
return QuantRunner(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def runner(tmp_path):
|
||||||
|
return _make_runner(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_confidence_max_sharpe_zero(runner):
|
||||||
|
assert runner._calc_confidence(1.0, 0) == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_confidence_half(runner):
|
||||||
|
result = runner._calc_confidence(1.0, 2.0)
|
||||||
|
assert result == pytest.approx(0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_confidence_full(runner):
|
||||||
|
result = runner._calc_confidence(2.0, 2.0)
|
||||||
|
assert result == pytest.approx(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_confidence_clamped_above(runner):
|
||||||
|
result = runner._calc_confidence(3.0, 2.0)
|
||||||
|
assert result == pytest.approx(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_confidence_clamped_below(runner):
|
||||||
|
result = runner._calc_confidence(-1.0, 2.0)
|
||||||
|
assert result == pytest.approx(0.0)
|
||||||
|
|
@ -0,0 +1,120 @@
|
||||||
|
"""Tests for SignalMerger in orchestrator/signals.py."""
|
||||||
|
import math
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from orchestrator.config import OrchestratorConfig
|
||||||
|
from orchestrator.signals import Signal, SignalMerger
|
||||||
|
|
||||||
|
|
||||||
|
def _make_signal(ticker="AAPL", direction=1, confidence=0.8, source="quant"):
|
||||||
|
return Signal(
|
||||||
|
ticker=ticker,
|
||||||
|
direction=direction,
|
||||||
|
confidence=confidence,
|
||||||
|
source=source,
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def merger():
|
||||||
|
return SignalMerger(OrchestratorConfig())
|
||||||
|
|
||||||
|
|
||||||
|
# Branch 1: both None → ValueError
|
||||||
|
def test_merge_both_none_raises(merger):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
merger.merge(None, None)
|
||||||
|
|
||||||
|
|
||||||
|
# Branch 2: quant only
|
||||||
|
def test_merge_quant_only(merger):
|
||||||
|
cfg = OrchestratorConfig()
|
||||||
|
q = _make_signal(direction=1, confidence=0.8, source="quant")
|
||||||
|
result = merger.merge(q, None)
|
||||||
|
assert result.direction == 1
|
||||||
|
expected_conf = min(0.8 * cfg.quant_solo_penalty, cfg.quant_weight_cap)
|
||||||
|
assert math.isclose(result.confidence, expected_conf)
|
||||||
|
assert result.quant_signal is q
|
||||||
|
assert result.llm_signal is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_quant_only_capped(merger):
|
||||||
|
cfg = OrchestratorConfig()
|
||||||
|
# confidence=1.0 * quant_solo_penalty=0.8 → 0.8 == quant_weight_cap=0.8, no clamp needed
|
||||||
|
q = _make_signal(direction=-1, confidence=1.0, source="quant")
|
||||||
|
result = merger.merge(q, None)
|
||||||
|
expected_conf = min(1.0 * cfg.quant_solo_penalty, cfg.quant_weight_cap)
|
||||||
|
assert math.isclose(result.confidence, expected_conf)
|
||||||
|
assert result.direction == -1
|
||||||
|
|
||||||
|
|
||||||
|
# Branch 3: llm only
|
||||||
|
def test_merge_llm_only(merger):
|
||||||
|
cfg = OrchestratorConfig()
|
||||||
|
l = _make_signal(direction=-1, confidence=0.9, source="llm")
|
||||||
|
result = merger.merge(None, l)
|
||||||
|
assert result.direction == -1
|
||||||
|
expected_conf = min(0.9 * cfg.llm_solo_penalty, cfg.llm_weight_cap)
|
||||||
|
assert math.isclose(result.confidence, expected_conf)
|
||||||
|
assert result.llm_signal is l
|
||||||
|
assert result.quant_signal is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_llm_only_capped(merger):
|
||||||
|
cfg = OrchestratorConfig()
|
||||||
|
# Force cap: confidence=1.0, llm_solo_penalty=0.7 → 0.7 < llm_weight_cap=0.9, no cap
|
||||||
|
l = _make_signal(direction=1, confidence=1.0, source="llm")
|
||||||
|
result = merger.merge(None, l)
|
||||||
|
expected_conf = min(1.0 * cfg.llm_solo_penalty, cfg.llm_weight_cap)
|
||||||
|
assert math.isclose(result.confidence, expected_conf)
|
||||||
|
|
||||||
|
|
||||||
|
# Branch 4: both present, same direction
|
||||||
|
def test_merge_both_same_direction(merger):
|
||||||
|
cfg = OrchestratorConfig()
|
||||||
|
q = _make_signal(direction=1, confidence=0.6, source="quant")
|
||||||
|
l = _make_signal(direction=1, confidence=0.8, source="llm")
|
||||||
|
result = merger.merge(q, l)
|
||||||
|
assert result.direction == 1
|
||||||
|
# caps applied per-signal before merging
|
||||||
|
quant_conf = min(0.6, cfg.quant_weight_cap) # 0.6
|
||||||
|
llm_conf = min(0.8, cfg.llm_weight_cap) # 0.8
|
||||||
|
weighted_sum = 1 * quant_conf + 1 * llm_conf # 1.4
|
||||||
|
total_conf = quant_conf + llm_conf # 1.4
|
||||||
|
expected_conf = abs(weighted_sum) / total_conf # 1.0
|
||||||
|
assert math.isclose(result.confidence, expected_conf)
|
||||||
|
|
||||||
|
|
||||||
|
# Branch 5: both present, opposite direction
|
||||||
|
def test_merge_both_opposite_direction_quant_wins(merger):
|
||||||
|
cfg = OrchestratorConfig()
|
||||||
|
# quant stronger: direction should be quant's
|
||||||
|
q = _make_signal(direction=1, confidence=0.9, source="quant")
|
||||||
|
l = _make_signal(direction=-1, confidence=0.3, source="llm")
|
||||||
|
result = merger.merge(q, l)
|
||||||
|
assert result.direction == 1
|
||||||
|
# caps applied per-signal before merging
|
||||||
|
quant_conf = min(0.9, cfg.quant_weight_cap) # 0.8
|
||||||
|
llm_conf = min(0.3, cfg.llm_weight_cap) # 0.3
|
||||||
|
weighted_sum = 1 * quant_conf + (-1) * llm_conf # 0.5
|
||||||
|
total_conf = quant_conf + llm_conf # 1.1
|
||||||
|
expected_conf = abs(weighted_sum) / total_conf
|
||||||
|
assert math.isclose(result.confidence, expected_conf)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_both_opposite_direction_llm_wins(merger):
|
||||||
|
q = _make_signal(direction=1, confidence=0.2, source="quant")
|
||||||
|
l = _make_signal(direction=-1, confidence=0.8, source="llm")
|
||||||
|
result = merger.merge(q, l)
|
||||||
|
assert result.direction == -1
|
||||||
|
|
||||||
|
|
||||||
|
# weighted_sum=0 → direction=HOLD
|
||||||
|
def test_merge_weighted_sum_zero(merger):
|
||||||
|
q = _make_signal(direction=1, confidence=0.5, source="quant")
|
||||||
|
l = _make_signal(direction=-1, confidence=0.5, source="llm")
|
||||||
|
result = merger.merge(q, l)
|
||||||
|
assert result.direction == 0
|
||||||
|
assert math.isclose(result.confidence, 0.0)
|
||||||
|
|
@ -4,6 +4,7 @@ FastAPI REST API + WebSocket for real-time analysis progress
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import fcntl
|
import fcntl
|
||||||
|
import hmac
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
@ -105,7 +106,9 @@ def _check_api_key(api_key: Optional[str]) -> bool:
|
||||||
required = _get_api_key()
|
required = _get_api_key()
|
||||||
if not required:
|
if not required:
|
||||||
return True
|
return True
|
||||||
return api_key == required
|
if not api_key:
|
||||||
|
return False
|
||||||
|
return hmac.compare_digest(api_key, required)
|
||||||
|
|
||||||
def _auth_error():
|
def _auth_error():
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required")
|
raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required")
|
||||||
|
|
@ -363,7 +366,7 @@ async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Head
|
||||||
# Use clean environment - don't inherit parent env
|
# Use clean environment - don't inherit parent env
|
||||||
clean_env = {k: v for k, v in os.environ.items()
|
clean_env = {k: v for k, v in os.environ.items()
|
||||||
if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))}
|
if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))}
|
||||||
clean_env["ANTHROPIC_API_KEY"] = api_key
|
clean_env["ANTHROPIC_API_KEY"] = anthropic_key
|
||||||
clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
|
clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
|
||||||
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
|
@ -1100,6 +1103,45 @@ async def root():
|
||||||
return {"message": "TradingAgents Web Dashboard API", "version": "0.1.0"}
|
return {"message": "TradingAgents Web Dashboard API", "version": "0.1.0"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws/orchestrator")
|
||||||
|
async def ws_orchestrator(websocket: WebSocket, api_key: Optional[str] = None):
|
||||||
|
"""WebSocket endpoint for orchestrator live signals."""
|
||||||
|
# Auth check before accepting — reject unauthenticated connections
|
||||||
|
if not _check_api_key(api_key):
|
||||||
|
await websocket.close(code=4401)
|
||||||
|
return
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
from orchestrator.config import OrchestratorConfig
|
||||||
|
from orchestrator.orchestrator import TradingOrchestrator
|
||||||
|
from orchestrator.live_mode import LiveMode
|
||||||
|
|
||||||
|
config = OrchestratorConfig(
|
||||||
|
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
|
||||||
|
)
|
||||||
|
orchestrator = TradingOrchestrator(config)
|
||||||
|
live = LiveMode(orchestrator)
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
data = await websocket.receive_text()
|
||||||
|
payload = json.loads(data)
|
||||||
|
tickers = payload.get("tickers", [])
|
||||||
|
date = payload.get("date")
|
||||||
|
|
||||||
|
results = await live.run_once(tickers, date)
|
||||||
|
await websocket.send_text(json.dumps({"signals": results}))
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
try:
|
||||||
|
await websocket.send_text(json.dumps({"error": str(e)}))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
# Run with: cd web_dashboard && ../env312/bin/python -m uvicorn main:app --reload
|
# Run with: cd web_dashboard && ../env312/bin/python -m uvicorn main:app --reload
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue