TradingAgents/tradingagents/backtest/config.py

364 lines
12 KiB
Python

"""
Configuration management for the backtesting framework.
This module provides configuration classes and utilities for managing
backtest parameters, ensuring type safety and validation.
"""
from dataclasses import dataclass, field, asdict
from decimal import Decimal
from datetime import datetime, time
from typing import Optional, Dict, Any, List
from enum import Enum
import json
import logging
from .exceptions import InvalidConfigError, MissingConfigError
logger = logging.getLogger(__name__)
class OrderType(Enum):
"""Supported order types."""
MARKET = "market"
LIMIT = "limit"
STOP = "stop"
STOP_LIMIT = "stop_limit"
class DataSource(Enum):
"""Supported data sources."""
YFINANCE = "yfinance"
CSV = "csv"
ALPHA_VANTAGE = "alpha_vantage"
LOCAL = "local"
CUSTOM = "custom"
class SlippageModel(Enum):
"""Slippage modeling approaches."""
FIXED = "fixed" # Fixed percentage
VOLUME_BASED = "volume_based" # Based on volume
SPREAD_BASED = "spread_based" # Based on bid-ask spread
CUSTOM = "custom" # Custom function
class CommissionModel(Enum):
"""Commission modeling approaches."""
FIXED_PER_TRADE = "fixed_per_trade" # Fixed amount per trade
PER_SHARE = "per_share" # Amount per share
PERCENTAGE = "percentage" # Percentage of trade value
TIERED = "tiered" # Tiered based on volume
CUSTOM = "custom" # Custom function
@dataclass
class BacktestConfig:
"""
Configuration for backtesting.
Attributes:
initial_capital: Starting capital for the backtest
start_date: Start date for the backtest (YYYY-MM-DD)
end_date: End date for the backtest (YYYY-MM-DD)
commission: Commission rate (as decimal, e.g., 0.001 for 0.1%)
slippage: Slippage rate (as decimal, e.g., 0.0005 for 0.05%)
benchmark: Benchmark ticker for comparison (e.g., 'SPY')
data_source: Source for historical data
commission_model: Commission calculation model
slippage_model: Slippage calculation model
max_position_size: Maximum position size as fraction of portfolio (None = unlimited)
max_leverage: Maximum leverage allowed (1.0 = no leverage)
allow_short: Whether to allow short positions
margin_requirement: Margin requirement for positions (as decimal)
risk_free_rate: Annual risk-free rate for metrics (as decimal)
trading_hours: Trading hours enforcement (None = 24/7)
market_impact: Whether to model market impact
partial_fills: Whether to allow partial fills
time_zone: Time zone for timestamps
cache_data: Whether to cache historical data
cache_dir: Directory for data cache
log_level: Logging level
progress_bar: Whether to show progress bar
random_seed: Random seed for reproducibility
"""
# Core parameters
initial_capital: Decimal
start_date: str
end_date: str
# Costs
commission: Decimal = Decimal("0.0")
slippage: Decimal = Decimal("0.0")
commission_model: CommissionModel = CommissionModel.PERCENTAGE
slippage_model: SlippageModel = SlippageModel.FIXED
# Benchmark
benchmark: Optional[str] = None
# Data
data_source: DataSource = DataSource.YFINANCE
cache_data: bool = True
cache_dir: Optional[str] = None
# Risk controls
max_position_size: Optional[Decimal] = None
max_leverage: Decimal = Decimal("1.0")
allow_short: bool = False
margin_requirement: Decimal = Decimal("0.5")
# Performance metrics
risk_free_rate: Decimal = Decimal("0.02") # 2% annual
# Execution
trading_hours: Optional[Dict[str, Any]] = None
market_impact: bool = False
partial_fills: bool = False
# System
time_zone: str = "America/New_York"
log_level: str = "INFO"
progress_bar: bool = True
random_seed: Optional[int] = None
# Custom parameters
custom_params: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
"""Validate configuration after initialization."""
self._validate()
def _validate(self):
"""Validate configuration parameters."""
# Validate capital
if self.initial_capital <= 0:
raise InvalidConfigError("Initial capital must be positive")
# Validate dates
try:
start = datetime.strptime(self.start_date, "%Y-%m-%d")
end = datetime.strptime(self.end_date, "%Y-%m-%d")
except ValueError as e:
raise InvalidConfigError(f"Invalid date format: {e}")
if start >= end:
raise InvalidConfigError("Start date must be before end date")
# Validate rates
if self.commission < 0:
raise InvalidConfigError("Commission cannot be negative")
if self.slippage < 0:
raise InvalidConfigError("Slippage cannot be negative")
if self.risk_free_rate < 0:
raise InvalidConfigError("Risk-free rate cannot be negative")
# Validate leverage and margin
if self.max_leverage < Decimal("1.0"):
raise InvalidConfigError("Max leverage must be >= 1.0")
if not (Decimal("0.0") < self.margin_requirement <= Decimal("1.0")):
raise InvalidConfigError("Margin requirement must be between 0 and 1")
# Validate position size
if self.max_position_size is not None:
if not (Decimal("0.0") < self.max_position_size <= Decimal("1.0")):
raise InvalidConfigError("Max position size must be between 0 and 1")
# Convert enum strings if necessary
if isinstance(self.commission_model, str):
self.commission_model = CommissionModel(self.commission_model)
if isinstance(self.slippage_model, str):
self.slippage_model = SlippageModel(self.slippage_model)
if isinstance(self.data_source, str):
self.data_source = DataSource(self.data_source)
logger.info(f"Backtest config validated: {self.start_date} to {self.end_date}")
def to_dict(self) -> Dict[str, Any]:
"""Convert configuration to dictionary."""
result = asdict(self)
# Convert Decimal to float for JSON serialization
for key, value in result.items():
if isinstance(value, Decimal):
result[key] = float(value)
elif isinstance(value, Enum):
result[key] = value.value
return result
def to_json(self, filepath: Optional[str] = None) -> str:
"""
Serialize configuration to JSON.
Args:
filepath: Optional file path to save JSON
Returns:
JSON string representation
"""
json_str = json.dumps(self.to_dict(), indent=2)
if filepath:
with open(filepath, 'w') as f:
f.write(json_str)
logger.info(f"Config saved to {filepath}")
return json_str
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> 'BacktestConfig':
"""
Create configuration from dictionary.
Args:
config_dict: Dictionary of configuration parameters
Returns:
BacktestConfig instance
"""
# Convert numeric values to Decimal
decimal_fields = [
'initial_capital', 'commission', 'slippage',
'max_position_size', 'max_leverage', 'margin_requirement',
'risk_free_rate'
]
for field_name in decimal_fields:
if field_name in config_dict and config_dict[field_name] is not None:
config_dict[field_name] = Decimal(str(config_dict[field_name]))
# Convert enum values
enum_fields = {
'commission_model': CommissionModel,
'slippage_model': SlippageModel,
'data_source': DataSource,
}
for field_name, enum_class in enum_fields.items():
if field_name in config_dict and config_dict[field_name] is not None:
if isinstance(config_dict[field_name], str):
config_dict[field_name] = enum_class(config_dict[field_name])
return cls(**config_dict)
@classmethod
def from_json(cls, filepath: str) -> 'BacktestConfig':
"""
Load configuration from JSON file.
Args:
filepath: Path to JSON configuration file
Returns:
BacktestConfig instance
"""
with open(filepath, 'r') as f:
config_dict = json.load(f)
return cls.from_dict(config_dict)
@dataclass
class WalkForwardConfig:
"""
Configuration for walk-forward analysis.
Attributes:
in_sample_months: Number of months for in-sample (training) period
out_sample_months: Number of months for out-of-sample (testing) period
step_months: Number of months to step forward (default: out_sample_months)
optimization_metric: Metric to optimize ('sharpe', 'return', 'sortino', etc.)
min_periods: Minimum number of periods required
anchored: Whether to use anchored walk-forward (growing window)
"""
in_sample_months: int
out_sample_months: int
step_months: Optional[int] = None
optimization_metric: str = "sharpe"
min_periods: int = 20
anchored: bool = False
def __post_init__(self):
"""Validate configuration."""
if self.step_months is None:
self.step_months = self.out_sample_months
if self.in_sample_months <= 0:
raise InvalidConfigError("In-sample months must be positive")
if self.out_sample_months <= 0:
raise InvalidConfigError("Out-of-sample months must be positive")
if self.step_months <= 0:
raise InvalidConfigError("Step months must be positive")
if self.min_periods <= 0:
raise InvalidConfigError("Min periods must be positive")
@dataclass
class MonteCarloConfig:
"""
Configuration for Monte Carlo simulation.
Attributes:
n_simulations: Number of simulations to run
method: Simulation method ('resample_trades', 'resample_returns', 'parametric')
confidence_levels: Confidence levels for intervals (e.g., [0.90, 0.95, 0.99])
random_seed: Random seed for reproducibility
preserve_order: Whether to preserve trade order in resampling
"""
n_simulations: int = 10000
method: str = "resample_trades"
confidence_levels: List[float] = field(default_factory=lambda: [0.90, 0.95, 0.99])
random_seed: Optional[int] = None
preserve_order: bool = False
def __post_init__(self):
"""Validate configuration."""
if self.n_simulations <= 0:
raise InvalidConfigError("Number of simulations must be positive")
if self.method not in ['resample_trades', 'resample_returns', 'parametric']:
raise InvalidConfigError(f"Invalid Monte Carlo method: {self.method}")
for level in self.confidence_levels:
if not (0 < level < 1):
raise InvalidConfigError(f"Invalid confidence level: {level}")
def create_default_config(
initial_capital: float = 100000.0,
start_date: str = "2020-01-01",
end_date: str = "2023-12-31",
**kwargs
) -> BacktestConfig:
"""
Create a default backtest configuration with sensible defaults.
Args:
initial_capital: Starting capital
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
**kwargs: Additional configuration parameters
Returns:
BacktestConfig instance
"""
config_dict = {
'initial_capital': Decimal(str(initial_capital)),
'start_date': start_date,
'end_date': end_date,
'commission': Decimal("0.001"), # 0.1%
'slippage': Decimal("0.0005"), # 0.05%
'benchmark': 'SPY',
**kwargs
}
return BacktestConfig.from_dict(config_dict)