TradingAgents/tradingagents/backtest/data_handler.py

588 lines
18 KiB
Python

"""
Historical data management for backtesting.
This module handles loading, validating, and managing historical price data
for backtesting, ensuring data quality and preventing look-ahead bias.
"""
import logging
import warnings
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Union, Tuple
from decimal import Decimal
import pickle
import pandas as pd
import numpy as np
import yfinance as yf
from tqdm import tqdm
from tradingagents.security.validators import validate_ticker, validate_date
from .config import BacktestConfig, DataSource
from .exceptions import (
DataError,
DataNotFoundError,
DataQualityError,
DataAlignmentError,
LookAheadBiasError,
)
logger = logging.getLogger(__name__)
class HistoricalDataHandler:
"""
Manages historical price data for backtesting.
This class provides point-in-time data access, ensuring no look-ahead bias
and handling data quality issues.
Attributes:
config: Backtest configuration
data: Dictionary mapping tickers to DataFrames with OHLCV data
current_time: Current simulation time (for look-ahead bias prevention)
"""
def __init__(self, config: BacktestConfig):
"""
Initialize the data handler.
Args:
config: Backtest configuration
"""
self.config = config
self.data: Dict[str, pd.DataFrame] = {}
self.current_time: Optional[datetime] = None
self._cache_dir = Path(config.cache_dir) if config.cache_dir else None
if self._cache_dir:
self._cache_dir.mkdir(parents=True, exist_ok=True)
logger.info("HistoricalDataHandler initialized")
def load_data(
self,
tickers: Union[str, List[str]],
start_date: Optional[str] = None,
end_date: Optional[str] = None,
validate: bool = True,
) -> None:
"""
Load historical data for one or more tickers.
Args:
tickers: Ticker or list of tickers
start_date: Start date (defaults to config start_date)
end_date: End date (defaults to config end_date)
validate: Whether to validate data quality
Raises:
DataNotFoundError: If data cannot be loaded
DataQualityError: If data fails quality checks
"""
if isinstance(tickers, str):
tickers = [tickers]
start_date = start_date or self.config.start_date
end_date = end_date or self.config.end_date
logger.info(f"Loading data for {len(tickers)} ticker(s) from {start_date} to {end_date}")
for ticker in tqdm(tickers, desc="Loading data", disable=not self.config.progress_bar):
# Validate ticker
ticker = validate_ticker(ticker)
# Check cache first
if self.config.cache_data and self._cache_dir:
cached_data = self._load_from_cache(ticker, start_date, end_date)
if cached_data is not None:
self.data[ticker] = cached_data
logger.debug(f"Loaded {ticker} from cache")
continue
# Load from source
try:
data = self._load_from_source(ticker, start_date, end_date)
except Exception as e:
logger.error(f"Failed to load data for {ticker}: {e}")
raise DataNotFoundError(f"Could not load data for {ticker}: {e}")
# Validate data quality
if validate:
self._validate_data(ticker, data)
# Clean and prepare data
data = self._prepare_data(data)
# Store
self.data[ticker] = data
# Cache if enabled
if self.config.cache_data and self._cache_dir:
self._save_to_cache(ticker, data, start_date, end_date)
logger.info(f"Successfully loaded data for {len(self.data)} ticker(s)")
def _load_from_source(
self,
ticker: str,
start_date: str,
end_date: str
) -> pd.DataFrame:
"""
Load data from the configured data source.
Args:
ticker: Ticker symbol
start_date: Start date
end_date: End date
Returns:
DataFrame with OHLCV data
"""
if self.config.data_source == DataSource.YFINANCE:
return self._load_from_yfinance(ticker, start_date, end_date)
elif self.config.data_source == DataSource.CSV:
return self._load_from_csv(ticker, start_date, end_date)
else:
raise DataError(f"Unsupported data source: {self.config.data_source}")
def _load_from_yfinance(
self,
ticker: str,
start_date: str,
end_date: str
) -> pd.DataFrame:
"""Load data from Yahoo Finance."""
# Add buffer to account for data availability
buffer_start = (datetime.strptime(start_date, "%Y-%m-%d") - timedelta(days=5)).strftime("%Y-%m-%d")
try:
stock = yf.Ticker(ticker)
data = stock.history(start=buffer_start, end=end_date, auto_adjust=False)
if data.empty:
raise DataNotFoundError(f"No data returned for {ticker}")
# Standardize column names
data.columns = [col.lower().replace(' ', '_') for col in data.columns]
# Ensure we have required columns
required_cols = ['open', 'high', 'low', 'close', 'volume']
if not all(col in data.columns for col in required_cols):
raise DataQualityError(f"Missing required columns for {ticker}")
return data
except Exception as e:
raise DataError(f"Error loading data from yfinance for {ticker}: {e}")
def _load_from_csv(
self,
ticker: str,
start_date: str,
end_date: str
) -> pd.DataFrame:
"""Load data from CSV file."""
csv_path = Path(self.config.custom_params.get('csv_dir', 'data')) / f"{ticker}.csv"
if not csv_path.exists():
raise DataNotFoundError(f"CSV file not found: {csv_path}")
try:
data = pd.read_csv(csv_path, index_col=0, parse_dates=True)
# Filter date range
data = data[(data.index >= start_date) & (data.index <= end_date)]
# Standardize column names
data.columns = [col.lower().replace(' ', '_') for col in data.columns]
return data
except Exception as e:
raise DataError(f"Error loading CSV for {ticker}: {e}")
def _validate_data(self, ticker: str, data: pd.DataFrame) -> None:
"""
Validate data quality.
Args:
ticker: Ticker symbol
data: DataFrame to validate
Raises:
DataQualityError: If data fails validation
"""
if data.empty:
raise DataQualityError(f"Empty data for {ticker}")
# Check for required columns
required_cols = ['open', 'high', 'low', 'close', 'volume']
missing_cols = [col for col in required_cols if col not in data.columns]
if missing_cols:
raise DataQualityError(f"Missing columns for {ticker}: {missing_cols}")
# Check for excessive missing data
missing_pct = data[required_cols].isnull().sum() / len(data) * 100
high_missing = missing_pct[missing_pct > 10]
if not high_missing.empty:
warnings.warn(
f"High missing data percentage for {ticker}: {high_missing.to_dict()}",
UserWarning
)
# Check for price anomalies
for col in ['open', 'high', 'low', 'close']:
if (data[col] <= 0).any():
raise DataQualityError(f"Non-positive prices found in {col} for {ticker}")
# Check OHLC relationship
invalid_ohlc = (
(data['high'] < data['low']) |
(data['high'] < data['open']) |
(data['high'] < data['close']) |
(data['low'] > data['open']) |
(data['low'] > data['close'])
)
if invalid_ohlc.any():
warnings.warn(
f"Invalid OHLC relationships found for {ticker} on {invalid_ohlc.sum()} days",
UserWarning
)
# Check for suspicious price movements
returns = data['close'].pct_change()
extreme_returns = returns.abs() > 0.5 # 50% in one day
if extreme_returns.any():
warnings.warn(
f"Extreme price movements (>50%) detected for {ticker} on {extreme_returns.sum()} days",
UserWarning
)
logger.debug(f"Data validation passed for {ticker}")
def _prepare_data(self, data: pd.DataFrame) -> pd.DataFrame:
"""
Clean and prepare data for backtesting.
Args:
data: Raw data
Returns:
Cleaned data
"""
# Ensure datetime index
if not isinstance(data.index, pd.DatetimeIndex):
data.index = pd.to_datetime(data.index)
# Sort by date
data = data.sort_index()
# Remove duplicates
data = data[~data.index.duplicated(keep='first')]
# Forward fill missing data (conservative approach)
data = data.fillna(method='ffill')
# Handle any remaining NaNs
data = data.dropna()
return data
def _load_from_cache(
self,
ticker: str,
start_date: str,
end_date: str
) -> Optional[pd.DataFrame]:
"""Load data from cache if available."""
cache_file = self._cache_dir / f"{ticker}_{start_date}_{end_date}.pkl"
if cache_file.exists():
try:
with open(cache_file, 'rb') as f:
return pickle.load(f)
except Exception as e:
logger.warning(f"Failed to load cache for {ticker}: {e}")
return None
def _save_to_cache(
self,
ticker: str,
data: pd.DataFrame,
start_date: str,
end_date: str
) -> None:
"""Save data to cache."""
cache_file = self._cache_dir / f"{ticker}_{start_date}_{end_date}.pkl"
try:
with open(cache_file, 'wb') as f:
pickle.dump(data, f)
logger.debug(f"Cached data for {ticker}")
except Exception as e:
logger.warning(f"Failed to save cache for {ticker}: {e}")
def get_data_at(
self,
ticker: str,
timestamp: datetime,
lookback: Optional[int] = None
) -> pd.DataFrame:
"""
Get historical data up to a specific point in time.
This method ensures no look-ahead bias by only returning data
available at the specified timestamp.
Args:
ticker: Ticker symbol
timestamp: Point in time
lookback: Number of periods to look back (None = all available)
Returns:
DataFrame with historical data up to timestamp
Raises:
LookAheadBiasError: If timestamp is in the future
DataNotFoundError: If ticker not loaded
"""
if ticker not in self.data:
raise DataNotFoundError(f"Data not loaded for {ticker}")
if self.current_time and timestamp > self.current_time:
raise LookAheadBiasError(
f"Requested timestamp {timestamp} is in the future (current: {self.current_time})"
)
# Get data up to timestamp
data = self.data[ticker]
historical = data[data.index <= timestamp]
if lookback:
historical = historical.tail(lookback)
return historical.copy()
def get_price_at(
self,
ticker: str,
timestamp: datetime,
price_type: str = 'close'
) -> Decimal:
"""
Get price at a specific point in time.
Args:
ticker: Ticker symbol
timestamp: Point in time
price_type: Type of price ('open', 'high', 'low', 'close')
Returns:
Price as Decimal
Raises:
DataNotFoundError: If data not available
"""
data = self.get_data_at(ticker, timestamp, lookback=1)
if data.empty:
raise DataNotFoundError(f"No data available for {ticker} at {timestamp}")
price = data.iloc[-1][price_type]
return Decimal(str(price))
def set_current_time(self, timestamp: datetime) -> None:
"""
Set the current simulation time.
This is critical for preventing look-ahead bias.
Args:
timestamp: Current simulation timestamp
"""
if self.current_time and timestamp < self.current_time:
logger.warning(f"Time moving backwards: {self.current_time} -> {timestamp}")
self.current_time = timestamp
logger.debug(f"Current time set to {timestamp}")
def align_data(
self,
tickers: Optional[List[str]] = None,
method: str = 'inner'
) -> pd.DataFrame:
"""
Align data across multiple tickers.
Args:
tickers: List of tickers to align (None = all loaded)
method: Alignment method ('inner', 'outer', 'left', 'right')
Returns:
DataFrame with aligned close prices
Raises:
DataAlignmentError: If alignment fails
"""
if tickers is None:
tickers = list(self.data.keys())
if not tickers:
raise DataAlignmentError("No tickers to align")
try:
# Get close prices for all tickers
prices = pd.DataFrame()
for ticker in tickers:
if ticker not in self.data:
raise DataNotFoundError(f"Data not loaded for {ticker}")
prices[ticker] = self.data[ticker]['close']
# Align using specified method
if method == 'inner':
prices = prices.dropna()
elif method == 'outer':
prices = prices.fillna(method='ffill').fillna(method='bfill')
elif method in ['left', 'right']:
raise NotImplementedError(f"Alignment method '{method}' not implemented")
else:
raise ValueError(f"Unknown alignment method: {method}")
logger.info(f"Aligned {len(tickers)} tickers with {len(prices)} periods")
return prices
except Exception as e:
raise DataAlignmentError(f"Failed to align data: {e}")
def get_trading_days(
self,
start_date: Optional[str] = None,
end_date: Optional[str] = None
) -> pd.DatetimeIndex:
"""
Get trading days in the backtest period.
Args:
start_date: Start date (defaults to config start_date)
end_date: End date (defaults to config end_date)
Returns:
DatetimeIndex of trading days
"""
start_date = start_date or self.config.start_date
end_date = end_date or self.config.end_date
if not self.data:
raise DataError("No data loaded")
# Use first ticker's index as reference
reference_ticker = list(self.data.keys())[0]
all_dates = self.data[reference_ticker].index
# Filter to date range
trading_days = all_dates[
(all_dates >= start_date) & (all_dates <= end_date)
]
return trading_days
def check_survivor_bias(self, tickers: List[str]) -> None:
"""
Warn if using current constituents for historical backtest.
Args:
tickers: List of tickers being tested
"""
warnings.warn(
"SURVIVOR BIAS WARNING: Ensure the ticker list represents "
"securities that existed throughout the backtest period. "
"Using current index constituents for historical backtests "
"can lead to survivorship bias.",
UserWarning
)
logger.warning("Survivor bias check performed")
def get_corporate_actions(
self,
ticker: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None
) -> pd.DataFrame:
"""
Get corporate actions (splits, dividends) for a ticker.
Args:
ticker: Ticker symbol
start_date: Start date
end_date: End date
Returns:
DataFrame with corporate actions
"""
if ticker not in self.data:
raise DataNotFoundError(f"Data not loaded for {ticker}")
# For yfinance, dividends and splits are included in history
start_date = start_date or self.config.start_date
end_date = end_date or self.config.end_date
try:
stock = yf.Ticker(ticker)
# Get splits and dividends
splits = stock.splits
dividends = stock.dividends
# Filter date range
if not splits.empty:
splits = splits[(splits.index >= start_date) & (splits.index <= end_date)]
if not dividends.empty:
dividends = dividends[(dividends.index >= start_date) & (dividends.index <= end_date)]
# Combine into single DataFrame
actions = pd.DataFrame()
if not splits.empty:
actions['splits'] = splits
if not dividends.empty:
actions['dividends'] = dividends
return actions
except Exception as e:
logger.warning(f"Failed to get corporate actions for {ticker}: {e}")
return pd.DataFrame()
def summary(self) -> Dict[str, Any]:
"""
Get summary of loaded data.
Returns:
Dictionary with data summary
"""
if not self.data:
return {"tickers": 0, "message": "No data loaded"}
summary_dict = {
"tickers": len(self.data),
"ticker_list": list(self.data.keys()),
}
for ticker, data in self.data.items():
summary_dict[ticker] = {
"start_date": str(data.index.min().date()),
"end_date": str(data.index.max().date()),
"periods": len(data),
"missing_data_pct": data.isnull().sum().sum() / (len(data) * len(data.columns)) * 100,
}
return summary_dict