TradingAgents/orchestrator/quant_runner.py

268 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import sqlite3
import sys
from datetime import datetime, timezone, timedelta
from typing import Any
import pandas as pd
import yfinance as yf
from orchestrator.config import OrchestratorConfig
from orchestrator.contracts.error_taxonomy import ReasonCode
from orchestrator.contracts.result_contract import Signal, build_error_signal
from orchestrator.market_calendar import is_non_trading_day
from tradingagents.dataflows.stockstats_utils import yf_retry
logger = logging.getLogger(__name__)
def _build_data_quality(state: str, **details: Any) -> dict[str, Any]:
payload = {"state": state}
payload.update({key: value for key, value in details.items() if value is not None})
return payload
class QuantRunner:
def __init__(self, config: OrchestratorConfig):
if not config.quant_backtest_path:
raise ValueError("OrchestratorConfig.quant_backtest_path must be set")
self._config = config
path = config.quant_backtest_path
if path not in sys.path:
sys.path.insert(0, path)
self._db_path = f"{path}/research_results/runs.db"
def get_signal(self, ticker: str, date: str) -> Signal:
"""
获取指定股票在指定日期的量化信号。
date 格式:'YYYY-MM-DD'
返回 Signal(source="quant")
"""
result = self._load_best_params()
params: dict = result["params"]
sharpe: float = result["sharpe_ratio"]
# 获取 date 前 60 天的历史数据
end_dt = datetime.strptime(date, "%Y-%m-%d")
start_dt = end_dt - timedelta(days=60)
start_str = start_dt.strftime("%Y-%m-%d")
end_exclusive = (end_dt + timedelta(days=1)).strftime("%Y-%m-%d")
df = yf_retry(
lambda: yf.download(
ticker,
start=start_str,
end=end_exclusive,
progress=False,
auto_adjust=True,
)
)
if df.empty:
logger.warning("No price data for %s between %s and %s", ticker, start_str, date)
if is_non_trading_day(ticker, end_dt.date()):
return build_error_signal(
ticker=ticker,
source="quant",
reason_code=ReasonCode.NON_TRADING_DAY.value,
message=f"{date} is not a trading day",
metadata={
"start_date": start_str,
"end_date": date,
"data_quality": _build_data_quality(
"non_trading_day",
requested_date=date,
),
},
)
return build_error_signal(
ticker=ticker,
source="quant",
reason_code=ReasonCode.QUANT_NO_DATA.value,
message=f"no price data between {start_str} and {date}",
metadata={
"start_date": start_str,
"end_date": date,
},
)
# 标准化列名为小写
df.columns = [c[0].lower() if isinstance(c, tuple) else c.lower() for c in df.columns]
required_columns = {"open", "high", "low", "close"}
missing_columns = sorted(required_columns - set(df.columns))
if missing_columns:
return build_error_signal(
ticker=ticker,
source="quant",
reason_code=ReasonCode.PARTIAL_DATA.value,
message=f"missing price columns: {', '.join(missing_columns)}",
metadata={
"start_date": start_str,
"end_date": date,
"data_quality": _build_data_quality(
"partial_data",
missing_fields=missing_columns,
),
},
)
df.index = pd.to_datetime(df.index)
available_dates = df.index.normalize()
requested_date = pd.Timestamp(end_dt.date())
if requested_date not in available_dates:
last_available_ts = df.index.max()
last_available_date = (
last_available_ts.strftime("%Y-%m-%d")
if hasattr(last_available_ts, "strftime")
else str(last_available_ts)
)
if is_non_trading_day(ticker, end_dt.date()):
return build_error_signal(
ticker=ticker,
source="quant",
reason_code=ReasonCode.NON_TRADING_DAY.value,
message=f"{date} is not a trading day",
metadata={
"start_date": start_str,
"end_date": date,
"data_quality": _build_data_quality(
"non_trading_day",
requested_date=date,
last_available_date=last_available_date,
),
},
)
return build_error_signal(
ticker=ticker,
source="quant",
reason_code=ReasonCode.STALE_DATA.value,
message=f"latest price data stops at {last_available_date}",
metadata={
"start_date": start_str,
"end_date": date,
"data_quality": _build_data_quality(
"stale_data",
requested_date=date,
last_available_date=last_available_date,
),
},
)
# 用最佳参数创建 BollingerStrategy 实例
# Lazy import: requires quant_backtest_path to be in sys.path (set in __init__)
from strategies.momentum import BollingerStrategy
from core.data_models import Bar, OrderDirection
strategy = BollingerStrategy(
period=params.get("period", 20),
num_std=params.get("num_std", 2.0),
position_pct=params.get("position_pct", 0.20),
stop_loss_pct=params.get("stop_loss_pct", 0.05),
take_profit_pct=params.get("take_profit_pct", 0.15),
)
# 逐 bar 喂给策略,模拟历史回放
direction = 0
orders: list = []
context: dict[str, Any] = {"positions": {}}
for ts, row in df.iterrows():
bar = Bar(
symbol=ticker,
timestamp=ts.to_pydatetime() if hasattr(ts, "to_pydatetime") else ts,
open=float(row["open"]),
high=float(row["high"]),
low=float(row["low"]),
close=float(row["close"]),
volume=float(row.get("volume", 0)),
)
orders = strategy.on_bar(bar, context)
# 更新模拟持仓
for order in orders:
if order.direction == OrderDirection.BUY:
context["positions"][ticker] = order.volume
elif order.direction == OrderDirection.SELL:
context["positions"][ticker] = 0
# 最后一个 bar 的信号
last_orders = orders if df.shape[0] > 0 else []
for order in last_orders:
if order.direction == OrderDirection.BUY:
direction = 1
break
elif order.direction == OrderDirection.SELL:
direction = -1
break
# 计算 max_sharpe从 DB 中取全局最大值)
try:
with sqlite3.connect(self._db_path) as conn:
cur = conn.cursor()
cur.execute("SELECT MAX(sharpe_ratio) FROM backtest_results")
row = cur.fetchone()
max_sharpe = float(row[0]) if row and row[0] is not None else sharpe
except Exception:
max_sharpe = sharpe
confidence = self._calc_confidence(sharpe, max_sharpe)
return Signal(
ticker=ticker,
direction=direction,
confidence=confidence,
source="quant",
timestamp=datetime.now(timezone.utc),
metadata={
"params": params,
"sharpe_ratio": sharpe,
"max_sharpe": max_sharpe,
"data_quality": _build_data_quality(
"ok",
requested_date=date,
last_available_date=date,
),
},
)
def _load_best_params(self) -> dict:
"""
直接查 SQLite 获取 BollingerStrategy 最佳参数。
参数是全局最优不区分股票backtest_results 表无 ticker 列,优化是全局的)。
strategy_type 支持 'BollingerStrategy''bollinger'(兼容两种写法)。
"""
with sqlite3.connect(self._db_path) as conn:
cur = conn.cursor()
# 先按规格查 'BollingerStrategy',再 fallback 到 'bollinger'
cur.execute(
"""
SELECT params, sharpe_ratio
FROM backtest_results
WHERE strategy_type IN ('BollingerStrategy', 'bollinger')
ORDER BY sharpe_ratio DESC
LIMIT 1
""",
)
row = cur.fetchone()
if row is None:
raise ValueError(
"No BollingerStrategy results found in ResultStore. "
"Run optimization first: python quant_backtest/run_research.py"
)
params = json.loads(row[0]) if isinstance(row[0], str) else row[0]
return {"params": params, "sharpe_ratio": float(row[1])}
def _calc_confidence(self, sharpe: float, max_sharpe: float) -> float:
"""
Sharpe 归一化为置信度。
- max_sharpe=0 时返回 0.5(默认值,避免除零)
- sharpe/max_sharpe 上限截断到 1.0
- 下限截断到 0.0(负 Sharpe 不产生负置信度)
"""
if max_sharpe == 0:
return 0.5
ratio = sharpe / max_sharpe
return max(0.0, min(1.0, ratio))