TradingAgents/orchestrator/tests/test_quant_runner.py

202 lines
6.2 KiB
Python

"""Tests for QuantRunner._calc_confidence()."""
import json
import sqlite3
import pandas as pd
import pytest
from orchestrator.config import OrchestratorConfig
from orchestrator.contracts.error_taxonomy import ReasonCode
from orchestrator.quant_runner import QuantRunner
def _make_runner(tmp_path):
"""Create a QuantRunner with a minimal SQLite DB so __init__ succeeds."""
db_dir = tmp_path / "research_results"
db_dir.mkdir(parents=True)
db_path = db_dir / "runs.db"
with sqlite3.connect(str(db_path)) as conn:
conn.execute(
"""CREATE TABLE backtest_results (
id INTEGER PRIMARY KEY,
strategy_type TEXT,
params TEXT,
sharpe_ratio REAL
)"""
)
conn.execute(
"INSERT INTO backtest_results (strategy_type, params, sharpe_ratio) VALUES (?, ?, ?)",
("BollingerStrategy", json.dumps({"period": 20, "num_std": 2.0,
"position_pct": 0.2,
"stop_loss_pct": 0.05,
"take_profit_pct": 0.15}), 1.5),
)
cfg = OrchestratorConfig(quant_backtest_path=str(tmp_path))
return QuantRunner(cfg)
@pytest.fixture
def runner(tmp_path):
return _make_runner(tmp_path)
def test_calc_confidence_max_sharpe_zero(runner):
assert runner._calc_confidence(1.0, 0) == 0.5
def test_calc_confidence_half(runner):
result = runner._calc_confidence(1.0, 2.0)
assert result == pytest.approx(0.5)
def test_calc_confidence_full(runner):
result = runner._calc_confidence(2.0, 2.0)
assert result == pytest.approx(1.0)
def test_calc_confidence_clamped_above(runner):
result = runner._calc_confidence(3.0, 2.0)
assert result == pytest.approx(1.0)
def test_calc_confidence_clamped_below(runner):
result = runner._calc_confidence(-1.0, 2.0)
assert result == pytest.approx(0.0)
def test_get_signal_returns_reason_code_when_no_data(runner, monkeypatch):
monkeypatch.setattr(
"orchestrator.quant_runner.yf.download",
lambda *args, **kwargs: type("EmptyFrame", (), {"empty": True})(),
)
signal = runner.get_signal("AAPL", "2024-01-02")
assert signal.degraded is True
assert signal.reason_code == ReasonCode.QUANT_NO_DATA.value
def test_get_signal_marks_non_trading_day_on_a_share_holiday(runner, monkeypatch):
monkeypatch.setattr(
"orchestrator.quant_runner.yf.download",
lambda *args, **kwargs: pd.DataFrame(),
)
signal = runner.get_signal("600519.SS", "2024-10-02")
assert signal.degraded is True
assert signal.reason_code == ReasonCode.NON_TRADING_DAY.value
assert signal.metadata["data_quality"]["state"] == "non_trading_day"
def test_get_signal_marks_non_trading_day_on_market_holiday(runner, monkeypatch):
monkeypatch.setattr(
"orchestrator.quant_runner.yf.download",
lambda *args, **kwargs: pd.DataFrame(),
)
signal = runner.get_signal("AAPL", "2024-03-29")
assert signal.degraded is True
assert signal.reason_code == ReasonCode.NON_TRADING_DAY.value
assert signal.metadata["data_quality"]["state"] == "non_trading_day"
def test_get_signal_marks_non_trading_day_on_weekend(runner, monkeypatch):
monkeypatch.setattr(
"orchestrator.quant_runner.yf.download",
lambda *args, **kwargs: pd.DataFrame(),
)
signal = runner.get_signal("AAPL", "2024-01-06")
assert signal.degraded is True
assert signal.reason_code == ReasonCode.NON_TRADING_DAY.value
assert signal.metadata["data_quality"]["state"] == "non_trading_day"
def test_get_signal_marks_non_trading_day_on_market_holiday(runner, monkeypatch):
holiday_frame = pd.DataFrame(
{
"Open": [10.0],
"High": [11.0],
"Low": [9.0],
"Close": [10.5],
"Volume": [1000],
},
index=pd.to_datetime(["2024-07-03"]),
)
monkeypatch.setattr(
"orchestrator.quant_runner.yf.download",
lambda *args, **kwargs: holiday_frame,
)
signal = runner.get_signal("AAPL", "2024-07-04")
assert signal.degraded is True
assert signal.reason_code == ReasonCode.NON_TRADING_DAY.value
assert signal.metadata["data_quality"]["state"] == "non_trading_day"
assert signal.metadata["data_quality"]["last_available_date"] == "2024-07-03"
def test_get_signal_marks_stale_data_when_requested_day_missing(runner, monkeypatch):
stale_frame = pd.DataFrame(
{
"Open": [10.0],
"High": [11.0],
"Low": [9.0],
"Close": [10.5],
"Volume": [1000],
},
index=pd.to_datetime(["2024-01-01"]),
)
monkeypatch.setattr(
"orchestrator.quant_runner.yf.download",
lambda *args, **kwargs: stale_frame,
)
signal = runner.get_signal("AAPL", "2024-01-02")
assert signal.degraded is True
assert signal.reason_code == ReasonCode.STALE_DATA.value
assert signal.metadata["data_quality"]["state"] == "stale_data"
def test_get_signal_marks_partial_data_when_required_columns_missing(runner, monkeypatch):
partial_frame = pd.DataFrame(
{
"Open": [10.0],
"Low": [9.0],
"Close": [10.5],
"Volume": [1000],
},
index=pd.to_datetime(["2024-01-02"]),
)
monkeypatch.setattr(
"orchestrator.quant_runner.yf.download",
lambda *args, **kwargs: partial_frame,
)
signal = runner.get_signal("AAPL", "2024-01-02")
assert signal.degraded is True
assert signal.reason_code == ReasonCode.PARTIAL_DATA.value
assert signal.metadata["data_quality"]["state"] == "partial_data"
def test_get_signal_uses_yf_retry_wrapper(runner, monkeypatch):
calls = []
def fake_retry(func, max_retries=3, base_delay=2.0):
calls.append((max_retries, base_delay))
return pd.DataFrame()
monkeypatch.setattr("orchestrator.quant_runner.yf_retry", fake_retry)
monkeypatch.setattr("orchestrator.quant_runner.is_non_trading_day", lambda *_args, **_kwargs: False)
signal = runner.get_signal("AAPL", "2024-01-02")
assert calls == [(3, 2.0)]
assert signal.reason_code == ReasonCode.QUANT_NO_DATA.value