488 lines
14 KiB
Python
488 lines
14 KiB
Python
"""
|
|
Strategy interface for backtesting.
|
|
|
|
This module provides abstract base classes and utilities for implementing
|
|
trading strategies, including TradingAgents integration.
|
|
"""
|
|
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from decimal import Decimal
|
|
from typing import Dict, List, Optional, Any, Tuple
|
|
|
|
import pandas as pd
|
|
|
|
from .execution import Order, OrderSide, create_market_order
|
|
from .exceptions import StrategyError, StrategyInitializationError
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Signal:
|
|
"""
|
|
Trading signal generated by a strategy.
|
|
|
|
Attributes:
|
|
ticker: Security ticker
|
|
timestamp: Signal timestamp
|
|
action: Action ('buy', 'sell', 'hold')
|
|
quantity: Suggested quantity (None = let position sizer decide)
|
|
confidence: Signal confidence (0.0 to 1.0)
|
|
price_target: Optional price target
|
|
stop_loss: Optional stop loss
|
|
metadata: Additional signal metadata
|
|
"""
|
|
ticker: str
|
|
timestamp: datetime
|
|
action: str # 'buy', 'sell', 'hold'
|
|
quantity: Optional[Decimal] = None
|
|
confidence: float = 1.0
|
|
price_target: Optional[Decimal] = None
|
|
stop_loss: Optional[Decimal] = None
|
|
metadata: Dict[str, Any] = None
|
|
|
|
def __post_init__(self):
|
|
"""Validate signal."""
|
|
if self.action not in ['buy', 'sell', 'hold']:
|
|
raise ValueError(f"Invalid action: {self.action}")
|
|
|
|
if not (0.0 <= self.confidence <= 1.0):
|
|
raise ValueError(f"Confidence must be between 0 and 1: {self.confidence}")
|
|
|
|
if self.metadata is None:
|
|
self.metadata = {}
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert signal to dictionary."""
|
|
return {
|
|
'ticker': self.ticker,
|
|
'timestamp': self.timestamp,
|
|
'action': self.action,
|
|
'quantity': float(self.quantity) if self.quantity else None,
|
|
'confidence': self.confidence,
|
|
'price_target': float(self.price_target) if self.price_target else None,
|
|
'stop_loss': float(self.stop_loss) if self.stop_loss else None,
|
|
'metadata': self.metadata,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class Position:
|
|
"""
|
|
Current position in a security.
|
|
|
|
Attributes:
|
|
ticker: Security ticker
|
|
quantity: Position size (positive = long, negative = short)
|
|
avg_entry_price: Average entry price
|
|
current_price: Current market price
|
|
unrealized_pnl: Unrealized P&L
|
|
entry_timestamp: First entry timestamp
|
|
"""
|
|
ticker: str
|
|
quantity: Decimal
|
|
avg_entry_price: Decimal
|
|
current_price: Decimal
|
|
unrealized_pnl: Decimal
|
|
entry_timestamp: datetime
|
|
|
|
@property
|
|
def market_value(self) -> Decimal:
|
|
"""Get current market value of position."""
|
|
return self.quantity * self.current_price
|
|
|
|
@property
|
|
def is_long(self) -> bool:
|
|
"""Check if position is long."""
|
|
return self.quantity > 0
|
|
|
|
@property
|
|
def is_short(self) -> bool:
|
|
"""Check if position is short."""
|
|
return self.quantity < 0
|
|
|
|
@property
|
|
def is_flat(self) -> bool:
|
|
"""Check if position is flat (no position)."""
|
|
return self.quantity == 0
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert position to dictionary."""
|
|
return {
|
|
'ticker': self.ticker,
|
|
'quantity': float(self.quantity),
|
|
'avg_entry_price': float(self.avg_entry_price),
|
|
'current_price': float(self.current_price),
|
|
'unrealized_pnl': float(self.unrealized_pnl),
|
|
'market_value': float(self.market_value),
|
|
'entry_timestamp': self.entry_timestamp,
|
|
}
|
|
|
|
|
|
class BaseStrategy(ABC):
|
|
"""
|
|
Abstract base class for trading strategies.
|
|
|
|
All strategies must implement the generate_signals method.
|
|
"""
|
|
|
|
def __init__(self, name: str = "BaseStrategy", params: Optional[Dict[str, Any]] = None):
|
|
"""
|
|
Initialize strategy.
|
|
|
|
Args:
|
|
name: Strategy name
|
|
params: Strategy parameters
|
|
"""
|
|
self.name = name
|
|
self.params = params or {}
|
|
self._is_initialized = False
|
|
|
|
logger.info(f"Strategy '{self.name}' created")
|
|
|
|
@abstractmethod
|
|
def generate_signals(
|
|
self,
|
|
timestamp: datetime,
|
|
data: Dict[str, pd.DataFrame],
|
|
positions: Dict[str, Position],
|
|
portfolio_value: Decimal,
|
|
) -> List[Signal]:
|
|
"""
|
|
Generate trading signals.
|
|
|
|
Args:
|
|
timestamp: Current timestamp
|
|
data: Historical data for all tickers (ticker -> DataFrame)
|
|
positions: Current positions (ticker -> Position)
|
|
portfolio_value: Current portfolio value
|
|
|
|
Returns:
|
|
List of signals
|
|
"""
|
|
pass
|
|
|
|
def initialize(self, tickers: List[str], start_date: datetime) -> None:
|
|
"""
|
|
Initialize strategy before backtesting.
|
|
|
|
Args:
|
|
tickers: List of tickers to trade
|
|
start_date: Backtest start date
|
|
"""
|
|
self._is_initialized = True
|
|
logger.info(f"Strategy '{self.name}' initialized with {len(tickers)} tickers")
|
|
|
|
def on_fill(self, fill: 'Fill') -> None:
|
|
"""
|
|
Called when an order is filled.
|
|
|
|
Args:
|
|
fill: Fill information
|
|
"""
|
|
pass
|
|
|
|
def on_bar(
|
|
self,
|
|
timestamp: datetime,
|
|
data: Dict[str, pd.DataFrame],
|
|
) -> None:
|
|
"""
|
|
Called on each bar/period.
|
|
|
|
Args:
|
|
timestamp: Current timestamp
|
|
data: Current bar data
|
|
"""
|
|
pass
|
|
|
|
def finalize(self) -> None:
|
|
"""Called at the end of backtesting."""
|
|
logger.info(f"Strategy '{self.name}' finalized")
|
|
|
|
|
|
class BuyAndHoldStrategy(BaseStrategy):
|
|
"""Simple buy-and-hold strategy for benchmarking."""
|
|
|
|
def __init__(self):
|
|
"""Initialize buy-and-hold strategy."""
|
|
super().__init__(name="BuyAndHold")
|
|
self._has_bought = False
|
|
|
|
def generate_signals(
|
|
self,
|
|
timestamp: datetime,
|
|
data: Dict[str, pd.DataFrame],
|
|
positions: Dict[str, Position],
|
|
portfolio_value: Decimal,
|
|
) -> List[Signal]:
|
|
"""Generate buy signals on first bar, then hold."""
|
|
if self._has_bought:
|
|
return []
|
|
|
|
signals = []
|
|
for ticker in data.keys():
|
|
if ticker not in positions or positions[ticker].is_flat:
|
|
signals.append(Signal(
|
|
ticker=ticker,
|
|
timestamp=timestamp,
|
|
action='buy',
|
|
confidence=1.0,
|
|
))
|
|
|
|
self._has_bought = True
|
|
return signals
|
|
|
|
|
|
class SimpleMovingAverageStrategy(BaseStrategy):
|
|
"""
|
|
Simple moving average crossover strategy.
|
|
|
|
Buys when short MA crosses above long MA, sells when it crosses below.
|
|
"""
|
|
|
|
def __init__(self, short_window: int = 50, long_window: int = 200):
|
|
"""
|
|
Initialize SMA strategy.
|
|
|
|
Args:
|
|
short_window: Short moving average window
|
|
long_window: Long moving average window
|
|
"""
|
|
super().__init__(
|
|
name="SMA_Crossover",
|
|
params={'short_window': short_window, 'long_window': long_window}
|
|
)
|
|
self.short_window = short_window
|
|
self.long_window = long_window
|
|
|
|
def generate_signals(
|
|
self,
|
|
timestamp: datetime,
|
|
data: Dict[str, pd.DataFrame],
|
|
positions: Dict[str, Position],
|
|
portfolio_value: Decimal,
|
|
) -> List[Signal]:
|
|
"""Generate signals based on SMA crossover."""
|
|
signals = []
|
|
|
|
for ticker, df in data.items():
|
|
if len(df) < self.long_window:
|
|
continue
|
|
|
|
# Calculate moving averages
|
|
short_ma = df['close'].rolling(self.short_window).mean()
|
|
long_ma = df['close'].rolling(self.long_window).mean()
|
|
|
|
# Get current and previous values
|
|
current_short = short_ma.iloc[-1]
|
|
current_long = long_ma.iloc[-1]
|
|
prev_short = short_ma.iloc[-2] if len(short_ma) > 1 else None
|
|
prev_long = long_ma.iloc[-2] if len(long_ma) > 1 else None
|
|
|
|
if prev_short is None or prev_long is None:
|
|
continue
|
|
|
|
# Check for crossover
|
|
current_position = positions.get(ticker)
|
|
|
|
# Bullish crossover
|
|
if prev_short <= prev_long and current_short > current_long:
|
|
if not current_position or current_position.is_flat:
|
|
signals.append(Signal(
|
|
ticker=ticker,
|
|
timestamp=timestamp,
|
|
action='buy',
|
|
confidence=0.8,
|
|
metadata={'signal_type': 'bullish_crossover'}
|
|
))
|
|
|
|
# Bearish crossover
|
|
elif prev_short >= prev_long and current_short < current_long:
|
|
if current_position and not current_position.is_flat:
|
|
signals.append(Signal(
|
|
ticker=ticker,
|
|
timestamp=timestamp,
|
|
action='sell',
|
|
confidence=0.8,
|
|
metadata={'signal_type': 'bearish_crossover'}
|
|
))
|
|
|
|
return signals
|
|
|
|
|
|
class PositionSizer:
|
|
"""
|
|
Position sizing logic.
|
|
|
|
Determines how much capital to allocate to each trade.
|
|
"""
|
|
|
|
def __init__(self, method: str = 'equal_weight', params: Optional[Dict[str, Any]] = None):
|
|
"""
|
|
Initialize position sizer.
|
|
|
|
Args:
|
|
method: Sizing method ('equal_weight', 'fixed_amount', 'risk_parity', etc.)
|
|
params: Method-specific parameters
|
|
"""
|
|
self.method = method
|
|
self.params = params or {}
|
|
|
|
def calculate_position_size(
|
|
self,
|
|
signal: Signal,
|
|
portfolio_value: Decimal,
|
|
current_price: Decimal,
|
|
max_position_size: Optional[Decimal] = None,
|
|
) -> Decimal:
|
|
"""
|
|
Calculate position size for a signal.
|
|
|
|
Args:
|
|
signal: Trading signal
|
|
portfolio_value: Current portfolio value
|
|
current_price: Current price
|
|
max_position_size: Maximum position size as fraction of portfolio
|
|
|
|
Returns:
|
|
Position size (number of shares)
|
|
"""
|
|
if signal.quantity is not None:
|
|
return signal.quantity
|
|
|
|
if self.method == 'equal_weight':
|
|
return self._equal_weight(portfolio_value, current_price, max_position_size)
|
|
|
|
elif self.method == 'fixed_amount':
|
|
fixed_amount = self.params.get('amount', Decimal('10000'))
|
|
return fixed_amount / current_price
|
|
|
|
elif self.method == 'confidence_weighted':
|
|
return self._confidence_weighted(signal, portfolio_value, current_price, max_position_size)
|
|
|
|
else:
|
|
raise ValueError(f"Unknown position sizing method: {self.method}")
|
|
|
|
def _equal_weight(
|
|
self,
|
|
portfolio_value: Decimal,
|
|
current_price: Decimal,
|
|
max_position_size: Optional[Decimal],
|
|
) -> Decimal:
|
|
"""Equal weight position sizing."""
|
|
num_positions = self.params.get('num_positions', 10)
|
|
allocation = portfolio_value / Decimal(str(num_positions))
|
|
|
|
if max_position_size:
|
|
allocation = min(allocation, portfolio_value * max_position_size)
|
|
|
|
return (allocation / current_price).quantize(Decimal('1'))
|
|
|
|
def _confidence_weighted(
|
|
self,
|
|
signal: Signal,
|
|
portfolio_value: Decimal,
|
|
current_price: Decimal,
|
|
max_position_size: Optional[Decimal],
|
|
) -> Decimal:
|
|
"""Confidence-weighted position sizing."""
|
|
base_allocation = portfolio_value * Decimal('0.1') # 10% base
|
|
weighted_allocation = base_allocation * Decimal(str(signal.confidence))
|
|
|
|
if max_position_size:
|
|
weighted_allocation = min(weighted_allocation, portfolio_value * max_position_size)
|
|
|
|
return (weighted_allocation / current_price).quantize(Decimal('1'))
|
|
|
|
|
|
class RiskManager:
|
|
"""
|
|
Risk management logic.
|
|
|
|
Enforces risk controls like stop losses, position limits, etc.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_position_size: Optional[Decimal] = None,
|
|
max_leverage: Decimal = Decimal('1.0'),
|
|
stop_loss_pct: Optional[Decimal] = None,
|
|
):
|
|
"""
|
|
Initialize risk manager.
|
|
|
|
Args:
|
|
max_position_size: Maximum position size as fraction of portfolio
|
|
max_leverage: Maximum leverage allowed
|
|
stop_loss_pct: Stop loss percentage (e.g., 0.05 for 5%)
|
|
"""
|
|
self.max_position_size = max_position_size
|
|
self.max_leverage = max_leverage
|
|
self.stop_loss_pct = stop_loss_pct
|
|
|
|
def check_signal(
|
|
self,
|
|
signal: Signal,
|
|
positions: Dict[str, Position],
|
|
portfolio_value: Decimal,
|
|
) -> Tuple[bool, Optional[str]]:
|
|
"""
|
|
Check if signal passes risk checks.
|
|
|
|
Args:
|
|
signal: Trading signal
|
|
positions: Current positions
|
|
portfolio_value: Current portfolio value
|
|
|
|
Returns:
|
|
(approved, reason) tuple
|
|
"""
|
|
# Check position limit
|
|
if self.max_position_size:
|
|
position = positions.get(signal.ticker)
|
|
if position and not position.is_flat:
|
|
position_pct = abs(position.market_value) / portfolio_value
|
|
if position_pct >= self.max_position_size:
|
|
return False, "Position size limit reached"
|
|
|
|
# Check leverage
|
|
total_exposure = sum(
|
|
abs(pos.market_value) for pos in positions.values()
|
|
)
|
|
leverage = total_exposure / portfolio_value
|
|
if leverage >= self.max_leverage:
|
|
return False, "Leverage limit reached"
|
|
|
|
return True, None
|
|
|
|
def check_stop_loss(
|
|
self,
|
|
position: Position,
|
|
) -> bool:
|
|
"""
|
|
Check if position hit stop loss.
|
|
|
|
Args:
|
|
position: Position to check
|
|
|
|
Returns:
|
|
True if stop loss triggered
|
|
"""
|
|
if not self.stop_loss_pct or position.is_flat:
|
|
return False
|
|
|
|
loss_pct = (position.current_price - position.avg_entry_price) / position.avg_entry_price
|
|
|
|
if position.is_long and loss_pct <= -self.stop_loss_pct:
|
|
return True
|
|
|
|
if position.is_short and loss_pct >= self.stop_loss_pct:
|
|
return True
|
|
|
|
return False
|