367 lines
13 KiB
Python
367 lines
13 KiB
Python
import logging
|
|
from datetime import date, datetime, timedelta
|
|
from decimal import Decimal
|
|
from typing import Optional, Callable
|
|
from uuid import uuid4
|
|
|
|
from tradingagents.models.backtest import (
|
|
BacktestConfig,
|
|
BacktestResult,
|
|
EquityCurvePoint,
|
|
TradeLog,
|
|
)
|
|
from tradingagents.models.decisions import SignalType, TradingDecision
|
|
from tradingagents.models.portfolio import PortfolioSnapshot
|
|
from tradingagents.models.trading import Order, OrderSide, OrderStatus, Fill, Trade
|
|
|
|
from .data_loader import DataLoader
|
|
from .metrics import MetricsCalculator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BacktestEngine:
|
|
def __init__(
|
|
self,
|
|
config: BacktestConfig,
|
|
decision_callback: Optional[Callable[[str, date, dict], TradingDecision]] = None,
|
|
):
|
|
self.config = config
|
|
self.decision_callback = decision_callback
|
|
self.data_loader = DataLoader()
|
|
self.metrics_calculator = MetricsCalculator(config.risk_free_rate)
|
|
|
|
self.portfolio: Optional[PortfolioSnapshot] = None
|
|
self.trade_log: Optional[TradeLog] = None
|
|
self.equity_curve: list[EquityCurvePoint] = []
|
|
self.daily_returns: list[Decimal] = []
|
|
self.decisions: list[TradingDecision] = []
|
|
self.open_trades: dict[str, Trade] = {}
|
|
|
|
def run(self) -> BacktestResult:
|
|
started_at = datetime.now()
|
|
|
|
try:
|
|
self._initialize()
|
|
self._preload_data()
|
|
trading_days = self._get_trading_days()
|
|
|
|
for i, trading_date in enumerate(trading_days):
|
|
if i < self.config.warmup_period:
|
|
continue
|
|
|
|
self._process_day(trading_date, i)
|
|
|
|
self._close_all_positions(trading_days[-1] if trading_days else self.config.end_date)
|
|
|
|
metrics = self.metrics_calculator.calculate_metrics(
|
|
self.equity_curve,
|
|
self.trade_log,
|
|
)
|
|
|
|
completed_at = datetime.now()
|
|
|
|
return BacktestResult(
|
|
config=self.config,
|
|
metrics=metrics,
|
|
trade_log=self.trade_log,
|
|
equity_curve=self.equity_curve,
|
|
daily_returns=self.daily_returns,
|
|
started_at=started_at,
|
|
completed_at=completed_at,
|
|
status="completed",
|
|
)
|
|
|
|
except (ValueError, KeyError, RuntimeError, FileNotFoundError, OSError) as e:
|
|
logger.exception("Backtest failed: %s", e)
|
|
completed_at = datetime.now()
|
|
|
|
return BacktestResult(
|
|
config=self.config,
|
|
metrics=self._empty_metrics(),
|
|
trade_log=self.trade_log or TradeLog(),
|
|
equity_curve=self.equity_curve,
|
|
daily_returns=self.daily_returns,
|
|
started_at=started_at,
|
|
completed_at=completed_at,
|
|
status="failed",
|
|
error_message=str(e),
|
|
)
|
|
|
|
def _initialize(self):
|
|
self.portfolio = PortfolioSnapshot(
|
|
cash=self.config.portfolio_config.initial_cash,
|
|
)
|
|
self.trade_log = TradeLog()
|
|
self.equity_curve = []
|
|
self.daily_returns = []
|
|
self.decisions = []
|
|
self.open_trades = {}
|
|
|
|
def _preload_data(self):
|
|
logger.info("Preloading data for %s tickers", len(self.config.tickers))
|
|
for ticker in self.config.tickers:
|
|
self.data_loader.load_ohlcv(
|
|
ticker,
|
|
self.config.start_date - timedelta(days=self.config.warmup_period + 10),
|
|
self.config.end_date,
|
|
)
|
|
|
|
def _get_trading_days(self) -> list[date]:
|
|
primary_ticker = self.config.tickers[0]
|
|
return self.data_loader.get_trading_days(
|
|
primary_ticker,
|
|
self.config.start_date,
|
|
self.config.end_date,
|
|
)
|
|
|
|
def _process_day(self, trading_date: date, day_index: int):
|
|
prices = self.data_loader.get_prices_dict(self.config.tickers, trading_date)
|
|
|
|
if not prices:
|
|
logger.debug("No prices available for %s", trading_date)
|
|
return
|
|
|
|
for ticker in self.config.tickers:
|
|
if ticker not in prices:
|
|
continue
|
|
|
|
decision = self._get_decision(ticker, trading_date, day_index)
|
|
if decision:
|
|
self.decisions.append(decision)
|
|
self._execute_decision(decision, prices[ticker], trading_date)
|
|
|
|
self._record_equity(trading_date, prices)
|
|
|
|
def _get_decision(
|
|
self,
|
|
ticker: str,
|
|
trading_date: date,
|
|
day_index: int,
|
|
) -> Optional[TradingDecision]:
|
|
if self.decision_callback:
|
|
context = {
|
|
"day_index": day_index,
|
|
"portfolio": self.portfolio,
|
|
"open_trade": self.open_trades.get(ticker),
|
|
}
|
|
return self.decision_callback(ticker, trading_date, context)
|
|
|
|
return self._simple_strategy(ticker, trading_date)
|
|
|
|
def _simple_strategy(
|
|
self,
|
|
ticker: str,
|
|
trading_date: date,
|
|
) -> Optional[TradingDecision]:
|
|
return None
|
|
|
|
def _execute_decision(
|
|
self,
|
|
decision: TradingDecision,
|
|
price: Decimal,
|
|
trading_date: date,
|
|
):
|
|
ticker = decision.ticker
|
|
config = self.config.portfolio_config
|
|
position = self.portfolio.get_position(ticker)
|
|
|
|
if decision.is_buy and position.quantity == 0:
|
|
execution_price = config.calculate_slippage(price, OrderSide.BUY)
|
|
|
|
if decision.recommended_quantity:
|
|
quantity = decision.recommended_quantity
|
|
else:
|
|
max_position_value = self.portfolio.cash * (config.max_position_size_percent / 100)
|
|
quantity = int(max_position_value / execution_price)
|
|
|
|
if quantity <= 0:
|
|
return
|
|
|
|
if not self.portfolio.can_afford(ticker, quantity, execution_price, config):
|
|
quantity = self.portfolio.max_shares_affordable(ticker, execution_price, config)
|
|
|
|
if quantity <= 0:
|
|
return
|
|
|
|
commission = config.calculate_commission(quantity, execution_price)
|
|
|
|
order = Order(
|
|
ticker=ticker,
|
|
side=OrderSide.BUY,
|
|
quantity=quantity,
|
|
status=OrderStatus.FILLED,
|
|
filled_quantity=quantity,
|
|
filled_avg_price=execution_price,
|
|
filled_at=datetime.combine(trading_date, datetime.min.time()),
|
|
commission=commission,
|
|
)
|
|
|
|
fill = Fill(
|
|
order_id=order.id,
|
|
ticker=ticker,
|
|
side=OrderSide.BUY,
|
|
quantity=quantity,
|
|
price=execution_price,
|
|
commission=commission,
|
|
timestamp=datetime.combine(trading_date, datetime.min.time()),
|
|
)
|
|
|
|
self.portfolio.apply_fill(fill)
|
|
|
|
trade = Trade(
|
|
ticker=ticker,
|
|
side=OrderSide.BUY,
|
|
entry_price=execution_price,
|
|
entry_quantity=quantity,
|
|
entry_time=datetime.combine(trading_date, datetime.min.time()),
|
|
entry_order_id=order.id,
|
|
)
|
|
self.open_trades[ticker] = trade
|
|
|
|
logger.debug(
|
|
"BUY %s: %d shares @ $%.2f on %s",
|
|
ticker, quantity, execution_price, trading_date
|
|
)
|
|
|
|
elif decision.is_sell and position.quantity > 0:
|
|
execution_price = config.calculate_slippage(price, OrderSide.SELL)
|
|
quantity = position.quantity
|
|
commission = config.calculate_commission(quantity, execution_price)
|
|
|
|
order = Order(
|
|
ticker=ticker,
|
|
side=OrderSide.SELL,
|
|
quantity=quantity,
|
|
status=OrderStatus.FILLED,
|
|
filled_quantity=quantity,
|
|
filled_avg_price=execution_price,
|
|
filled_at=datetime.combine(trading_date, datetime.min.time()),
|
|
commission=commission,
|
|
)
|
|
|
|
fill = Fill(
|
|
order_id=order.id,
|
|
ticker=ticker,
|
|
side=OrderSide.SELL,
|
|
quantity=quantity,
|
|
price=execution_price,
|
|
commission=commission,
|
|
timestamp=datetime.combine(trading_date, datetime.min.time()),
|
|
)
|
|
|
|
self.portfolio.apply_fill(fill)
|
|
|
|
if ticker in self.open_trades:
|
|
trade = self.open_trades[ticker]
|
|
trade.exit_price = execution_price
|
|
trade.exit_quantity = quantity
|
|
trade.exit_time = datetime.combine(trading_date, datetime.min.time())
|
|
trade.exit_order_id = order.id
|
|
trade.commission = (
|
|
config.calculate_commission(trade.entry_quantity, trade.entry_price) +
|
|
commission
|
|
)
|
|
self.trade_log.add_trade(trade)
|
|
del self.open_trades[ticker]
|
|
|
|
logger.debug(
|
|
"SELL %s: %d shares @ $%.2f on %s",
|
|
ticker, quantity, execution_price, trading_date
|
|
)
|
|
|
|
def _record_equity(self, trading_date: date, prices: dict[str, Decimal]):
|
|
equity = self.portfolio.total_equity(prices)
|
|
positions_value = self.portfolio.positions_value(prices)
|
|
|
|
point = EquityCurvePoint(
|
|
timestamp=datetime.combine(trading_date, datetime.min.time()),
|
|
equity=equity,
|
|
cash=self.portfolio.cash,
|
|
positions_value=positions_value,
|
|
)
|
|
self.equity_curve.append(point)
|
|
|
|
if len(self.equity_curve) > 1:
|
|
prev_equity = self.equity_curve[-2].equity
|
|
if prev_equity > 0:
|
|
daily_return = (equity - prev_equity) / prev_equity
|
|
self.daily_returns.append(daily_return)
|
|
|
|
def _close_all_positions(self, final_date: date):
|
|
prices = self.data_loader.get_prices_dict(self.config.tickers, final_date)
|
|
|
|
for ticker, trade in list(self.open_trades.items()):
|
|
if ticker in prices:
|
|
decision = TradingDecision(
|
|
ticker=ticker,
|
|
timestamp=datetime.now(),
|
|
decision_date=datetime.combine(final_date, datetime.min.time()),
|
|
signal=SignalType.SELL,
|
|
confidence=Decimal("1.0"),
|
|
recommended_action="SELL",
|
|
final_decision="SELL - End of backtest",
|
|
rationale="Closing position at end of backtest period",
|
|
)
|
|
self._execute_decision(decision, prices[ticker], final_date)
|
|
|
|
def _empty_metrics(self):
|
|
from tradingagents.models.backtest import BacktestMetrics
|
|
return BacktestMetrics(
|
|
start_equity=self.config.portfolio_config.initial_cash,
|
|
end_equity=self.portfolio.cash if self.portfolio else self.config.portfolio_config.initial_cash,
|
|
)
|
|
|
|
|
|
class SimpleBacktestEngine(BacktestEngine):
|
|
def __init__(
|
|
self,
|
|
config: BacktestConfig,
|
|
buy_signal: Callable[[str, date, dict], bool] = None,
|
|
sell_signal: Callable[[str, date, dict], bool] = None,
|
|
):
|
|
super().__init__(config)
|
|
self.buy_signal = buy_signal
|
|
self.sell_signal = sell_signal
|
|
|
|
def _get_decision(
|
|
self,
|
|
ticker: str,
|
|
trading_date: date,
|
|
day_index: int,
|
|
) -> Optional[TradingDecision]:
|
|
context = {
|
|
"day_index": day_index,
|
|
"portfolio": self.portfolio,
|
|
"data_loader": self.data_loader,
|
|
"open_trade": self.open_trades.get(ticker),
|
|
}
|
|
|
|
position = self.portfolio.get_position(ticker)
|
|
|
|
if position.quantity == 0 and self.buy_signal and self.buy_signal(ticker, trading_date, context):
|
|
return TradingDecision(
|
|
ticker=ticker,
|
|
timestamp=datetime.now(),
|
|
decision_date=datetime.combine(trading_date, datetime.min.time()),
|
|
signal=SignalType.BUY,
|
|
confidence=Decimal("0.7"),
|
|
recommended_action="BUY",
|
|
final_decision="BUY",
|
|
rationale="Buy signal triggered",
|
|
)
|
|
|
|
if position.quantity > 0 and self.sell_signal and self.sell_signal(ticker, trading_date, context):
|
|
return TradingDecision(
|
|
ticker=ticker,
|
|
timestamp=datetime.now(),
|
|
decision_date=datetime.combine(trading_date, datetime.min.time()),
|
|
signal=SignalType.SELL,
|
|
confidence=Decimal("0.7"),
|
|
recommended_action="SELL",
|
|
final_decision="SELL",
|
|
rationale="Sell signal triggered",
|
|
)
|
|
|
|
return None
|