fix(quant_runner): fix 3 critical issues and 2 important improvements
- Critical 1: initialize orders=[] before loop to prevent NameError when df is empty - Critical 2: replace bare sqlite3 conn with context manager (with statement) in get_signal - Critical 3: remove ticker param from _load_best_params (table has no ticker col, params are global) - Important: extract db_path as self._db_path attribute in __init__ (DRY) - Important: add comment explaining lazy imports require sys.path set in __init__
This commit is contained in:
parent
7a03c29330
commit
30d8f90467
|
|
@ -21,6 +21,7 @@ class QuantRunner:
|
||||||
path = config.quant_backtest_path
|
path = config.quant_backtest_path
|
||||||
if path not in sys.path:
|
if path not in sys.path:
|
||||||
sys.path.insert(0, path)
|
sys.path.insert(0, path)
|
||||||
|
self._db_path = f"{path}/research_results/runs.db"
|
||||||
|
|
||||||
def get_signal(self, ticker: str, date: str) -> Signal:
|
def get_signal(self, ticker: str, date: str) -> Signal:
|
||||||
"""
|
"""
|
||||||
|
|
@ -28,7 +29,7 @@ class QuantRunner:
|
||||||
date 格式:'YYYY-MM-DD'
|
date 格式:'YYYY-MM-DD'
|
||||||
返回 Signal(source="quant")
|
返回 Signal(source="quant")
|
||||||
"""
|
"""
|
||||||
result = self._load_best_params(ticker)
|
result = self._load_best_params()
|
||||||
params: dict = result["params"]
|
params: dict = result["params"]
|
||||||
sharpe: float = result["sharpe_ratio"]
|
sharpe: float = result["sharpe_ratio"]
|
||||||
|
|
||||||
|
|
@ -53,6 +54,7 @@ class QuantRunner:
|
||||||
df.columns = [c[0].lower() if isinstance(c, tuple) else c.lower() for c in df.columns]
|
df.columns = [c[0].lower() if isinstance(c, tuple) else c.lower() for c in df.columns]
|
||||||
|
|
||||||
# 用最佳参数创建 BollingerStrategy 实例
|
# 用最佳参数创建 BollingerStrategy 实例
|
||||||
|
# Lazy import: requires quant_backtest_path to be in sys.path (set in __init__)
|
||||||
from strategies.momentum import BollingerStrategy
|
from strategies.momentum import BollingerStrategy
|
||||||
from core.data_models import Bar, OrderDirection
|
from core.data_models import Bar, OrderDirection
|
||||||
|
|
||||||
|
|
@ -66,6 +68,7 @@ class QuantRunner:
|
||||||
|
|
||||||
# 逐 bar 喂给策略,模拟历史回放
|
# 逐 bar 喂给策略,模拟历史回放
|
||||||
direction = 0
|
direction = 0
|
||||||
|
orders: list = []
|
||||||
context: dict[str, Any] = {"positions": {}}
|
context: dict[str, Any] = {"positions": {}}
|
||||||
|
|
||||||
for ts, row in df.iterrows():
|
for ts, row in df.iterrows():
|
||||||
|
|
@ -97,14 +100,12 @@ class QuantRunner:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 计算 max_sharpe(从 DB 中取全局最大值)
|
# 计算 max_sharpe(从 DB 中取全局最大值)
|
||||||
db_path = f"{self._config.quant_backtest_path}/research_results/runs.db"
|
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
with sqlite3.connect(self._db_path) as conn:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute("SELECT MAX(sharpe_ratio) FROM backtest_results")
|
cur.execute("SELECT MAX(sharpe_ratio) FROM backtest_results")
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
max_sharpe = float(row[0]) if row and row[0] is not None else sharpe
|
max_sharpe = float(row[0]) if row and row[0] is not None else sharpe
|
||||||
conn.close()
|
|
||||||
except Exception:
|
except Exception:
|
||||||
max_sharpe = sharpe
|
max_sharpe = sharpe
|
||||||
|
|
||||||
|
|
@ -119,14 +120,13 @@ class QuantRunner:
|
||||||
metadata={"params": params, "sharpe_ratio": sharpe, "max_sharpe": max_sharpe},
|
metadata={"params": params, "sharpe_ratio": sharpe, "max_sharpe": max_sharpe},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_best_params(self, ticker: str) -> dict:
|
def _load_best_params(self) -> dict:
|
||||||
"""
|
"""
|
||||||
直接查 SQLite 获取 BollingerStrategy 最佳参数。
|
直接查 SQLite 获取 BollingerStrategy 最佳参数。
|
||||||
|
参数是全局最优,不区分股票(backtest_results 表无 ticker 列,优化是全局的)。
|
||||||
strategy_type 支持 'BollingerStrategy' 和 'bollinger'(兼容两种写法)。
|
strategy_type 支持 'BollingerStrategy' 和 'bollinger'(兼容两种写法)。
|
||||||
"""
|
"""
|
||||||
db_path = f"{self._config.quant_backtest_path}/research_results/runs.db"
|
with sqlite3.connect(self._db_path) as conn:
|
||||||
conn = sqlite3.connect(db_path)
|
|
||||||
try:
|
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
# 先按规格查 'BollingerStrategy',再 fallback 到 'bollinger'
|
# 先按规格查 'BollingerStrategy',再 fallback 到 'bollinger'
|
||||||
cur.execute(
|
cur.execute(
|
||||||
|
|
@ -139,8 +139,6 @@ class QuantRunner:
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if row is None:
|
if row is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue