""" Strategy interface for backtesting. This module provides abstract base classes and utilities for implementing trading strategies, including TradingAgents integration. """ import logging from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime from decimal import Decimal from typing import Dict, List, Optional, Any, Tuple import pandas as pd from .execution import Order, OrderSide, create_market_order from .exceptions import StrategyError, StrategyInitializationError logger = logging.getLogger(__name__) @dataclass class Signal: """ Trading signal generated by a strategy. Attributes: ticker: Security ticker timestamp: Signal timestamp action: Action ('buy', 'sell', 'hold') quantity: Suggested quantity (None = let position sizer decide) confidence: Signal confidence (0.0 to 1.0) price_target: Optional price target stop_loss: Optional stop loss metadata: Additional signal metadata """ ticker: str timestamp: datetime action: str # 'buy', 'sell', 'hold' quantity: Optional[Decimal] = None confidence: float = 1.0 price_target: Optional[Decimal] = None stop_loss: Optional[Decimal] = None metadata: Dict[str, Any] = None def __post_init__(self): """Validate signal.""" if self.action not in ['buy', 'sell', 'hold']: raise ValueError(f"Invalid action: {self.action}") if not (0.0 <= self.confidence <= 1.0): raise ValueError(f"Confidence must be between 0 and 1: {self.confidence}") if self.metadata is None: self.metadata = {} def to_dict(self) -> Dict[str, Any]: """Convert signal to dictionary.""" return { 'ticker': self.ticker, 'timestamp': self.timestamp, 'action': self.action, 'quantity': float(self.quantity) if self.quantity else None, 'confidence': self.confidence, 'price_target': float(self.price_target) if self.price_target else None, 'stop_loss': float(self.stop_loss) if self.stop_loss else None, 'metadata': self.metadata, } @dataclass class Position: """ Current position in a security. Attributes: ticker: Security ticker quantity: Position size (positive = long, negative = short) avg_entry_price: Average entry price current_price: Current market price unrealized_pnl: Unrealized P&L entry_timestamp: First entry timestamp """ ticker: str quantity: Decimal avg_entry_price: Decimal current_price: Decimal unrealized_pnl: Decimal entry_timestamp: datetime @property def market_value(self) -> Decimal: """Get current market value of position.""" return self.quantity * self.current_price @property def is_long(self) -> bool: """Check if position is long.""" return self.quantity > 0 @property def is_short(self) -> bool: """Check if position is short.""" return self.quantity < 0 @property def is_flat(self) -> bool: """Check if position is flat (no position).""" return self.quantity == 0 def to_dict(self) -> Dict[str, Any]: """Convert position to dictionary.""" return { 'ticker': self.ticker, 'quantity': float(self.quantity), 'avg_entry_price': float(self.avg_entry_price), 'current_price': float(self.current_price), 'unrealized_pnl': float(self.unrealized_pnl), 'market_value': float(self.market_value), 'entry_timestamp': self.entry_timestamp, } class BaseStrategy(ABC): """ Abstract base class for trading strategies. All strategies must implement the generate_signals method. """ def __init__(self, name: str = "BaseStrategy", params: Optional[Dict[str, Any]] = None): """ Initialize strategy. Args: name: Strategy name params: Strategy parameters """ self.name = name self.params = params or {} self._is_initialized = False logger.info(f"Strategy '{self.name}' created") @abstractmethod def generate_signals( self, timestamp: datetime, data: Dict[str, pd.DataFrame], positions: Dict[str, Position], portfolio_value: Decimal, ) -> List[Signal]: """ Generate trading signals. Args: timestamp: Current timestamp data: Historical data for all tickers (ticker -> DataFrame) positions: Current positions (ticker -> Position) portfolio_value: Current portfolio value Returns: List of signals """ pass def initialize(self, tickers: List[str], start_date: datetime) -> None: """ Initialize strategy before backtesting. Args: tickers: List of tickers to trade start_date: Backtest start date """ self._is_initialized = True logger.info(f"Strategy '{self.name}' initialized with {len(tickers)} tickers") def on_fill(self, fill: 'Fill') -> None: """ Called when an order is filled. Args: fill: Fill information """ pass def on_bar( self, timestamp: datetime, data: Dict[str, pd.DataFrame], ) -> None: """ Called on each bar/period. Args: timestamp: Current timestamp data: Current bar data """ pass def finalize(self) -> None: """Called at the end of backtesting.""" logger.info(f"Strategy '{self.name}' finalized") class BuyAndHoldStrategy(BaseStrategy): """Simple buy-and-hold strategy for benchmarking.""" def __init__(self): """Initialize buy-and-hold strategy.""" super().__init__(name="BuyAndHold") self._has_bought = False def generate_signals( self, timestamp: datetime, data: Dict[str, pd.DataFrame], positions: Dict[str, Position], portfolio_value: Decimal, ) -> List[Signal]: """Generate buy signals on first bar, then hold.""" if self._has_bought: return [] signals = [] for ticker in data.keys(): if ticker not in positions or positions[ticker].is_flat: signals.append(Signal( ticker=ticker, timestamp=timestamp, action='buy', confidence=1.0, )) self._has_bought = True return signals class SimpleMovingAverageStrategy(BaseStrategy): """ Simple moving average crossover strategy. Buys when short MA crosses above long MA, sells when it crosses below. """ def __init__(self, short_window: int = 50, long_window: int = 200): """ Initialize SMA strategy. Args: short_window: Short moving average window long_window: Long moving average window """ super().__init__( name="SMA_Crossover", params={'short_window': short_window, 'long_window': long_window} ) self.short_window = short_window self.long_window = long_window def generate_signals( self, timestamp: datetime, data: Dict[str, pd.DataFrame], positions: Dict[str, Position], portfolio_value: Decimal, ) -> List[Signal]: """Generate signals based on SMA crossover.""" signals = [] for ticker, df in data.items(): if len(df) < self.long_window: continue # Calculate moving averages short_ma = df['close'].rolling(self.short_window).mean() long_ma = df['close'].rolling(self.long_window).mean() # Get current and previous values current_short = short_ma.iloc[-1] current_long = long_ma.iloc[-1] prev_short = short_ma.iloc[-2] if len(short_ma) > 1 else None prev_long = long_ma.iloc[-2] if len(long_ma) > 1 else None if prev_short is None or prev_long is None: continue # Check for crossover current_position = positions.get(ticker) # Bullish crossover if prev_short <= prev_long and current_short > current_long: if not current_position or current_position.is_flat: signals.append(Signal( ticker=ticker, timestamp=timestamp, action='buy', confidence=0.8, metadata={'signal_type': 'bullish_crossover'} )) # Bearish crossover elif prev_short >= prev_long and current_short < current_long: if current_position and not current_position.is_flat: signals.append(Signal( ticker=ticker, timestamp=timestamp, action='sell', confidence=0.8, metadata={'signal_type': 'bearish_crossover'} )) return signals class PositionSizer: """ Position sizing logic. Determines how much capital to allocate to each trade. """ def __init__(self, method: str = 'equal_weight', params: Optional[Dict[str, Any]] = None): """ Initialize position sizer. Args: method: Sizing method ('equal_weight', 'fixed_amount', 'risk_parity', etc.) params: Method-specific parameters """ self.method = method self.params = params or {} def calculate_position_size( self, signal: Signal, portfolio_value: Decimal, current_price: Decimal, max_position_size: Optional[Decimal] = None, ) -> Decimal: """ Calculate position size for a signal. Args: signal: Trading signal portfolio_value: Current portfolio value current_price: Current price max_position_size: Maximum position size as fraction of portfolio Returns: Position size (number of shares) """ if signal.quantity is not None: return signal.quantity if self.method == 'equal_weight': return self._equal_weight(portfolio_value, current_price, max_position_size) elif self.method == 'fixed_amount': fixed_amount = self.params.get('amount', Decimal('10000')) return fixed_amount / current_price elif self.method == 'confidence_weighted': return self._confidence_weighted(signal, portfolio_value, current_price, max_position_size) else: raise ValueError(f"Unknown position sizing method: {self.method}") def _equal_weight( self, portfolio_value: Decimal, current_price: Decimal, max_position_size: Optional[Decimal], ) -> Decimal: """Equal weight position sizing.""" num_positions = self.params.get('num_positions', 10) allocation = portfolio_value / Decimal(str(num_positions)) if max_position_size: allocation = min(allocation, portfolio_value * max_position_size) return (allocation / current_price).quantize(Decimal('1')) def _confidence_weighted( self, signal: Signal, portfolio_value: Decimal, current_price: Decimal, max_position_size: Optional[Decimal], ) -> Decimal: """Confidence-weighted position sizing.""" base_allocation = portfolio_value * Decimal('0.1') # 10% base weighted_allocation = base_allocation * Decimal(str(signal.confidence)) if max_position_size: weighted_allocation = min(weighted_allocation, portfolio_value * max_position_size) return (weighted_allocation / current_price).quantize(Decimal('1')) class RiskManager: """ Risk management logic. Enforces risk controls like stop losses, position limits, etc. """ def __init__( self, max_position_size: Optional[Decimal] = None, max_leverage: Decimal = Decimal('1.0'), stop_loss_pct: Optional[Decimal] = None, ): """ Initialize risk manager. Args: max_position_size: Maximum position size as fraction of portfolio max_leverage: Maximum leverage allowed stop_loss_pct: Stop loss percentage (e.g., 0.05 for 5%) """ self.max_position_size = max_position_size self.max_leverage = max_leverage self.stop_loss_pct = stop_loss_pct def check_signal( self, signal: Signal, positions: Dict[str, Position], portfolio_value: Decimal, ) -> Tuple[bool, Optional[str]]: """ Check if signal passes risk checks. Args: signal: Trading signal positions: Current positions portfolio_value: Current portfolio value Returns: (approved, reason) tuple """ # Check position limit if self.max_position_size: position = positions.get(signal.ticker) if position and not position.is_flat: position_pct = abs(position.market_value) / portfolio_value if position_pct >= self.max_position_size: return False, "Position size limit reached" # Check leverage total_exposure = sum( abs(pos.market_value) for pos in positions.values() ) leverage = total_exposure / portfolio_value if leverage >= self.max_leverage: return False, "Leverage limit reached" return True, None def check_stop_loss( self, position: Position, ) -> bool: """ Check if position hit stop loss. Args: position: Position to check Returns: True if stop loss triggered """ if not self.stop_loss_pct or position.is_flat: return False loss_pct = (position.current_price - position.avg_entry_price) / position.avg_entry_price if position.is_long and loss_pct <= -self.stop_loss_pct: return True if position.is_short and loss_pct >= self.stop_loss_pct: return True return False