TradingAgents/tradingagents/backtest/strategy.py

488 lines
14 KiB
Python

"""
Strategy interface for backtesting.
This module provides abstract base classes and utilities for implementing
trading strategies, including TradingAgents integration.
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
from typing import Dict, List, Optional, Any, Tuple
import pandas as pd
from .execution import Order, OrderSide, create_market_order
from .exceptions import StrategyError, StrategyInitializationError
logger = logging.getLogger(__name__)
@dataclass
class Signal:
"""
Trading signal generated by a strategy.
Attributes:
ticker: Security ticker
timestamp: Signal timestamp
action: Action ('buy', 'sell', 'hold')
quantity: Suggested quantity (None = let position sizer decide)
confidence: Signal confidence (0.0 to 1.0)
price_target: Optional price target
stop_loss: Optional stop loss
metadata: Additional signal metadata
"""
ticker: str
timestamp: datetime
action: str # 'buy', 'sell', 'hold'
quantity: Optional[Decimal] = None
confidence: float = 1.0
price_target: Optional[Decimal] = None
stop_loss: Optional[Decimal] = None
metadata: Dict[str, Any] = None
def __post_init__(self):
"""Validate signal."""
if self.action not in ['buy', 'sell', 'hold']:
raise ValueError(f"Invalid action: {self.action}")
if not (0.0 <= self.confidence <= 1.0):
raise ValueError(f"Confidence must be between 0 and 1: {self.confidence}")
if self.metadata is None:
self.metadata = {}
def to_dict(self) -> Dict[str, Any]:
"""Convert signal to dictionary."""
return {
'ticker': self.ticker,
'timestamp': self.timestamp,
'action': self.action,
'quantity': float(self.quantity) if self.quantity else None,
'confidence': self.confidence,
'price_target': float(self.price_target) if self.price_target else None,
'stop_loss': float(self.stop_loss) if self.stop_loss else None,
'metadata': self.metadata,
}
@dataclass
class Position:
"""
Current position in a security.
Attributes:
ticker: Security ticker
quantity: Position size (positive = long, negative = short)
avg_entry_price: Average entry price
current_price: Current market price
unrealized_pnl: Unrealized P&L
entry_timestamp: First entry timestamp
"""
ticker: str
quantity: Decimal
avg_entry_price: Decimal
current_price: Decimal
unrealized_pnl: Decimal
entry_timestamp: datetime
@property
def market_value(self) -> Decimal:
"""Get current market value of position."""
return self.quantity * self.current_price
@property
def is_long(self) -> bool:
"""Check if position is long."""
return self.quantity > 0
@property
def is_short(self) -> bool:
"""Check if position is short."""
return self.quantity < 0
@property
def is_flat(self) -> bool:
"""Check if position is flat (no position)."""
return self.quantity == 0
def to_dict(self) -> Dict[str, Any]:
"""Convert position to dictionary."""
return {
'ticker': self.ticker,
'quantity': float(self.quantity),
'avg_entry_price': float(self.avg_entry_price),
'current_price': float(self.current_price),
'unrealized_pnl': float(self.unrealized_pnl),
'market_value': float(self.market_value),
'entry_timestamp': self.entry_timestamp,
}
class BaseStrategy(ABC):
"""
Abstract base class for trading strategies.
All strategies must implement the generate_signals method.
"""
def __init__(self, name: str = "BaseStrategy", params: Optional[Dict[str, Any]] = None):
"""
Initialize strategy.
Args:
name: Strategy name
params: Strategy parameters
"""
self.name = name
self.params = params or {}
self._is_initialized = False
logger.info(f"Strategy '{self.name}' created")
@abstractmethod
def generate_signals(
self,
timestamp: datetime,
data: Dict[str, pd.DataFrame],
positions: Dict[str, Position],
portfolio_value: Decimal,
) -> List[Signal]:
"""
Generate trading signals.
Args:
timestamp: Current timestamp
data: Historical data for all tickers (ticker -> DataFrame)
positions: Current positions (ticker -> Position)
portfolio_value: Current portfolio value
Returns:
List of signals
"""
pass
def initialize(self, tickers: List[str], start_date: datetime) -> None:
"""
Initialize strategy before backtesting.
Args:
tickers: List of tickers to trade
start_date: Backtest start date
"""
self._is_initialized = True
logger.info(f"Strategy '{self.name}' initialized with {len(tickers)} tickers")
def on_fill(self, fill: 'Fill') -> None:
"""
Called when an order is filled.
Args:
fill: Fill information
"""
pass
def on_bar(
self,
timestamp: datetime,
data: Dict[str, pd.DataFrame],
) -> None:
"""
Called on each bar/period.
Args:
timestamp: Current timestamp
data: Current bar data
"""
pass
def finalize(self) -> None:
"""Called at the end of backtesting."""
logger.info(f"Strategy '{self.name}' finalized")
class BuyAndHoldStrategy(BaseStrategy):
"""Simple buy-and-hold strategy for benchmarking."""
def __init__(self):
"""Initialize buy-and-hold strategy."""
super().__init__(name="BuyAndHold")
self._has_bought = False
def generate_signals(
self,
timestamp: datetime,
data: Dict[str, pd.DataFrame],
positions: Dict[str, Position],
portfolio_value: Decimal,
) -> List[Signal]:
"""Generate buy signals on first bar, then hold."""
if self._has_bought:
return []
signals = []
for ticker in data.keys():
if ticker not in positions or positions[ticker].is_flat:
signals.append(Signal(
ticker=ticker,
timestamp=timestamp,
action='buy',
confidence=1.0,
))
self._has_bought = True
return signals
class SimpleMovingAverageStrategy(BaseStrategy):
"""
Simple moving average crossover strategy.
Buys when short MA crosses above long MA, sells when it crosses below.
"""
def __init__(self, short_window: int = 50, long_window: int = 200):
"""
Initialize SMA strategy.
Args:
short_window: Short moving average window
long_window: Long moving average window
"""
super().__init__(
name="SMA_Crossover",
params={'short_window': short_window, 'long_window': long_window}
)
self.short_window = short_window
self.long_window = long_window
def generate_signals(
self,
timestamp: datetime,
data: Dict[str, pd.DataFrame],
positions: Dict[str, Position],
portfolio_value: Decimal,
) -> List[Signal]:
"""Generate signals based on SMA crossover."""
signals = []
for ticker, df in data.items():
if len(df) < self.long_window:
continue
# Calculate moving averages
short_ma = df['close'].rolling(self.short_window).mean()
long_ma = df['close'].rolling(self.long_window).mean()
# Get current and previous values
current_short = short_ma.iloc[-1]
current_long = long_ma.iloc[-1]
prev_short = short_ma.iloc[-2] if len(short_ma) > 1 else None
prev_long = long_ma.iloc[-2] if len(long_ma) > 1 else None
if prev_short is None or prev_long is None:
continue
# Check for crossover
current_position = positions.get(ticker)
# Bullish crossover
if prev_short <= prev_long and current_short > current_long:
if not current_position or current_position.is_flat:
signals.append(Signal(
ticker=ticker,
timestamp=timestamp,
action='buy',
confidence=0.8,
metadata={'signal_type': 'bullish_crossover'}
))
# Bearish crossover
elif prev_short >= prev_long and current_short < current_long:
if current_position and not current_position.is_flat:
signals.append(Signal(
ticker=ticker,
timestamp=timestamp,
action='sell',
confidence=0.8,
metadata={'signal_type': 'bearish_crossover'}
))
return signals
class PositionSizer:
"""
Position sizing logic.
Determines how much capital to allocate to each trade.
"""
def __init__(self, method: str = 'equal_weight', params: Optional[Dict[str, Any]] = None):
"""
Initialize position sizer.
Args:
method: Sizing method ('equal_weight', 'fixed_amount', 'risk_parity', etc.)
params: Method-specific parameters
"""
self.method = method
self.params = params or {}
def calculate_position_size(
self,
signal: Signal,
portfolio_value: Decimal,
current_price: Decimal,
max_position_size: Optional[Decimal] = None,
) -> Decimal:
"""
Calculate position size for a signal.
Args:
signal: Trading signal
portfolio_value: Current portfolio value
current_price: Current price
max_position_size: Maximum position size as fraction of portfolio
Returns:
Position size (number of shares)
"""
if signal.quantity is not None:
return signal.quantity
if self.method == 'equal_weight':
return self._equal_weight(portfolio_value, current_price, max_position_size)
elif self.method == 'fixed_amount':
fixed_amount = self.params.get('amount', Decimal('10000'))
return fixed_amount / current_price
elif self.method == 'confidence_weighted':
return self._confidence_weighted(signal, portfolio_value, current_price, max_position_size)
else:
raise ValueError(f"Unknown position sizing method: {self.method}")
def _equal_weight(
self,
portfolio_value: Decimal,
current_price: Decimal,
max_position_size: Optional[Decimal],
) -> Decimal:
"""Equal weight position sizing."""
num_positions = self.params.get('num_positions', 10)
allocation = portfolio_value / Decimal(str(num_positions))
if max_position_size:
allocation = min(allocation, portfolio_value * max_position_size)
return (allocation / current_price).quantize(Decimal('1'))
def _confidence_weighted(
self,
signal: Signal,
portfolio_value: Decimal,
current_price: Decimal,
max_position_size: Optional[Decimal],
) -> Decimal:
"""Confidence-weighted position sizing."""
base_allocation = portfolio_value * Decimal('0.1') # 10% base
weighted_allocation = base_allocation * Decimal(str(signal.confidence))
if max_position_size:
weighted_allocation = min(weighted_allocation, portfolio_value * max_position_size)
return (weighted_allocation / current_price).quantize(Decimal('1'))
class RiskManager:
"""
Risk management logic.
Enforces risk controls like stop losses, position limits, etc.
"""
def __init__(
self,
max_position_size: Optional[Decimal] = None,
max_leverage: Decimal = Decimal('1.0'),
stop_loss_pct: Optional[Decimal] = None,
):
"""
Initialize risk manager.
Args:
max_position_size: Maximum position size as fraction of portfolio
max_leverage: Maximum leverage allowed
stop_loss_pct: Stop loss percentage (e.g., 0.05 for 5%)
"""
self.max_position_size = max_position_size
self.max_leverage = max_leverage
self.stop_loss_pct = stop_loss_pct
def check_signal(
self,
signal: Signal,
positions: Dict[str, Position],
portfolio_value: Decimal,
) -> Tuple[bool, Optional[str]]:
"""
Check if signal passes risk checks.
Args:
signal: Trading signal
positions: Current positions
portfolio_value: Current portfolio value
Returns:
(approved, reason) tuple
"""
# Check position limit
if self.max_position_size:
position = positions.get(signal.ticker)
if position and not position.is_flat:
position_pct = abs(position.market_value) / portfolio_value
if position_pct >= self.max_position_size:
return False, "Position size limit reached"
# Check leverage
total_exposure = sum(
abs(pos.market_value) for pos in positions.values()
)
leverage = total_exposure / portfolio_value
if leverage >= self.max_leverage:
return False, "Leverage limit reached"
return True, None
def check_stop_loss(
self,
position: Position,
) -> bool:
"""
Check if position hit stop loss.
Args:
position: Position to check
Returns:
True if stop loss triggered
"""
if not self.stop_loss_pct or position.is_flat:
return False
loss_pct = (position.current_price - position.avg_entry_price) / position.avg_entry_price
if position.is_long and loss_pct <= -self.stop_loss_pct:
return True
if position.is_short and loss_pct >= self.stop_loss_pct:
return True
return False