1271 lines
39 KiB
Python
1271 lines
39 KiB
Python
"""Backtest Engine for historical strategy replay.
|
|
|
|
Issue #42: [BT-41] Backtest engine - historical replay, slippage
|
|
|
|
This module provides backtesting capabilities for trading strategies:
|
|
- Historical price data replay
|
|
- Realistic slippage modeling
|
|
- Commission/fee handling
|
|
- Position and portfolio tracking
|
|
- Trade execution simulation
|
|
|
|
Classes:
|
|
SlippageModel: Base class for slippage calculation
|
|
FixedSlippage: Fixed amount slippage
|
|
PercentageSlippage: Percentage-based slippage
|
|
VolumeSlippage: Volume-impact slippage
|
|
CommissionModel: Base class for commission calculation
|
|
FixedCommission: Fixed per-trade commission
|
|
PercentageCommission: Percentage-based commission
|
|
TieredCommission: Tiered commission based on trade value
|
|
BacktestConfig: Configuration for backtest
|
|
BacktestPosition: Position tracking during backtest
|
|
BacktestTrade: Individual trade record
|
|
BacktestResult: Complete backtest result
|
|
BacktestEngine: Main backtest engine
|
|
|
|
Example:
|
|
>>> from tradingagents.backtest import BacktestEngine, BacktestConfig
|
|
>>> from decimal import Decimal
|
|
>>>
|
|
>>> config = BacktestConfig(
|
|
... initial_capital=Decimal("100000"),
|
|
... start_date=datetime(2023, 1, 1),
|
|
... end_date=datetime(2023, 12, 31),
|
|
... )
|
|
>>> engine = BacktestEngine(config)
|
|
>>> result = engine.run(price_data, signals)
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timedelta
|
|
from decimal import Decimal
|
|
from enum import Enum
|
|
from typing import Any, Callable, Optional, Protocol
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ============================================================================
|
|
# Constants
|
|
# ============================================================================
|
|
|
|
ZERO = Decimal("0")
|
|
ONE = Decimal("1")
|
|
HUNDRED = Decimal("100")
|
|
|
|
|
|
# ============================================================================
|
|
# Enums
|
|
# ============================================================================
|
|
|
|
class OrderSide(Enum):
|
|
"""Order side."""
|
|
BUY = "buy"
|
|
SELL = "sell"
|
|
|
|
|
|
class OrderType(Enum):
|
|
"""Order type for backtest."""
|
|
MARKET = "market"
|
|
LIMIT = "limit"
|
|
STOP = "stop"
|
|
STOP_LIMIT = "stop_limit"
|
|
|
|
|
|
class FillStatus(Enum):
|
|
"""Fill status for orders."""
|
|
UNFILLED = "unfilled"
|
|
PARTIAL = "partial"
|
|
FILLED = "filled"
|
|
CANCELLED = "cancelled"
|
|
REJECTED = "rejected"
|
|
|
|
|
|
# ============================================================================
|
|
# Price Data Types
|
|
# ============================================================================
|
|
|
|
@dataclass
|
|
class OHLCV:
|
|
"""OHLCV bar data.
|
|
|
|
Attributes:
|
|
timestamp: Bar timestamp
|
|
open: Open price
|
|
high: High price
|
|
low: Low price
|
|
close: Close price
|
|
volume: Volume
|
|
symbol: Optional symbol identifier
|
|
"""
|
|
timestamp: datetime
|
|
open: Decimal
|
|
high: Decimal
|
|
low: Decimal
|
|
close: Decimal
|
|
volume: Decimal = ZERO
|
|
symbol: str = ""
|
|
|
|
def __post_init__(self):
|
|
"""Convert numeric types to Decimal."""
|
|
for field_name in ["open", "high", "low", "close", "volume"]:
|
|
value = getattr(self, field_name)
|
|
if not isinstance(value, Decimal):
|
|
setattr(self, field_name, Decimal(str(value)))
|
|
|
|
|
|
@dataclass
|
|
class Signal:
|
|
"""Trading signal.
|
|
|
|
Attributes:
|
|
timestamp: Signal timestamp
|
|
symbol: Symbol to trade
|
|
side: Buy or sell
|
|
quantity: Quantity to trade (0 for position sizing by engine)
|
|
price: Target price (for limit orders)
|
|
order_type: Order type
|
|
confidence: Signal confidence (0-1)
|
|
metadata: Additional signal data
|
|
"""
|
|
timestamp: datetime
|
|
symbol: str
|
|
side: OrderSide
|
|
quantity: Decimal = ZERO
|
|
price: Decimal = ZERO
|
|
order_type: OrderType = OrderType.MARKET
|
|
confidence: Decimal = ONE
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
# ============================================================================
|
|
# Slippage Models
|
|
# ============================================================================
|
|
|
|
class SlippageModel(ABC):
|
|
"""Base class for slippage calculation."""
|
|
|
|
@abstractmethod
|
|
def calculate(
|
|
self,
|
|
price: Decimal,
|
|
quantity: Decimal,
|
|
side: OrderSide,
|
|
volume: Decimal,
|
|
) -> Decimal:
|
|
"""Calculate slippage amount.
|
|
|
|
Args:
|
|
price: Order price
|
|
quantity: Order quantity
|
|
side: Order side
|
|
volume: Bar volume
|
|
|
|
Returns:
|
|
Slippage amount (added to buy, subtracted from sell)
|
|
"""
|
|
pass
|
|
|
|
|
|
class NoSlippage(SlippageModel):
|
|
"""No slippage model."""
|
|
|
|
def calculate(
|
|
self,
|
|
price: Decimal,
|
|
quantity: Decimal,
|
|
side: OrderSide,
|
|
volume: Decimal,
|
|
) -> Decimal:
|
|
"""No slippage."""
|
|
return ZERO
|
|
|
|
|
|
class FixedSlippage(SlippageModel):
|
|
"""Fixed amount slippage per share.
|
|
|
|
Attributes:
|
|
amount: Fixed slippage amount per share
|
|
"""
|
|
|
|
def __init__(self, amount: Decimal):
|
|
"""Initialize with fixed amount.
|
|
|
|
Args:
|
|
amount: Slippage per share
|
|
"""
|
|
self.amount = Decimal(str(amount))
|
|
|
|
def calculate(
|
|
self,
|
|
price: Decimal,
|
|
quantity: Decimal,
|
|
side: OrderSide,
|
|
volume: Decimal,
|
|
) -> Decimal:
|
|
"""Calculate fixed slippage."""
|
|
return self.amount
|
|
|
|
|
|
class PercentageSlippage(SlippageModel):
|
|
"""Percentage-based slippage.
|
|
|
|
Attributes:
|
|
percentage: Slippage as percentage of price (e.g., 0.1 = 0.1%)
|
|
"""
|
|
|
|
def __init__(self, percentage: Decimal):
|
|
"""Initialize with percentage.
|
|
|
|
Args:
|
|
percentage: Slippage percentage (0.1 = 0.1%)
|
|
"""
|
|
self.percentage = Decimal(str(percentage))
|
|
|
|
def calculate(
|
|
self,
|
|
price: Decimal,
|
|
quantity: Decimal,
|
|
side: OrderSide,
|
|
volume: Decimal,
|
|
) -> Decimal:
|
|
"""Calculate percentage slippage."""
|
|
return price * self.percentage / HUNDRED
|
|
|
|
|
|
class VolumeSlippage(SlippageModel):
|
|
"""Volume-impact slippage model.
|
|
|
|
Slippage increases with order size relative to volume.
|
|
|
|
Attributes:
|
|
base_percentage: Base slippage percentage
|
|
volume_impact: Impact factor for volume (higher = more slippage)
|
|
max_percentage: Maximum slippage percentage
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_percentage: Decimal = Decimal("0.05"),
|
|
volume_impact: Decimal = Decimal("0.1"),
|
|
max_percentage: Decimal = Decimal("1.0"),
|
|
):
|
|
"""Initialize volume slippage model.
|
|
|
|
Args:
|
|
base_percentage: Base slippage (%)
|
|
volume_impact: Volume impact factor
|
|
max_percentage: Maximum slippage cap (%)
|
|
"""
|
|
self.base_percentage = Decimal(str(base_percentage))
|
|
self.volume_impact = Decimal(str(volume_impact))
|
|
self.max_percentage = Decimal(str(max_percentage))
|
|
|
|
def calculate(
|
|
self,
|
|
price: Decimal,
|
|
quantity: Decimal,
|
|
side: OrderSide,
|
|
volume: Decimal,
|
|
) -> Decimal:
|
|
"""Calculate volume-based slippage."""
|
|
if volume <= ZERO:
|
|
# No volume data, use base slippage
|
|
return price * self.base_percentage / HUNDRED
|
|
|
|
# Calculate volume participation
|
|
participation = quantity / volume
|
|
|
|
# Calculate slippage percentage
|
|
slippage_pct = self.base_percentage + (participation * self.volume_impact * HUNDRED)
|
|
|
|
# Cap at maximum
|
|
slippage_pct = min(slippage_pct, self.max_percentage)
|
|
|
|
return price * slippage_pct / HUNDRED
|
|
|
|
|
|
# ============================================================================
|
|
# Commission Models
|
|
# ============================================================================
|
|
|
|
class CommissionModel(ABC):
|
|
"""Base class for commission calculation."""
|
|
|
|
@abstractmethod
|
|
def calculate(
|
|
self,
|
|
price: Decimal,
|
|
quantity: Decimal,
|
|
trade_value: Decimal,
|
|
) -> Decimal:
|
|
"""Calculate commission.
|
|
|
|
Args:
|
|
price: Trade price
|
|
quantity: Trade quantity
|
|
trade_value: Total trade value
|
|
|
|
Returns:
|
|
Commission amount
|
|
"""
|
|
pass
|
|
|
|
|
|
class NoCommission(CommissionModel):
|
|
"""No commission model."""
|
|
|
|
def calculate(
|
|
self,
|
|
price: Decimal,
|
|
quantity: Decimal,
|
|
trade_value: Decimal,
|
|
) -> Decimal:
|
|
"""No commission."""
|
|
return ZERO
|
|
|
|
|
|
class FixedCommission(CommissionModel):
|
|
"""Fixed per-trade commission.
|
|
|
|
Attributes:
|
|
amount: Fixed commission per trade
|
|
minimum: Minimum commission
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
amount: Decimal,
|
|
minimum: Decimal = ZERO,
|
|
):
|
|
"""Initialize with fixed amount.
|
|
|
|
Args:
|
|
amount: Commission per trade
|
|
minimum: Minimum commission
|
|
"""
|
|
self.amount = Decimal(str(amount))
|
|
self.minimum = Decimal(str(minimum))
|
|
|
|
def calculate(
|
|
self,
|
|
price: Decimal,
|
|
quantity: Decimal,
|
|
trade_value: Decimal,
|
|
) -> Decimal:
|
|
"""Calculate fixed commission."""
|
|
return max(self.amount, self.minimum)
|
|
|
|
|
|
class PerShareCommission(CommissionModel):
|
|
"""Per-share commission.
|
|
|
|
Attributes:
|
|
per_share: Commission per share
|
|
minimum: Minimum commission per trade
|
|
maximum: Maximum commission per trade
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
per_share: Decimal,
|
|
minimum: Decimal = ZERO,
|
|
maximum: Decimal = Decimal("Infinity"),
|
|
):
|
|
"""Initialize per-share commission.
|
|
|
|
Args:
|
|
per_share: Commission per share
|
|
minimum: Minimum per trade
|
|
maximum: Maximum per trade
|
|
"""
|
|
self.per_share = Decimal(str(per_share))
|
|
self.minimum = Decimal(str(minimum))
|
|
self.maximum = Decimal(str(maximum)) if maximum != Decimal("Infinity") else None
|
|
|
|
def calculate(
|
|
self,
|
|
price: Decimal,
|
|
quantity: Decimal,
|
|
trade_value: Decimal,
|
|
) -> Decimal:
|
|
"""Calculate per-share commission."""
|
|
commission = self.per_share * abs(quantity)
|
|
commission = max(commission, self.minimum)
|
|
if self.maximum is not None:
|
|
commission = min(commission, self.maximum)
|
|
return commission
|
|
|
|
|
|
class PercentageCommission(CommissionModel):
|
|
"""Percentage-based commission.
|
|
|
|
Attributes:
|
|
percentage: Commission as percentage of trade value
|
|
minimum: Minimum commission
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
percentage: Decimal,
|
|
minimum: Decimal = ZERO,
|
|
):
|
|
"""Initialize percentage commission.
|
|
|
|
Args:
|
|
percentage: Commission percentage (e.g., 0.1 = 0.1%)
|
|
minimum: Minimum commission
|
|
"""
|
|
self.percentage = Decimal(str(percentage))
|
|
self.minimum = Decimal(str(minimum))
|
|
|
|
def calculate(
|
|
self,
|
|
price: Decimal,
|
|
quantity: Decimal,
|
|
trade_value: Decimal,
|
|
) -> Decimal:
|
|
"""Calculate percentage commission."""
|
|
commission = abs(trade_value) * self.percentage / HUNDRED
|
|
return max(commission, self.minimum)
|
|
|
|
|
|
class TieredCommission(CommissionModel):
|
|
"""Tiered commission based on trade value.
|
|
|
|
Attributes:
|
|
tiers: List of (threshold, percentage) tuples
|
|
minimum: Minimum commission
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tiers: list[tuple[Decimal, Decimal]],
|
|
minimum: Decimal = ZERO,
|
|
):
|
|
"""Initialize tiered commission.
|
|
|
|
Args:
|
|
tiers: List of (threshold, percentage) - sorted ascending
|
|
minimum: Minimum commission
|
|
"""
|
|
self.tiers = sorted(
|
|
[(Decimal(str(t)), Decimal(str(p))) for t, p in tiers],
|
|
key=lambda x: x[0],
|
|
)
|
|
self.minimum = Decimal(str(minimum))
|
|
|
|
def calculate(
|
|
self,
|
|
price: Decimal,
|
|
quantity: Decimal,
|
|
trade_value: Decimal,
|
|
) -> Decimal:
|
|
"""Calculate tiered commission."""
|
|
abs_value = abs(trade_value)
|
|
|
|
# Find applicable tier
|
|
percentage = self.tiers[0][1] if self.tiers else ZERO
|
|
for threshold, pct in self.tiers:
|
|
if abs_value >= threshold:
|
|
percentage = pct
|
|
else:
|
|
break
|
|
|
|
commission = abs_value * percentage / HUNDRED
|
|
return max(commission, self.minimum)
|
|
|
|
|
|
# ============================================================================
|
|
# Backtest Data Classes
|
|
# ============================================================================
|
|
|
|
@dataclass
|
|
class BacktestConfig:
|
|
"""Configuration for backtest.
|
|
|
|
Attributes:
|
|
initial_capital: Starting capital
|
|
start_date: Backtest start date
|
|
end_date: Backtest end date
|
|
slippage_model: Slippage model to use
|
|
commission_model: Commission model to use
|
|
position_sizing: Position sizing mode
|
|
max_position_pct: Maximum position as % of portfolio
|
|
min_trade_value: Minimum trade value
|
|
allow_shorting: Whether to allow short positions
|
|
margin_rate: Margin rate for leveraged trades
|
|
risk_free_rate: Risk-free rate for Sharpe calculation
|
|
benchmark_symbol: Benchmark symbol for comparison
|
|
rebalance_frequency: Rebalance frequency in days (0 = no rebalance)
|
|
"""
|
|
initial_capital: Decimal = Decimal("100000")
|
|
start_date: Optional[datetime] = None
|
|
end_date: Optional[datetime] = None
|
|
slippage_model: SlippageModel = field(default_factory=NoSlippage)
|
|
commission_model: CommissionModel = field(default_factory=NoCommission)
|
|
position_sizing: str = "equal" # equal, risk_parity, kelly
|
|
max_position_pct: Decimal = Decimal("20") # 20% max per position
|
|
min_trade_value: Decimal = Decimal("100")
|
|
allow_shorting: bool = False
|
|
margin_rate: Decimal = Decimal("50") # 50% margin
|
|
risk_free_rate: Decimal = Decimal("0.05") # 5% annual
|
|
benchmark_symbol: str = "SPY"
|
|
rebalance_frequency: int = 0 # 0 = no automatic rebalance
|
|
|
|
|
|
@dataclass
|
|
class BacktestPosition:
|
|
"""Position during backtest.
|
|
|
|
Attributes:
|
|
symbol: Position symbol
|
|
quantity: Current quantity (negative for short)
|
|
average_cost: Average cost basis
|
|
current_price: Current market price
|
|
unrealized_pnl: Unrealized P&L
|
|
realized_pnl: Realized P&L from closed trades
|
|
opened_at: Position open timestamp
|
|
last_updated: Last update timestamp
|
|
"""
|
|
symbol: str
|
|
quantity: Decimal = ZERO
|
|
average_cost: Decimal = ZERO
|
|
current_price: Decimal = ZERO
|
|
unrealized_pnl: Decimal = ZERO
|
|
realized_pnl: Decimal = ZERO
|
|
opened_at: Optional[datetime] = None
|
|
last_updated: Optional[datetime] = None
|
|
|
|
@property
|
|
def market_value(self) -> Decimal:
|
|
"""Get current market value."""
|
|
return self.quantity * self.current_price
|
|
|
|
@property
|
|
def cost_basis(self) -> Decimal:
|
|
"""Get total cost basis."""
|
|
return self.quantity * self.average_cost
|
|
|
|
@property
|
|
def is_long(self) -> bool:
|
|
"""Check if long position."""
|
|
return self.quantity > ZERO
|
|
|
|
@property
|
|
def is_short(self) -> bool:
|
|
"""Check if short position."""
|
|
return self.quantity < ZERO
|
|
|
|
def update_price(self, price: Decimal, timestamp: datetime) -> None:
|
|
"""Update current price and unrealized P&L.
|
|
|
|
Args:
|
|
price: New price
|
|
timestamp: Update timestamp
|
|
"""
|
|
self.current_price = price
|
|
self.unrealized_pnl = (price - self.average_cost) * self.quantity
|
|
self.last_updated = timestamp
|
|
|
|
|
|
@dataclass
|
|
class BacktestTrade:
|
|
"""Individual trade record.
|
|
|
|
Attributes:
|
|
trade_id: Unique trade ID
|
|
timestamp: Trade timestamp
|
|
symbol: Symbol traded
|
|
side: Buy or sell
|
|
quantity: Quantity traded
|
|
price: Execution price (after slippage)
|
|
base_price: Price before slippage
|
|
slippage: Slippage amount
|
|
commission: Commission paid
|
|
trade_value: Total trade value
|
|
signal_confidence: Original signal confidence
|
|
position_after: Position quantity after trade
|
|
cash_after: Cash balance after trade
|
|
pnl: Realized P&L (for closing trades)
|
|
"""
|
|
trade_id: str = ""
|
|
timestamp: datetime = field(default_factory=datetime.now)
|
|
symbol: str = ""
|
|
side: OrderSide = OrderSide.BUY
|
|
quantity: Decimal = ZERO
|
|
price: Decimal = ZERO
|
|
base_price: Decimal = ZERO
|
|
slippage: Decimal = ZERO
|
|
commission: Decimal = ZERO
|
|
trade_value: Decimal = ZERO
|
|
signal_confidence: Decimal = ONE
|
|
position_after: Decimal = ZERO
|
|
cash_after: Decimal = ZERO
|
|
pnl: Decimal = ZERO
|
|
|
|
|
|
@dataclass
|
|
class BacktestSnapshot:
|
|
"""Portfolio snapshot at a point in time.
|
|
|
|
Attributes:
|
|
timestamp: Snapshot timestamp
|
|
cash: Cash balance
|
|
positions_value: Total value of positions
|
|
total_value: Total portfolio value
|
|
positions: Current positions
|
|
drawdown: Current drawdown from peak
|
|
peak_value: Peak portfolio value
|
|
"""
|
|
timestamp: datetime
|
|
cash: Decimal
|
|
positions_value: Decimal
|
|
total_value: Decimal
|
|
positions: dict[str, BacktestPosition] = field(default_factory=dict)
|
|
drawdown: Decimal = ZERO
|
|
peak_value: Decimal = ZERO
|
|
|
|
|
|
@dataclass
|
|
class BacktestResult:
|
|
"""Complete backtest result.
|
|
|
|
Attributes:
|
|
config: Backtest configuration
|
|
start_date: Actual start date
|
|
end_date: Actual end date
|
|
initial_capital: Starting capital
|
|
final_value: Ending portfolio value
|
|
total_return: Total return percentage
|
|
annualized_return: Annualized return
|
|
sharpe_ratio: Sharpe ratio
|
|
sortino_ratio: Sortino ratio
|
|
max_drawdown: Maximum drawdown
|
|
win_rate: Win rate
|
|
profit_factor: Profit factor
|
|
total_trades: Number of trades
|
|
winning_trades: Number of winning trades
|
|
losing_trades: Number of losing trades
|
|
avg_trade_pnl: Average P&L per trade
|
|
avg_win: Average winning trade
|
|
avg_loss: Average losing trade
|
|
max_win: Largest winning trade
|
|
max_loss: Largest losing trade
|
|
total_commission: Total commission paid
|
|
total_slippage: Total slippage cost
|
|
trades: List of all trades
|
|
snapshots: Portfolio snapshots over time
|
|
daily_returns: Daily return series
|
|
benchmark_return: Benchmark return (if available)
|
|
alpha: Alpha vs benchmark
|
|
beta: Beta vs benchmark
|
|
errors: Any errors during backtest
|
|
"""
|
|
config: BacktestConfig = field(default_factory=BacktestConfig)
|
|
start_date: Optional[datetime] = None
|
|
end_date: Optional[datetime] = None
|
|
initial_capital: Decimal = ZERO
|
|
final_value: Decimal = ZERO
|
|
total_return: Decimal = ZERO
|
|
annualized_return: Decimal = ZERO
|
|
sharpe_ratio: Decimal = ZERO
|
|
sortino_ratio: Decimal = ZERO
|
|
max_drawdown: Decimal = ZERO
|
|
win_rate: Decimal = ZERO
|
|
profit_factor: Decimal = ZERO
|
|
total_trades: int = 0
|
|
winning_trades: int = 0
|
|
losing_trades: int = 0
|
|
avg_trade_pnl: Decimal = ZERO
|
|
avg_win: Decimal = ZERO
|
|
avg_loss: Decimal = ZERO
|
|
max_win: Decimal = ZERO
|
|
max_loss: Decimal = ZERO
|
|
total_commission: Decimal = ZERO
|
|
total_slippage: Decimal = ZERO
|
|
trades: list[BacktestTrade] = field(default_factory=list)
|
|
snapshots: list[BacktestSnapshot] = field(default_factory=list)
|
|
daily_returns: list[Decimal] = field(default_factory=list)
|
|
benchmark_return: Decimal = ZERO
|
|
alpha: Decimal = ZERO
|
|
beta: Decimal = ZERO
|
|
errors: list[str] = field(default_factory=list)
|
|
|
|
|
|
# ============================================================================
|
|
# Backtest Engine
|
|
# ============================================================================
|
|
|
|
class BacktestEngine:
|
|
"""Main backtest engine for historical strategy replay.
|
|
|
|
Attributes:
|
|
config: Backtest configuration
|
|
cash: Current cash balance
|
|
positions: Current positions
|
|
trades: Trade history
|
|
snapshots: Portfolio snapshots
|
|
"""
|
|
|
|
def __init__(self, config: Optional[BacktestConfig] = None):
|
|
"""Initialize backtest engine.
|
|
|
|
Args:
|
|
config: Backtest configuration
|
|
"""
|
|
self.config = config or BacktestConfig()
|
|
self.reset()
|
|
|
|
def reset(self) -> None:
|
|
"""Reset engine state."""
|
|
self.cash = self.config.initial_capital
|
|
self.positions: dict[str, BacktestPosition] = {}
|
|
self.trades: list[BacktestTrade] = []
|
|
self.snapshots: list[BacktestSnapshot] = []
|
|
self._trade_counter = 0
|
|
self._peak_value = self.config.initial_capital
|
|
self._current_timestamp: Optional[datetime] = None
|
|
self._price_data: dict[str, list[OHLCV]] = {}
|
|
self._current_prices: dict[str, Decimal] = {}
|
|
|
|
def run(
|
|
self,
|
|
price_data: dict[str, list[OHLCV]],
|
|
signals: list[Signal],
|
|
strategy_callback: Optional[Callable[[datetime, dict[str, OHLCV]], list[Signal]]] = None,
|
|
) -> BacktestResult:
|
|
"""Run backtest.
|
|
|
|
Args:
|
|
price_data: Dict of symbol -> list of OHLCV bars
|
|
signals: List of trading signals (pre-generated)
|
|
strategy_callback: Optional callback for dynamic signal generation
|
|
|
|
Returns:
|
|
BacktestResult with all metrics
|
|
"""
|
|
self.reset()
|
|
self._price_data = price_data
|
|
|
|
# Determine date range
|
|
all_timestamps = set()
|
|
for bars in price_data.values():
|
|
for bar in bars:
|
|
all_timestamps.add(bar.timestamp)
|
|
|
|
if not all_timestamps:
|
|
return self._create_result([])
|
|
|
|
sorted_timestamps = sorted(all_timestamps)
|
|
start_date = self.config.start_date or sorted_timestamps[0]
|
|
end_date = self.config.end_date or sorted_timestamps[-1]
|
|
|
|
# Filter timestamps to date range
|
|
timestamps = [t for t in sorted_timestamps if start_date <= t <= end_date]
|
|
|
|
if not timestamps:
|
|
return self._create_result([])
|
|
|
|
# Index signals by timestamp
|
|
signal_index: dict[datetime, list[Signal]] = {}
|
|
for signal in signals:
|
|
if start_date <= signal.timestamp <= end_date:
|
|
if signal.timestamp not in signal_index:
|
|
signal_index[signal.timestamp] = []
|
|
signal_index[signal.timestamp].append(signal)
|
|
|
|
# Main replay loop
|
|
errors = []
|
|
for timestamp in timestamps:
|
|
self._current_timestamp = timestamp
|
|
|
|
# Get current prices
|
|
current_bars = self._get_bars_at(timestamp)
|
|
self._update_prices(current_bars)
|
|
|
|
# Process signals for this timestamp
|
|
timestamp_signals = signal_index.get(timestamp, [])
|
|
|
|
# Also get signals from callback if provided
|
|
if strategy_callback:
|
|
try:
|
|
callback_signals = strategy_callback(timestamp, current_bars)
|
|
timestamp_signals.extend(callback_signals)
|
|
except Exception as e:
|
|
errors.append(f"Strategy callback error at {timestamp}: {e}")
|
|
|
|
# Execute signals
|
|
for signal in timestamp_signals:
|
|
try:
|
|
self._execute_signal(signal, current_bars)
|
|
except Exception as e:
|
|
errors.append(f"Signal execution error at {timestamp}: {e}")
|
|
|
|
# Take snapshot
|
|
self._take_snapshot(timestamp)
|
|
|
|
result = self._create_result(errors)
|
|
return result
|
|
|
|
def _get_bars_at(self, timestamp: datetime) -> dict[str, OHLCV]:
|
|
"""Get OHLCV bars at timestamp.
|
|
|
|
Args:
|
|
timestamp: Target timestamp
|
|
|
|
Returns:
|
|
Dict of symbol -> OHLCV
|
|
"""
|
|
bars = {}
|
|
for symbol, bar_list in self._price_data.items():
|
|
for bar in bar_list:
|
|
if bar.timestamp == timestamp:
|
|
bars[symbol] = bar
|
|
break
|
|
return bars
|
|
|
|
def _update_prices(self, bars: dict[str, OHLCV]) -> None:
|
|
"""Update current prices and position values.
|
|
|
|
Args:
|
|
bars: Current price bars
|
|
"""
|
|
for symbol, bar in bars.items():
|
|
self._current_prices[symbol] = bar.close
|
|
|
|
if symbol in self.positions:
|
|
self.positions[symbol].update_price(bar.close, bar.timestamp)
|
|
|
|
def _execute_signal(self, signal: Signal, bars: dict[str, OHLCV]) -> Optional[BacktestTrade]:
|
|
"""Execute a trading signal.
|
|
|
|
Args:
|
|
signal: Signal to execute
|
|
bars: Current price bars
|
|
|
|
Returns:
|
|
BacktestTrade if executed, None if rejected
|
|
"""
|
|
symbol = signal.symbol
|
|
|
|
# Check if we have price data
|
|
if symbol not in bars:
|
|
logger.warning(f"No price data for {symbol} at {signal.timestamp}")
|
|
return None
|
|
|
|
bar = bars[symbol]
|
|
|
|
# Determine quantity
|
|
quantity = self._calculate_quantity(signal, bar)
|
|
if quantity == ZERO:
|
|
return None
|
|
|
|
# Get execution price with slippage
|
|
base_price = bar.close
|
|
if signal.order_type == OrderType.LIMIT:
|
|
base_price = signal.price
|
|
|
|
slippage = self.config.slippage_model.calculate(
|
|
base_price, quantity, signal.side, bar.volume
|
|
)
|
|
|
|
if signal.side == OrderSide.BUY:
|
|
exec_price = base_price + slippage
|
|
else:
|
|
exec_price = base_price - slippage
|
|
|
|
# Calculate trade value and commission
|
|
trade_value = exec_price * quantity
|
|
commission = self.config.commission_model.calculate(
|
|
exec_price, quantity, trade_value
|
|
)
|
|
|
|
# Check if we can afford the trade
|
|
if signal.side == OrderSide.BUY:
|
|
total_cost = trade_value + commission
|
|
if total_cost > self.cash:
|
|
# Reduce quantity to what we can afford
|
|
available = self.cash - commission
|
|
if available <= ZERO:
|
|
return None
|
|
quantity = (available / exec_price).quantize(Decimal("1"))
|
|
if quantity <= ZERO:
|
|
return None
|
|
trade_value = exec_price * quantity
|
|
commission = self.config.commission_model.calculate(
|
|
exec_price, quantity, trade_value
|
|
)
|
|
total_cost = trade_value + commission
|
|
|
|
self.cash -= total_cost
|
|
else:
|
|
# Sell - check position
|
|
current_position = self.positions.get(symbol)
|
|
if current_position is None or current_position.quantity <= ZERO:
|
|
if not self.config.allow_shorting:
|
|
return None
|
|
elif quantity > current_position.quantity:
|
|
# Can only sell what we have
|
|
quantity = current_position.quantity
|
|
trade_value = exec_price * quantity
|
|
commission = self.config.commission_model.calculate(
|
|
exec_price, quantity, trade_value
|
|
)
|
|
|
|
self.cash += trade_value - commission
|
|
|
|
# Update position
|
|
pnl = self._update_position(signal, quantity, exec_price)
|
|
|
|
# Create trade record
|
|
self._trade_counter += 1
|
|
trade = BacktestTrade(
|
|
trade_id=f"BT-{self._trade_counter:06d}",
|
|
timestamp=signal.timestamp,
|
|
symbol=symbol,
|
|
side=signal.side,
|
|
quantity=quantity,
|
|
price=exec_price,
|
|
base_price=base_price,
|
|
slippage=slippage * quantity,
|
|
commission=commission,
|
|
trade_value=trade_value,
|
|
signal_confidence=signal.confidence,
|
|
position_after=self.positions.get(symbol, BacktestPosition(symbol)).quantity,
|
|
cash_after=self.cash,
|
|
pnl=pnl,
|
|
)
|
|
|
|
self.trades.append(trade)
|
|
return trade
|
|
|
|
def _calculate_quantity(self, signal: Signal, bar: OHLCV) -> Decimal:
|
|
"""Calculate trade quantity based on position sizing.
|
|
|
|
Args:
|
|
signal: Trading signal
|
|
bar: Current price bar
|
|
|
|
Returns:
|
|
Quantity to trade
|
|
"""
|
|
if signal.quantity > ZERO:
|
|
return signal.quantity
|
|
|
|
# Position sizing based on config
|
|
portfolio_value = self._get_portfolio_value()
|
|
max_position_value = portfolio_value * self.config.max_position_pct / HUNDRED
|
|
|
|
if self.config.position_sizing == "equal":
|
|
# Equal weight for each position
|
|
num_positions = max(len(self.positions), 5) # Assume at least 5 positions
|
|
target_value = portfolio_value / Decimal(num_positions)
|
|
target_value = min(target_value, max_position_value)
|
|
else:
|
|
target_value = max_position_value
|
|
|
|
# Check minimum trade value
|
|
if target_value < self.config.min_trade_value:
|
|
return ZERO
|
|
|
|
quantity = (target_value / bar.close).quantize(Decimal("1"))
|
|
return max(quantity, ZERO)
|
|
|
|
def _update_position(
|
|
self,
|
|
signal: Signal,
|
|
quantity: Decimal,
|
|
price: Decimal,
|
|
) -> Decimal:
|
|
"""Update position after trade.
|
|
|
|
Args:
|
|
signal: Trading signal
|
|
quantity: Trade quantity
|
|
price: Execution price
|
|
|
|
Returns:
|
|
Realized P&L
|
|
"""
|
|
symbol = signal.symbol
|
|
pnl = ZERO
|
|
|
|
if symbol not in self.positions:
|
|
self.positions[symbol] = BacktestPosition(
|
|
symbol=symbol,
|
|
opened_at=signal.timestamp,
|
|
)
|
|
|
|
position = self.positions[symbol]
|
|
|
|
if signal.side == OrderSide.BUY:
|
|
# Buying
|
|
if position.quantity >= ZERO:
|
|
# Adding to long or opening new long
|
|
total_cost = position.quantity * position.average_cost + quantity * price
|
|
new_quantity = position.quantity + quantity
|
|
position.average_cost = total_cost / new_quantity if new_quantity > ZERO else ZERO
|
|
position.quantity = new_quantity
|
|
else:
|
|
# Covering short
|
|
pnl = (position.average_cost - price) * min(quantity, abs(position.quantity))
|
|
position.realized_pnl += pnl
|
|
position.quantity += quantity
|
|
else:
|
|
# Selling
|
|
if position.quantity > ZERO:
|
|
# Closing long
|
|
pnl = (price - position.average_cost) * min(quantity, position.quantity)
|
|
position.realized_pnl += pnl
|
|
position.quantity -= quantity
|
|
else:
|
|
# Adding to short or opening new short
|
|
total_cost = abs(position.quantity) * position.average_cost + quantity * price
|
|
new_quantity = position.quantity - quantity
|
|
position.average_cost = total_cost / abs(new_quantity) if new_quantity != ZERO else price
|
|
position.quantity = new_quantity
|
|
|
|
position.last_updated = signal.timestamp
|
|
|
|
# Clean up closed positions
|
|
if position.quantity == ZERO:
|
|
del self.positions[symbol]
|
|
|
|
return pnl
|
|
|
|
def _get_portfolio_value(self) -> Decimal:
|
|
"""Get total portfolio value.
|
|
|
|
Returns:
|
|
Total value (cash + positions)
|
|
"""
|
|
positions_value = sum(p.market_value for p in self.positions.values())
|
|
return self.cash + positions_value
|
|
|
|
def _take_snapshot(self, timestamp: datetime) -> None:
|
|
"""Take portfolio snapshot.
|
|
|
|
Args:
|
|
timestamp: Snapshot timestamp
|
|
"""
|
|
positions_value = sum(p.market_value for p in self.positions.values())
|
|
total_value = self.cash + positions_value
|
|
|
|
# Update peak and drawdown
|
|
if total_value > self._peak_value:
|
|
self._peak_value = total_value
|
|
|
|
drawdown = (self._peak_value - total_value) / self._peak_value * HUNDRED if self._peak_value > ZERO else ZERO
|
|
|
|
snapshot = BacktestSnapshot(
|
|
timestamp=timestamp,
|
|
cash=self.cash,
|
|
positions_value=positions_value,
|
|
total_value=total_value,
|
|
positions={k: BacktestPosition(
|
|
symbol=v.symbol,
|
|
quantity=v.quantity,
|
|
average_cost=v.average_cost,
|
|
current_price=v.current_price,
|
|
unrealized_pnl=v.unrealized_pnl,
|
|
realized_pnl=v.realized_pnl,
|
|
) for k, v in self.positions.items()},
|
|
drawdown=drawdown,
|
|
peak_value=self._peak_value,
|
|
)
|
|
|
|
self.snapshots.append(snapshot)
|
|
|
|
def _create_result(self, errors: list[str]) -> BacktestResult:
|
|
"""Create backtest result with calculated metrics.
|
|
|
|
Args:
|
|
errors: List of errors during backtest
|
|
|
|
Returns:
|
|
Complete BacktestResult
|
|
"""
|
|
if not self.snapshots:
|
|
return BacktestResult(
|
|
config=self.config,
|
|
initial_capital=self.config.initial_capital,
|
|
final_value=self.config.initial_capital,
|
|
errors=errors,
|
|
)
|
|
|
|
# Basic metrics
|
|
start_date = self.snapshots[0].timestamp
|
|
end_date = self.snapshots[-1].timestamp
|
|
final_value = self.snapshots[-1].total_value
|
|
total_return = (final_value - self.config.initial_capital) / self.config.initial_capital * HUNDRED
|
|
|
|
# Calculate trading days and annualized return
|
|
trading_days = len(self.snapshots)
|
|
years = Decimal(str((end_date - start_date).days)) / Decimal("365")
|
|
if years > ZERO:
|
|
annualized_return = ((final_value / self.config.initial_capital) ** (ONE / years) - ONE) * HUNDRED
|
|
else:
|
|
annualized_return = ZERO
|
|
|
|
# Calculate daily returns
|
|
daily_returns = []
|
|
for i in range(1, len(self.snapshots)):
|
|
prev_value = self.snapshots[i - 1].total_value
|
|
curr_value = self.snapshots[i].total_value
|
|
if prev_value > ZERO:
|
|
daily_returns.append((curr_value - prev_value) / prev_value)
|
|
else:
|
|
daily_returns.append(ZERO)
|
|
|
|
# Sharpe ratio
|
|
if daily_returns:
|
|
avg_return = sum(daily_returns) / len(daily_returns)
|
|
variance = sum((r - avg_return) ** 2 for r in daily_returns) / len(daily_returns)
|
|
std_dev = variance ** Decimal("0.5")
|
|
daily_rf = self.config.risk_free_rate / Decimal("252")
|
|
if std_dev > ZERO:
|
|
sharpe_ratio = (avg_return - daily_rf) / std_dev * Decimal("252").sqrt()
|
|
else:
|
|
sharpe_ratio = ZERO
|
|
else:
|
|
sharpe_ratio = ZERO
|
|
|
|
# Sortino ratio (downside deviation)
|
|
negative_returns = [r for r in daily_returns if r < ZERO]
|
|
if negative_returns:
|
|
downside_variance = sum(r ** 2 for r in negative_returns) / len(negative_returns)
|
|
downside_dev = downside_variance ** Decimal("0.5")
|
|
daily_rf = self.config.risk_free_rate / Decimal("252")
|
|
if downside_dev > ZERO:
|
|
avg_return = sum(daily_returns) / len(daily_returns) if daily_returns else ZERO
|
|
sortino_ratio = (avg_return - daily_rf) / downside_dev * Decimal("252").sqrt()
|
|
else:
|
|
sortino_ratio = ZERO
|
|
else:
|
|
sortino_ratio = ZERO
|
|
|
|
# Maximum drawdown
|
|
max_drawdown = max((s.drawdown for s in self.snapshots), default=ZERO)
|
|
|
|
# Trade statistics
|
|
total_trades = len(self.trades)
|
|
winning_trades = sum(1 for t in self.trades if t.pnl > ZERO)
|
|
losing_trades = sum(1 for t in self.trades if t.pnl < ZERO)
|
|
win_rate = Decimal(str(winning_trades)) / Decimal(str(total_trades)) * HUNDRED if total_trades > 0 else ZERO
|
|
|
|
wins = [t.pnl for t in self.trades if t.pnl > ZERO]
|
|
losses = [t.pnl for t in self.trades if t.pnl < ZERO]
|
|
|
|
avg_win = sum(wins) / len(wins) if wins else ZERO
|
|
avg_loss = sum(losses) / len(losses) if losses else ZERO
|
|
max_win = max(wins) if wins else ZERO
|
|
max_loss = min(losses) if losses else ZERO # Most negative
|
|
|
|
total_wins = sum(wins)
|
|
total_losses = abs(sum(losses))
|
|
profit_factor = total_wins / total_losses if total_losses > ZERO else ZERO
|
|
|
|
avg_trade_pnl = sum(t.pnl for t in self.trades) / total_trades if total_trades > 0 else ZERO
|
|
|
|
# Total costs
|
|
total_commission = sum(t.commission for t in self.trades)
|
|
total_slippage = sum(t.slippage for t in self.trades)
|
|
|
|
return BacktestResult(
|
|
config=self.config,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
initial_capital=self.config.initial_capital,
|
|
final_value=final_value,
|
|
total_return=total_return,
|
|
annualized_return=annualized_return,
|
|
sharpe_ratio=sharpe_ratio,
|
|
sortino_ratio=sortino_ratio,
|
|
max_drawdown=max_drawdown,
|
|
win_rate=win_rate,
|
|
profit_factor=profit_factor,
|
|
total_trades=total_trades,
|
|
winning_trades=winning_trades,
|
|
losing_trades=losing_trades,
|
|
avg_trade_pnl=avg_trade_pnl,
|
|
avg_win=avg_win,
|
|
avg_loss=avg_loss,
|
|
max_win=max_win,
|
|
max_loss=max_loss,
|
|
total_commission=total_commission,
|
|
total_slippage=total_slippage,
|
|
trades=self.trades,
|
|
snapshots=self.snapshots,
|
|
daily_returns=daily_returns,
|
|
errors=errors,
|
|
)
|
|
|
|
def get_position(self, symbol: str) -> Optional[BacktestPosition]:
|
|
"""Get current position for symbol.
|
|
|
|
Args:
|
|
symbol: Symbol to look up
|
|
|
|
Returns:
|
|
Position if exists
|
|
"""
|
|
return self.positions.get(symbol)
|
|
|
|
def get_cash(self) -> Decimal:
|
|
"""Get current cash balance.
|
|
|
|
Returns:
|
|
Cash balance
|
|
"""
|
|
return self.cash
|
|
|
|
def get_portfolio_value(self) -> Decimal:
|
|
"""Get current portfolio value.
|
|
|
|
Returns:
|
|
Total portfolio value
|
|
"""
|
|
return self._get_portfolio_value()
|
|
|
|
|
|
# ============================================================================
|
|
# Factory Functions
|
|
# ============================================================================
|
|
|
|
def create_backtest_engine(
|
|
initial_capital: Decimal = Decimal("100000"),
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
slippage: Optional[SlippageModel] = None,
|
|
commission: Optional[CommissionModel] = None,
|
|
**kwargs,
|
|
) -> BacktestEngine:
|
|
"""Create a configured backtest engine.
|
|
|
|
Args:
|
|
initial_capital: Starting capital
|
|
start_date: Backtest start date
|
|
end_date: Backtest end date
|
|
slippage: Slippage model
|
|
commission: Commission model
|
|
**kwargs: Additional config options
|
|
|
|
Returns:
|
|
Configured BacktestEngine
|
|
"""
|
|
config = BacktestConfig(
|
|
initial_capital=initial_capital,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
slippage_model=slippage or NoSlippage(),
|
|
commission_model=commission or NoCommission(),
|
|
**{k: v for k, v in kwargs.items() if hasattr(BacktestConfig, k)},
|
|
)
|
|
|
|
return BacktestEngine(config)
|