588 lines
18 KiB
Python
588 lines
18 KiB
Python
"""
|
|
Historical data management for backtesting.
|
|
|
|
This module handles loading, validating, and managing historical price data
|
|
for backtesting, ensuring data quality and preventing look-ahead bias.
|
|
"""
|
|
|
|
import logging
|
|
import warnings
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Union, Tuple
|
|
from decimal import Decimal
|
|
import pickle
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
import yfinance as yf
|
|
from tqdm import tqdm
|
|
|
|
from tradingagents.security.validators import validate_ticker, validate_date
|
|
from .config import BacktestConfig, DataSource
|
|
from .exceptions import (
|
|
DataError,
|
|
DataNotFoundError,
|
|
DataQualityError,
|
|
DataAlignmentError,
|
|
LookAheadBiasError,
|
|
)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HistoricalDataHandler:
|
|
"""
|
|
Manages historical price data for backtesting.
|
|
|
|
This class provides point-in-time data access, ensuring no look-ahead bias
|
|
and handling data quality issues.
|
|
|
|
Attributes:
|
|
config: Backtest configuration
|
|
data: Dictionary mapping tickers to DataFrames with OHLCV data
|
|
current_time: Current simulation time (for look-ahead bias prevention)
|
|
"""
|
|
|
|
def __init__(self, config: BacktestConfig):
|
|
"""
|
|
Initialize the data handler.
|
|
|
|
Args:
|
|
config: Backtest configuration
|
|
"""
|
|
self.config = config
|
|
self.data: Dict[str, pd.DataFrame] = {}
|
|
self.current_time: Optional[datetime] = None
|
|
self._cache_dir = Path(config.cache_dir) if config.cache_dir else None
|
|
|
|
if self._cache_dir:
|
|
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info("HistoricalDataHandler initialized")
|
|
|
|
def load_data(
|
|
self,
|
|
tickers: Union[str, List[str]],
|
|
start_date: Optional[str] = None,
|
|
end_date: Optional[str] = None,
|
|
validate: bool = True,
|
|
) -> None:
|
|
"""
|
|
Load historical data for one or more tickers.
|
|
|
|
Args:
|
|
tickers: Ticker or list of tickers
|
|
start_date: Start date (defaults to config start_date)
|
|
end_date: End date (defaults to config end_date)
|
|
validate: Whether to validate data quality
|
|
|
|
Raises:
|
|
DataNotFoundError: If data cannot be loaded
|
|
DataQualityError: If data fails quality checks
|
|
"""
|
|
if isinstance(tickers, str):
|
|
tickers = [tickers]
|
|
|
|
start_date = start_date or self.config.start_date
|
|
end_date = end_date or self.config.end_date
|
|
|
|
logger.info(f"Loading data for {len(tickers)} ticker(s) from {start_date} to {end_date}")
|
|
|
|
for ticker in tqdm(tickers, desc="Loading data", disable=not self.config.progress_bar):
|
|
# Validate ticker
|
|
ticker = validate_ticker(ticker)
|
|
|
|
# Check cache first
|
|
if self.config.cache_data and self._cache_dir:
|
|
cached_data = self._load_from_cache(ticker, start_date, end_date)
|
|
if cached_data is not None:
|
|
self.data[ticker] = cached_data
|
|
logger.debug(f"Loaded {ticker} from cache")
|
|
continue
|
|
|
|
# Load from source
|
|
try:
|
|
data = self._load_from_source(ticker, start_date, end_date)
|
|
except Exception as e:
|
|
logger.error(f"Failed to load data for {ticker}: {e}")
|
|
raise DataNotFoundError(f"Could not load data for {ticker}: {e}")
|
|
|
|
# Validate data quality
|
|
if validate:
|
|
self._validate_data(ticker, data)
|
|
|
|
# Clean and prepare data
|
|
data = self._prepare_data(data)
|
|
|
|
# Store
|
|
self.data[ticker] = data
|
|
|
|
# Cache if enabled
|
|
if self.config.cache_data and self._cache_dir:
|
|
self._save_to_cache(ticker, data, start_date, end_date)
|
|
|
|
logger.info(f"Successfully loaded data for {len(self.data)} ticker(s)")
|
|
|
|
def _load_from_source(
|
|
self,
|
|
ticker: str,
|
|
start_date: str,
|
|
end_date: str
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Load data from the configured data source.
|
|
|
|
Args:
|
|
ticker: Ticker symbol
|
|
start_date: Start date
|
|
end_date: End date
|
|
|
|
Returns:
|
|
DataFrame with OHLCV data
|
|
"""
|
|
if self.config.data_source == DataSource.YFINANCE:
|
|
return self._load_from_yfinance(ticker, start_date, end_date)
|
|
elif self.config.data_source == DataSource.CSV:
|
|
return self._load_from_csv(ticker, start_date, end_date)
|
|
else:
|
|
raise DataError(f"Unsupported data source: {self.config.data_source}")
|
|
|
|
def _load_from_yfinance(
|
|
self,
|
|
ticker: str,
|
|
start_date: str,
|
|
end_date: str
|
|
) -> pd.DataFrame:
|
|
"""Load data from Yahoo Finance."""
|
|
# Add buffer to account for data availability
|
|
buffer_start = (datetime.strptime(start_date, "%Y-%m-%d") - timedelta(days=5)).strftime("%Y-%m-%d")
|
|
|
|
try:
|
|
stock = yf.Ticker(ticker)
|
|
data = stock.history(start=buffer_start, end=end_date, auto_adjust=False)
|
|
|
|
if data.empty:
|
|
raise DataNotFoundError(f"No data returned for {ticker}")
|
|
|
|
# Standardize column names
|
|
data.columns = [col.lower().replace(' ', '_') for col in data.columns]
|
|
|
|
# Ensure we have required columns
|
|
required_cols = ['open', 'high', 'low', 'close', 'volume']
|
|
if not all(col in data.columns for col in required_cols):
|
|
raise DataQualityError(f"Missing required columns for {ticker}")
|
|
|
|
return data
|
|
|
|
except Exception as e:
|
|
raise DataError(f"Error loading data from yfinance for {ticker}: {e}")
|
|
|
|
def _load_from_csv(
|
|
self,
|
|
ticker: str,
|
|
start_date: str,
|
|
end_date: str
|
|
) -> pd.DataFrame:
|
|
"""Load data from CSV file."""
|
|
csv_path = Path(self.config.custom_params.get('csv_dir', 'data')) / f"{ticker}.csv"
|
|
|
|
if not csv_path.exists():
|
|
raise DataNotFoundError(f"CSV file not found: {csv_path}")
|
|
|
|
try:
|
|
data = pd.read_csv(csv_path, index_col=0, parse_dates=True)
|
|
|
|
# Filter date range
|
|
data = data[(data.index >= start_date) & (data.index <= end_date)]
|
|
|
|
# Standardize column names
|
|
data.columns = [col.lower().replace(' ', '_') for col in data.columns]
|
|
|
|
return data
|
|
|
|
except Exception as e:
|
|
raise DataError(f"Error loading CSV for {ticker}: {e}")
|
|
|
|
def _validate_data(self, ticker: str, data: pd.DataFrame) -> None:
|
|
"""
|
|
Validate data quality.
|
|
|
|
Args:
|
|
ticker: Ticker symbol
|
|
data: DataFrame to validate
|
|
|
|
Raises:
|
|
DataQualityError: If data fails validation
|
|
"""
|
|
if data.empty:
|
|
raise DataQualityError(f"Empty data for {ticker}")
|
|
|
|
# Check for required columns
|
|
required_cols = ['open', 'high', 'low', 'close', 'volume']
|
|
missing_cols = [col for col in required_cols if col not in data.columns]
|
|
if missing_cols:
|
|
raise DataQualityError(f"Missing columns for {ticker}: {missing_cols}")
|
|
|
|
# Check for excessive missing data
|
|
missing_pct = data[required_cols].isnull().sum() / len(data) * 100
|
|
high_missing = missing_pct[missing_pct > 10]
|
|
if not high_missing.empty:
|
|
warnings.warn(
|
|
f"High missing data percentage for {ticker}: {high_missing.to_dict()}",
|
|
UserWarning
|
|
)
|
|
|
|
# Check for price anomalies
|
|
for col in ['open', 'high', 'low', 'close']:
|
|
if (data[col] <= 0).any():
|
|
raise DataQualityError(f"Non-positive prices found in {col} for {ticker}")
|
|
|
|
# Check OHLC relationship
|
|
invalid_ohlc = (
|
|
(data['high'] < data['low']) |
|
|
(data['high'] < data['open']) |
|
|
(data['high'] < data['close']) |
|
|
(data['low'] > data['open']) |
|
|
(data['low'] > data['close'])
|
|
)
|
|
|
|
if invalid_ohlc.any():
|
|
warnings.warn(
|
|
f"Invalid OHLC relationships found for {ticker} on {invalid_ohlc.sum()} days",
|
|
UserWarning
|
|
)
|
|
|
|
# Check for suspicious price movements
|
|
returns = data['close'].pct_change()
|
|
extreme_returns = returns.abs() > 0.5 # 50% in one day
|
|
if extreme_returns.any():
|
|
warnings.warn(
|
|
f"Extreme price movements (>50%) detected for {ticker} on {extreme_returns.sum()} days",
|
|
UserWarning
|
|
)
|
|
|
|
logger.debug(f"Data validation passed for {ticker}")
|
|
|
|
def _prepare_data(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
"""
|
|
Clean and prepare data for backtesting.
|
|
|
|
Args:
|
|
data: Raw data
|
|
|
|
Returns:
|
|
Cleaned data
|
|
"""
|
|
# Ensure datetime index
|
|
if not isinstance(data.index, pd.DatetimeIndex):
|
|
data.index = pd.to_datetime(data.index)
|
|
|
|
# Sort by date
|
|
data = data.sort_index()
|
|
|
|
# Remove duplicates
|
|
data = data[~data.index.duplicated(keep='first')]
|
|
|
|
# Forward fill missing data (conservative approach)
|
|
data = data.fillna(method='ffill')
|
|
|
|
# Handle any remaining NaNs
|
|
data = data.dropna()
|
|
|
|
return data
|
|
|
|
def _load_from_cache(
|
|
self,
|
|
ticker: str,
|
|
start_date: str,
|
|
end_date: str
|
|
) -> Optional[pd.DataFrame]:
|
|
"""Load data from cache if available."""
|
|
cache_file = self._cache_dir / f"{ticker}_{start_date}_{end_date}.pkl"
|
|
|
|
if cache_file.exists():
|
|
try:
|
|
with open(cache_file, 'rb') as f:
|
|
return pickle.load(f)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load cache for {ticker}: {e}")
|
|
|
|
return None
|
|
|
|
def _save_to_cache(
|
|
self,
|
|
ticker: str,
|
|
data: pd.DataFrame,
|
|
start_date: str,
|
|
end_date: str
|
|
) -> None:
|
|
"""Save data to cache."""
|
|
cache_file = self._cache_dir / f"{ticker}_{start_date}_{end_date}.pkl"
|
|
|
|
try:
|
|
with open(cache_file, 'wb') as f:
|
|
pickle.dump(data, f)
|
|
logger.debug(f"Cached data for {ticker}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to save cache for {ticker}: {e}")
|
|
|
|
def get_data_at(
|
|
self,
|
|
ticker: str,
|
|
timestamp: datetime,
|
|
lookback: Optional[int] = None
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Get historical data up to a specific point in time.
|
|
|
|
This method ensures no look-ahead bias by only returning data
|
|
available at the specified timestamp.
|
|
|
|
Args:
|
|
ticker: Ticker symbol
|
|
timestamp: Point in time
|
|
lookback: Number of periods to look back (None = all available)
|
|
|
|
Returns:
|
|
DataFrame with historical data up to timestamp
|
|
|
|
Raises:
|
|
LookAheadBiasError: If timestamp is in the future
|
|
DataNotFoundError: If ticker not loaded
|
|
"""
|
|
if ticker not in self.data:
|
|
raise DataNotFoundError(f"Data not loaded for {ticker}")
|
|
|
|
if self.current_time and timestamp > self.current_time:
|
|
raise LookAheadBiasError(
|
|
f"Requested timestamp {timestamp} is in the future (current: {self.current_time})"
|
|
)
|
|
|
|
# Get data up to timestamp
|
|
data = self.data[ticker]
|
|
historical = data[data.index <= timestamp]
|
|
|
|
if lookback:
|
|
historical = historical.tail(lookback)
|
|
|
|
return historical.copy()
|
|
|
|
def get_price_at(
|
|
self,
|
|
ticker: str,
|
|
timestamp: datetime,
|
|
price_type: str = 'close'
|
|
) -> Decimal:
|
|
"""
|
|
Get price at a specific point in time.
|
|
|
|
Args:
|
|
ticker: Ticker symbol
|
|
timestamp: Point in time
|
|
price_type: Type of price ('open', 'high', 'low', 'close')
|
|
|
|
Returns:
|
|
Price as Decimal
|
|
|
|
Raises:
|
|
DataNotFoundError: If data not available
|
|
"""
|
|
data = self.get_data_at(ticker, timestamp, lookback=1)
|
|
|
|
if data.empty:
|
|
raise DataNotFoundError(f"No data available for {ticker} at {timestamp}")
|
|
|
|
price = data.iloc[-1][price_type]
|
|
return Decimal(str(price))
|
|
|
|
def set_current_time(self, timestamp: datetime) -> None:
|
|
"""
|
|
Set the current simulation time.
|
|
|
|
This is critical for preventing look-ahead bias.
|
|
|
|
Args:
|
|
timestamp: Current simulation timestamp
|
|
"""
|
|
if self.current_time and timestamp < self.current_time:
|
|
logger.warning(f"Time moving backwards: {self.current_time} -> {timestamp}")
|
|
|
|
self.current_time = timestamp
|
|
logger.debug(f"Current time set to {timestamp}")
|
|
|
|
def align_data(
|
|
self,
|
|
tickers: Optional[List[str]] = None,
|
|
method: str = 'inner'
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Align data across multiple tickers.
|
|
|
|
Args:
|
|
tickers: List of tickers to align (None = all loaded)
|
|
method: Alignment method ('inner', 'outer', 'left', 'right')
|
|
|
|
Returns:
|
|
DataFrame with aligned close prices
|
|
|
|
Raises:
|
|
DataAlignmentError: If alignment fails
|
|
"""
|
|
if tickers is None:
|
|
tickers = list(self.data.keys())
|
|
|
|
if not tickers:
|
|
raise DataAlignmentError("No tickers to align")
|
|
|
|
try:
|
|
# Get close prices for all tickers
|
|
prices = pd.DataFrame()
|
|
|
|
for ticker in tickers:
|
|
if ticker not in self.data:
|
|
raise DataNotFoundError(f"Data not loaded for {ticker}")
|
|
|
|
prices[ticker] = self.data[ticker]['close']
|
|
|
|
# Align using specified method
|
|
if method == 'inner':
|
|
prices = prices.dropna()
|
|
elif method == 'outer':
|
|
prices = prices.fillna(method='ffill').fillna(method='bfill')
|
|
elif method in ['left', 'right']:
|
|
raise NotImplementedError(f"Alignment method '{method}' not implemented")
|
|
else:
|
|
raise ValueError(f"Unknown alignment method: {method}")
|
|
|
|
logger.info(f"Aligned {len(tickers)} tickers with {len(prices)} periods")
|
|
return prices
|
|
|
|
except Exception as e:
|
|
raise DataAlignmentError(f"Failed to align data: {e}")
|
|
|
|
def get_trading_days(
|
|
self,
|
|
start_date: Optional[str] = None,
|
|
end_date: Optional[str] = None
|
|
) -> pd.DatetimeIndex:
|
|
"""
|
|
Get trading days in the backtest period.
|
|
|
|
Args:
|
|
start_date: Start date (defaults to config start_date)
|
|
end_date: End date (defaults to config end_date)
|
|
|
|
Returns:
|
|
DatetimeIndex of trading days
|
|
"""
|
|
start_date = start_date or self.config.start_date
|
|
end_date = end_date or self.config.end_date
|
|
|
|
if not self.data:
|
|
raise DataError("No data loaded")
|
|
|
|
# Use first ticker's index as reference
|
|
reference_ticker = list(self.data.keys())[0]
|
|
all_dates = self.data[reference_ticker].index
|
|
|
|
# Filter to date range
|
|
trading_days = all_dates[
|
|
(all_dates >= start_date) & (all_dates <= end_date)
|
|
]
|
|
|
|
return trading_days
|
|
|
|
def check_survivor_bias(self, tickers: List[str]) -> None:
|
|
"""
|
|
Warn if using current constituents for historical backtest.
|
|
|
|
Args:
|
|
tickers: List of tickers being tested
|
|
"""
|
|
warnings.warn(
|
|
"SURVIVOR BIAS WARNING: Ensure the ticker list represents "
|
|
"securities that existed throughout the backtest period. "
|
|
"Using current index constituents for historical backtests "
|
|
"can lead to survivorship bias.",
|
|
UserWarning
|
|
)
|
|
logger.warning("Survivor bias check performed")
|
|
|
|
def get_corporate_actions(
|
|
self,
|
|
ticker: str,
|
|
start_date: Optional[str] = None,
|
|
end_date: Optional[str] = None
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Get corporate actions (splits, dividends) for a ticker.
|
|
|
|
Args:
|
|
ticker: Ticker symbol
|
|
start_date: Start date
|
|
end_date: End date
|
|
|
|
Returns:
|
|
DataFrame with corporate actions
|
|
"""
|
|
if ticker not in self.data:
|
|
raise DataNotFoundError(f"Data not loaded for {ticker}")
|
|
|
|
# For yfinance, dividends and splits are included in history
|
|
start_date = start_date or self.config.start_date
|
|
end_date = end_date or self.config.end_date
|
|
|
|
try:
|
|
stock = yf.Ticker(ticker)
|
|
|
|
# Get splits and dividends
|
|
splits = stock.splits
|
|
dividends = stock.dividends
|
|
|
|
# Filter date range
|
|
if not splits.empty:
|
|
splits = splits[(splits.index >= start_date) & (splits.index <= end_date)]
|
|
|
|
if not dividends.empty:
|
|
dividends = dividends[(dividends.index >= start_date) & (dividends.index <= end_date)]
|
|
|
|
# Combine into single DataFrame
|
|
actions = pd.DataFrame()
|
|
if not splits.empty:
|
|
actions['splits'] = splits
|
|
if not dividends.empty:
|
|
actions['dividends'] = dividends
|
|
|
|
return actions
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get corporate actions for {ticker}: {e}")
|
|
return pd.DataFrame()
|
|
|
|
def summary(self) -> Dict[str, Any]:
|
|
"""
|
|
Get summary of loaded data.
|
|
|
|
Returns:
|
|
Dictionary with data summary
|
|
"""
|
|
if not self.data:
|
|
return {"tickers": 0, "message": "No data loaded"}
|
|
|
|
summary_dict = {
|
|
"tickers": len(self.data),
|
|
"ticker_list": list(self.data.keys()),
|
|
}
|
|
|
|
for ticker, data in self.data.items():
|
|
summary_dict[ticker] = {
|
|
"start_date": str(data.index.min().date()),
|
|
"end_date": str(data.index.max().date()),
|
|
"periods": len(data),
|
|
"missing_data_pct": data.isnull().sum().sum() / (len(data) * len(data.columns)) * 100,
|
|
}
|
|
|
|
return summary_dict
|