TradingAgents/tests/unit/backtest/test_backtest_engine.py

965 lines
31 KiB
Python

"""Tests for Backtest Engine.
Issue #42: [BT-41] Backtest engine - historical replay, slippage
"""
from datetime import datetime, timedelta
from decimal import Decimal
import pytest
from tradingagents.backtest import (
# Enums
OrderSide,
OrderType,
FillStatus,
# Data Classes
OHLCV,
Signal,
BacktestConfig,
BacktestPosition,
BacktestTrade,
BacktestSnapshot,
BacktestResult,
# Slippage Models
SlippageModel,
NoSlippage,
FixedSlippage,
PercentageSlippage,
VolumeSlippage,
# Commission Models
CommissionModel,
NoCommission,
FixedCommission,
PerShareCommission,
PercentageCommission,
TieredCommission,
# Main Classes
BacktestEngine,
# Factory Functions
create_backtest_engine,
)
ZERO = Decimal("0")
# ============================================================================
# Enum Tests
# ============================================================================
class TestOrderSide:
"""Tests for OrderSide enum."""
def test_values(self):
"""Test enum values."""
assert OrderSide.BUY.value == "buy"
assert OrderSide.SELL.value == "sell"
class TestOrderType:
"""Tests for OrderType enum."""
def test_values(self):
"""Test enum values."""
assert OrderType.MARKET.value == "market"
assert OrderType.LIMIT.value == "limit"
assert OrderType.STOP.value == "stop"
assert OrderType.STOP_LIMIT.value == "stop_limit"
class TestFillStatus:
"""Tests for FillStatus enum."""
def test_values(self):
"""Test enum values."""
assert FillStatus.UNFILLED.value == "unfilled"
assert FillStatus.FILLED.value == "filled"
assert FillStatus.PARTIAL.value == "partial"
# ============================================================================
# Data Class Tests
# ============================================================================
class TestOHLCV:
"""Tests for OHLCV dataclass."""
def test_creation(self):
"""Test OHLCV creation."""
bar = OHLCV(
timestamp=datetime(2023, 1, 3),
open=Decimal("100"),
high=Decimal("105"),
low=Decimal("99"),
close=Decimal("103"),
volume=Decimal("1000000"),
symbol="AAPL",
)
assert bar.open == Decimal("100")
assert bar.close == Decimal("103")
assert bar.symbol == "AAPL"
def test_numeric_conversion(self):
"""Test numeric types are converted to Decimal."""
bar = OHLCV(
timestamp=datetime(2023, 1, 3),
open=100,
high=105,
low=99,
close=103,
volume=1000000,
)
assert isinstance(bar.open, Decimal)
assert bar.open == Decimal("100")
class TestSignal:
"""Tests for Signal dataclass."""
def test_creation(self):
"""Test signal creation."""
signal = Signal(
timestamp=datetime(2023, 1, 3),
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
)
assert signal.symbol == "AAPL"
assert signal.side == OrderSide.BUY
assert signal.quantity == Decimal("100")
def test_defaults(self):
"""Test signal defaults."""
signal = Signal(
timestamp=datetime(2023, 1, 3),
symbol="AAPL",
side=OrderSide.BUY,
)
assert signal.quantity == ZERO
assert signal.order_type == OrderType.MARKET
assert signal.confidence == Decimal("1")
class TestBacktestConfig:
"""Tests for BacktestConfig dataclass."""
def test_defaults(self):
"""Test default configuration."""
config = BacktestConfig()
assert config.initial_capital == Decimal("100000")
assert config.allow_shorting is False
assert config.max_position_pct == Decimal("20")
def test_custom_config(self):
"""Test custom configuration."""
config = BacktestConfig(
initial_capital=Decimal("50000"),
allow_shorting=True,
max_position_pct=Decimal("10"),
)
assert config.initial_capital == Decimal("50000")
assert config.allow_shorting is True
class TestBacktestPosition:
"""Tests for BacktestPosition dataclass."""
def test_creation(self):
"""Test position creation."""
position = BacktestPosition(
symbol="AAPL",
quantity=Decimal("100"),
average_cost=Decimal("150"),
current_price=Decimal("155"),
)
assert position.symbol == "AAPL"
assert position.quantity == Decimal("100")
def test_market_value(self):
"""Test market value calculation."""
position = BacktestPosition(
symbol="AAPL",
quantity=Decimal("100"),
average_cost=Decimal("150"),
current_price=Decimal("160"),
)
assert position.market_value == Decimal("16000")
def test_cost_basis(self):
"""Test cost basis calculation."""
position = BacktestPosition(
symbol="AAPL",
quantity=Decimal("100"),
average_cost=Decimal("150"),
)
assert position.cost_basis == Decimal("15000")
def test_is_long(self):
"""Test is_long property."""
position = BacktestPosition(symbol="AAPL", quantity=Decimal("100"))
assert position.is_long is True
assert position.is_short is False
def test_is_short(self):
"""Test is_short property."""
position = BacktestPosition(symbol="AAPL", quantity=Decimal("-100"))
assert position.is_short is True
assert position.is_long is False
def test_update_price(self):
"""Test price update."""
position = BacktestPosition(
symbol="AAPL",
quantity=Decimal("100"),
average_cost=Decimal("150"),
current_price=Decimal("150"),
)
position.update_price(Decimal("160"), datetime(2023, 1, 4))
assert position.current_price == Decimal("160")
assert position.unrealized_pnl == Decimal("1000") # (160-150)*100
# ============================================================================
# Slippage Model Tests
# ============================================================================
class TestNoSlippage:
"""Tests for NoSlippage model."""
def test_calculate(self):
"""Test no slippage."""
model = NoSlippage()
slippage = model.calculate(
price=Decimal("100"),
quantity=Decimal("100"),
side=OrderSide.BUY,
volume=Decimal("1000000"),
)
assert slippage == ZERO
class TestFixedSlippage:
"""Tests for FixedSlippage model."""
def test_calculate(self):
"""Test fixed slippage."""
model = FixedSlippage(Decimal("0.01"))
slippage = model.calculate(
price=Decimal("100"),
quantity=Decimal("100"),
side=OrderSide.BUY,
volume=Decimal("1000000"),
)
assert slippage == Decimal("0.01")
class TestPercentageSlippage:
"""Tests for PercentageSlippage model."""
def test_calculate(self):
"""Test percentage slippage."""
model = PercentageSlippage(Decimal("0.1")) # 0.1%
slippage = model.calculate(
price=Decimal("100"),
quantity=Decimal("100"),
side=OrderSide.BUY,
volume=Decimal("1000000"),
)
assert slippage == Decimal("0.1") # 0.1% of 100
class TestVolumeSlippage:
"""Tests for VolumeSlippage model."""
def test_calculate_low_volume(self):
"""Test low volume participation."""
model = VolumeSlippage(
base_percentage=Decimal("0.05"),
volume_impact=Decimal("0.1"),
)
slippage = model.calculate(
price=Decimal("100"),
quantity=Decimal("100"),
side=OrderSide.BUY,
volume=Decimal("10000"),
)
# 100/10000 = 1% participation
# slippage = 0.05% + (0.01 * 0.1 * 100) = 0.05% + 0.1% = 0.15%
assert slippage > ZERO
def test_calculate_high_volume(self):
"""Test high volume participation."""
model = VolumeSlippage(
base_percentage=Decimal("0.05"),
volume_impact=Decimal("0.1"),
max_percentage=Decimal("1.0"),
)
slippage = model.calculate(
price=Decimal("100"),
quantity=Decimal("5000"),
side=OrderSide.BUY,
volume=Decimal("10000"),
)
# 50% participation - should hit max
assert slippage <= Decimal("1.0") # Max 1%
def test_calculate_no_volume(self):
"""Test with no volume data."""
model = VolumeSlippage(base_percentage=Decimal("0.05"))
slippage = model.calculate(
price=Decimal("100"),
quantity=Decimal("100"),
side=OrderSide.BUY,
volume=ZERO,
)
# Falls back to base percentage
assert slippage == Decimal("0.05")
# ============================================================================
# Commission Model Tests
# ============================================================================
class TestNoCommission:
"""Tests for NoCommission model."""
def test_calculate(self):
"""Test no commission."""
model = NoCommission()
commission = model.calculate(
price=Decimal("100"),
quantity=Decimal("100"),
trade_value=Decimal("10000"),
)
assert commission == ZERO
class TestFixedCommission:
"""Tests for FixedCommission model."""
def test_calculate(self):
"""Test fixed commission."""
model = FixedCommission(Decimal("9.99"))
commission = model.calculate(
price=Decimal("100"),
quantity=Decimal("100"),
trade_value=Decimal("10000"),
)
assert commission == Decimal("9.99")
def test_minimum(self):
"""Test minimum commission."""
model = FixedCommission(Decimal("5"), minimum=Decimal("10"))
commission = model.calculate(
price=Decimal("100"),
quantity=Decimal("100"),
trade_value=Decimal("10000"),
)
assert commission == Decimal("10")
class TestPerShareCommission:
"""Tests for PerShareCommission model."""
def test_calculate(self):
"""Test per-share commission."""
model = PerShareCommission(Decimal("0.005")) # $0.005/share
commission = model.calculate(
price=Decimal("100"),
quantity=Decimal("100"),
trade_value=Decimal("10000"),
)
assert commission == Decimal("0.5") # 100 * 0.005
def test_minimum(self):
"""Test minimum commission."""
model = PerShareCommission(Decimal("0.005"), minimum=Decimal("1.0"))
commission = model.calculate(
price=Decimal("100"),
quantity=Decimal("10"), # Only 10 shares
trade_value=Decimal("1000"),
)
assert commission == Decimal("1.0") # Minimum
def test_maximum(self):
"""Test maximum commission."""
model = PerShareCommission(
Decimal("0.005"),
minimum=ZERO,
maximum=Decimal("10"),
)
commission = model.calculate(
price=Decimal("100"),
quantity=Decimal("10000"), # Many shares
trade_value=Decimal("1000000"),
)
assert commission == Decimal("10") # Maximum
class TestPercentageCommission:
"""Tests for PercentageCommission model."""
def test_calculate(self):
"""Test percentage commission."""
model = PercentageCommission(Decimal("0.1")) # 0.1%
commission = model.calculate(
price=Decimal("100"),
quantity=Decimal("100"),
trade_value=Decimal("10000"),
)
assert commission == Decimal("10") # 0.1% of 10000
def test_minimum(self):
"""Test minimum commission."""
model = PercentageCommission(Decimal("0.1"), minimum=Decimal("5"))
commission = model.calculate(
price=Decimal("100"),
quantity=Decimal("1"),
trade_value=Decimal("100"),
)
assert commission == Decimal("5") # Minimum
class TestTieredCommission:
"""Tests for TieredCommission model."""
def test_calculate_low_tier(self):
"""Test low tier commission."""
model = TieredCommission([
(Decimal("0"), Decimal("0.2")),
(Decimal("10000"), Decimal("0.15")),
(Decimal("50000"), Decimal("0.1")),
])
commission = model.calculate(
price=Decimal("100"),
quantity=Decimal("50"),
trade_value=Decimal("5000"),
)
assert commission == Decimal("10") # 0.2% of 5000
def test_calculate_mid_tier(self):
"""Test middle tier commission."""
model = TieredCommission([
(Decimal("0"), Decimal("0.2")),
(Decimal("10000"), Decimal("0.15")),
(Decimal("50000"), Decimal("0.1")),
])
commission = model.calculate(
price=Decimal("100"),
quantity=Decimal("200"),
trade_value=Decimal("20000"),
)
assert commission == Decimal("30") # 0.15% of 20000
def test_calculate_high_tier(self):
"""Test high tier commission."""
model = TieredCommission([
(Decimal("0"), Decimal("0.2")),
(Decimal("10000"), Decimal("0.15")),
(Decimal("50000"), Decimal("0.1")),
])
commission = model.calculate(
price=Decimal("100"),
quantity=Decimal("1000"),
trade_value=Decimal("100000"),
)
assert commission == Decimal("100") # 0.1% of 100000
# ============================================================================
# BacktestEngine Tests
# ============================================================================
class TestBacktestEngine:
"""Tests for BacktestEngine class."""
@pytest.fixture
def config(self):
"""Create test config."""
return BacktestConfig(
initial_capital=Decimal("100000"),
)
@pytest.fixture
def engine(self, config):
"""Create test engine."""
return BacktestEngine(config)
@pytest.fixture
def price_data(self):
"""Create test price data."""
return {
"AAPL": [
OHLCV(datetime(2023, 1, 3), 130, 132, 129, 131, 1000000, "AAPL"),
OHLCV(datetime(2023, 1, 4), 131, 135, 130, 134, 1200000, "AAPL"),
OHLCV(datetime(2023, 1, 5), 134, 136, 133, 135, 1100000, "AAPL"),
OHLCV(datetime(2023, 1, 6), 135, 138, 134, 137, 1300000, "AAPL"),
OHLCV(datetime(2023, 1, 9), 137, 140, 136, 139, 1400000, "AAPL"),
],
}
def test_initialization(self, engine):
"""Test engine initialization."""
assert engine.cash == Decimal("100000")
assert len(engine.positions) == 0
assert len(engine.trades) == 0
def test_reset(self, engine):
"""Test engine reset."""
engine.cash = Decimal("50000")
engine.positions["AAPL"] = BacktestPosition(symbol="AAPL")
engine.reset()
assert engine.cash == Decimal("100000")
assert len(engine.positions) == 0
def test_run_empty(self, engine):
"""Test run with no data."""
result = engine.run({}, [])
assert result.total_trades == 0
assert result.final_value == Decimal("100000")
def test_run_no_signals(self, engine, price_data):
"""Test run with no signals."""
result = engine.run(price_data, [])
assert result.total_trades == 0
assert result.final_value == Decimal("100000")
assert len(result.snapshots) == 5
def test_run_buy_signal(self, engine, price_data):
"""Test run with buy signal."""
signals = [
Signal(
timestamp=datetime(2023, 1, 3),
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
),
]
result = engine.run(price_data, signals)
assert result.total_trades == 1
assert len(engine.positions) == 1
assert "AAPL" in engine.positions
def test_run_buy_and_sell(self, engine, price_data):
"""Test run with buy and sell signals."""
signals = [
Signal(
timestamp=datetime(2023, 1, 3),
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
),
Signal(
timestamp=datetime(2023, 1, 6),
symbol="AAPL",
side=OrderSide.SELL,
quantity=Decimal("100"),
),
]
result = engine.run(price_data, signals)
assert result.total_trades == 2
assert len(engine.positions) == 0 # Position closed
# Should have profit: bought at ~131, sold at ~137
assert result.final_value > Decimal("100000")
def test_run_with_slippage(self, price_data):
"""Test run with slippage model."""
config = BacktestConfig(
initial_capital=Decimal("100000"),
slippage_model=FixedSlippage(Decimal("0.10")),
)
engine = BacktestEngine(config)
signals = [
Signal(
timestamp=datetime(2023, 1, 3),
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
),
]
result = engine.run(price_data, signals)
# Check slippage was applied
trade = result.trades[0]
assert trade.slippage > ZERO
assert trade.price > trade.base_price # Buy price increased by slippage
def test_run_with_commission(self, price_data):
"""Test run with commission model."""
config = BacktestConfig(
initial_capital=Decimal("100000"),
commission_model=FixedCommission(Decimal("10")),
)
engine = BacktestEngine(config)
signals = [
Signal(
timestamp=datetime(2023, 1, 3),
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
),
]
result = engine.run(price_data, signals)
assert result.total_commission == Decimal("10")
trade = result.trades[0]
assert trade.commission == Decimal("10")
def test_run_insufficient_cash(self, price_data):
"""Test run with insufficient cash."""
config = BacktestConfig(
initial_capital=Decimal("1000"), # Not enough for 100 shares at $131
)
engine = BacktestEngine(config)
signals = [
Signal(
timestamp=datetime(2023, 1, 3),
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
),
]
result = engine.run(price_data, signals)
# Should have bought fewer shares
if result.total_trades > 0:
assert engine.positions["AAPL"].quantity < Decimal("100")
# If no trades, that's also acceptable (couldn't afford any)
def test_run_no_shorting(self, engine, price_data):
"""Test no shorting when disabled."""
signals = [
Signal(
timestamp=datetime(2023, 1, 3),
symbol="AAPL",
side=OrderSide.SELL,
quantity=Decimal("100"),
),
]
result = engine.run(price_data, signals)
# Sell should be rejected (no position and shorting disabled)
assert result.total_trades == 0
def test_run_with_shorting(self, price_data):
"""Test shorting when enabled."""
config = BacktestConfig(
initial_capital=Decimal("100000"),
allow_shorting=True,
)
engine = BacktestEngine(config)
signals = [
Signal(
timestamp=datetime(2023, 1, 3),
symbol="AAPL",
side=OrderSide.SELL,
quantity=Decimal("100"),
),
]
result = engine.run(price_data, signals)
# Should have short position
assert result.total_trades == 1
assert engine.positions["AAPL"].quantity == Decimal("-100")
def test_run_position_sizing(self, price_data):
"""Test automatic position sizing."""
engine = BacktestEngine(BacktestConfig(
initial_capital=Decimal("100000"),
position_sizing="equal",
))
signals = [
Signal(
timestamp=datetime(2023, 1, 3),
symbol="AAPL",
side=OrderSide.BUY,
quantity=ZERO, # Auto-size
),
]
result = engine.run(price_data, signals)
# Should have calculated quantity
if result.total_trades > 0:
assert result.trades[0].quantity > ZERO
def test_get_position(self, engine, price_data):
"""Test getting position."""
signals = [
Signal(
timestamp=datetime(2023, 1, 3),
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
),
]
engine.run(price_data, signals)
position = engine.get_position("AAPL")
assert position is not None
assert position.quantity == Decimal("100")
no_position = engine.get_position("GOOG")
assert no_position is None
def test_get_cash(self, engine, price_data):
"""Test getting cash balance."""
initial_cash = engine.get_cash()
assert initial_cash == Decimal("100000")
signals = [
Signal(
timestamp=datetime(2023, 1, 3),
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
),
]
engine.run(price_data, signals)
assert engine.get_cash() < initial_cash
def test_get_portfolio_value(self, engine, price_data):
"""Test getting portfolio value."""
signals = [
Signal(
timestamp=datetime(2023, 1, 3),
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
),
]
engine.run(price_data, signals)
value = engine.get_portfolio_value()
# Should be approximately initial capital (cash + position value)
assert value > Decimal("99000")
assert value < Decimal("101000")
class TestBacktestResult:
"""Tests for BacktestResult metrics."""
@pytest.fixture
def price_data(self):
"""Create test price data with clear trend."""
return {
"AAPL": [
OHLCV(datetime(2023, 1, 3), 100, 102, 99, 100, 1000000, "AAPL"),
OHLCV(datetime(2023, 1, 4), 100, 105, 99, 105, 1200000, "AAPL"),
OHLCV(datetime(2023, 1, 5), 105, 110, 104, 110, 1100000, "AAPL"),
OHLCV(datetime(2023, 1, 6), 110, 115, 109, 115, 1300000, "AAPL"),
OHLCV(datetime(2023, 1, 9), 115, 120, 114, 120, 1400000, "AAPL"),
],
}
def test_winning_trade(self, price_data):
"""Test metrics for winning trade."""
engine = BacktestEngine(BacktestConfig(initial_capital=Decimal("100000")))
signals = [
Signal(datetime(2023, 1, 3), "AAPL", OrderSide.BUY, Decimal("100")),
Signal(datetime(2023, 1, 9), "AAPL", OrderSide.SELL, Decimal("100")),
]
result = engine.run(price_data, signals)
assert result.total_trades == 2
assert result.winning_trades >= 1
assert result.total_return > ZERO
assert result.final_value > result.initial_capital
def test_max_drawdown(self, price_data):
"""Test max drawdown calculation."""
# Add some volatility
price_data["AAPL"].insert(2, OHLCV(
datetime(2023, 1, 4, 12), 105, 106, 95, 95, 1000000, "AAPL"
))
engine = BacktestEngine(BacktestConfig(initial_capital=Decimal("100000")))
signals = [
Signal(datetime(2023, 1, 3), "AAPL", OrderSide.BUY, Decimal("100")),
]
result = engine.run(price_data, signals)
# Should have some drawdown recorded
assert result.max_drawdown >= ZERO
def test_snapshots(self, price_data):
"""Test snapshot creation."""
engine = BacktestEngine(BacktestConfig(initial_capital=Decimal("100000")))
signals = [
Signal(datetime(2023, 1, 3), "AAPL", OrderSide.BUY, Decimal("100")),
]
result = engine.run(price_data, signals)
assert len(result.snapshots) == 5
for snapshot in result.snapshots:
assert snapshot.total_value > ZERO
assert snapshot.cash >= ZERO
def test_daily_returns(self, price_data):
"""Test daily returns calculation."""
engine = BacktestEngine(BacktestConfig(initial_capital=Decimal("100000")))
signals = [
Signal(datetime(2023, 1, 3), "AAPL", OrderSide.BUY, Decimal("100")),
]
result = engine.run(price_data, signals)
assert len(result.daily_returns) == 4 # 5 snapshots = 4 returns
def test_trade_stats(self, price_data):
"""Test trade statistics."""
engine = BacktestEngine(BacktestConfig(initial_capital=Decimal("100000")))
signals = [
Signal(datetime(2023, 1, 3), "AAPL", OrderSide.BUY, Decimal("50")),
Signal(datetime(2023, 1, 5), "AAPL", OrderSide.SELL, Decimal("50")),
Signal(datetime(2023, 1, 5), "AAPL", OrderSide.BUY, Decimal("50")),
Signal(datetime(2023, 1, 9), "AAPL", OrderSide.SELL, Decimal("50")),
]
result = engine.run(price_data, signals)
assert result.total_trades == 4
assert result.winning_trades + result.losing_trades + (result.total_trades - result.winning_trades - result.losing_trades) == result.total_trades
class TestBacktestEngineIntegration:
"""Integration tests for backtest engine."""
def test_module_imports(self):
"""Test that all classes are exported from module."""
from tradingagents.backtest import (
OrderSide,
OrderType,
FillStatus,
OHLCV,
Signal,
BacktestConfig,
BacktestPosition,
BacktestTrade,
BacktestSnapshot,
BacktestResult,
SlippageModel,
NoSlippage,
FixedSlippage,
PercentageSlippage,
VolumeSlippage,
CommissionModel,
NoCommission,
FixedCommission,
PerShareCommission,
PercentageCommission,
TieredCommission,
BacktestEngine,
create_backtest_engine,
)
# All imports successful
assert BacktestEngine is not None
assert OrderSide.BUY is not None
def test_create_backtest_engine_factory(self):
"""Test factory function."""
engine = create_backtest_engine(
initial_capital=Decimal("50000"),
slippage=PercentageSlippage(Decimal("0.1")),
commission=FixedCommission(Decimal("10")),
)
assert engine.config.initial_capital == Decimal("50000")
assert isinstance(engine.config.slippage_model, PercentageSlippage)
assert isinstance(engine.config.commission_model, FixedCommission)
def test_strategy_callback(self):
"""Test dynamic signal generation via callback."""
engine = BacktestEngine(BacktestConfig(initial_capital=Decimal("100000")))
price_data = {
"AAPL": [
OHLCV(datetime(2023, 1, 3), 100, 102, 99, 101, 1000000, "AAPL"),
OHLCV(datetime(2023, 1, 4), 101, 105, 100, 104, 1200000, "AAPL"),
OHLCV(datetime(2023, 1, 5), 104, 108, 103, 107, 1100000, "AAPL"),
],
}
def strategy(timestamp, bars):
"""Simple momentum strategy."""
if "AAPL" in bars and bars["AAPL"].close > Decimal("102"):
return [Signal(
timestamp=timestamp,
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("10"),
)]
return []
result = engine.run(price_data, [], strategy_callback=strategy)
# Strategy should have generated signals on days 2 and 3
assert result.total_trades >= 1
def test_multi_symbol(self):
"""Test with multiple symbols."""
engine = BacktestEngine(BacktestConfig(initial_capital=Decimal("100000")))
price_data = {
"AAPL": [
OHLCV(datetime(2023, 1, 3), 100, 102, 99, 101, 1000000, "AAPL"),
OHLCV(datetime(2023, 1, 4), 101, 105, 100, 104, 1200000, "AAPL"),
],
"GOOG": [
OHLCV(datetime(2023, 1, 3), 90, 92, 89, 91, 500000, "GOOG"),
OHLCV(datetime(2023, 1, 4), 91, 94, 90, 93, 600000, "GOOG"),
],
}
signals = [
Signal(datetime(2023, 1, 3), "AAPL", OrderSide.BUY, Decimal("50")),
Signal(datetime(2023, 1, 3), "GOOG", OrderSide.BUY, Decimal("50")),
]
result = engine.run(price_data, signals)
assert result.total_trades == 2
assert "AAPL" in engine.positions
assert "GOOG" in engine.positions
def test_date_range_filter(self):
"""Test date range filtering."""
config = BacktestConfig(
initial_capital=Decimal("100000"),
start_date=datetime(2023, 1, 4),
end_date=datetime(2023, 1, 5),
)
engine = BacktestEngine(config)
price_data = {
"AAPL": [
OHLCV(datetime(2023, 1, 3), 100, 102, 99, 101, 1000000, "AAPL"),
OHLCV(datetime(2023, 1, 4), 101, 105, 100, 104, 1200000, "AAPL"),
OHLCV(datetime(2023, 1, 5), 104, 108, 103, 107, 1100000, "AAPL"),
OHLCV(datetime(2023, 1, 6), 107, 110, 106, 109, 1300000, "AAPL"),
],
}
signals = [
Signal(datetime(2023, 1, 3), "AAPL", OrderSide.BUY, Decimal("50")), # Before range
Signal(datetime(2023, 1, 4), "AAPL", OrderSide.BUY, Decimal("50")), # In range
Signal(datetime(2023, 1, 6), "AAPL", OrderSide.SELL, Decimal("50")), # After range
]
result = engine.run(price_data, signals)
# Only Jan 4 signal should execute
assert result.total_trades == 1
assert len(result.snapshots) == 2 # Only Jan 4-5