264 lines
8.9 KiB
Python
264 lines
8.9 KiB
Python
from datetime import date, datetime
|
|
from decimal import Decimal
|
|
from enum import Enum
|
|
from typing import Optional
|
|
from uuid import UUID, uuid4
|
|
|
|
from pydantic import BaseModel, Field, computed_field, field_validator
|
|
|
|
from .portfolio import PortfolioConfig
|
|
from .trading import Trade
|
|
|
|
|
|
class BacktestStatus(str, Enum):
|
|
PENDING = "pending"
|
|
RUNNING = "running"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
|
|
|
|
class BacktestConfig(BaseModel):
|
|
id: UUID = Field(default_factory=uuid4)
|
|
name: str = Field(default="Backtest")
|
|
description: str | None = None
|
|
|
|
tickers: list[str] = Field(min_length=1)
|
|
start_date: date
|
|
end_date: date
|
|
interval: str = Field(default="1d")
|
|
|
|
portfolio_config: PortfolioConfig = Field(default_factory=PortfolioConfig)
|
|
|
|
warmup_period: int = Field(default=20, ge=0)
|
|
rebalance_frequency: str | None = Field(default=None)
|
|
|
|
use_agent_pipeline: bool = Field(default=True)
|
|
agent_config: dict = Field(default_factory=dict)
|
|
|
|
benchmark_ticker: str | None = Field(default="SPY")
|
|
risk_free_rate: Decimal = Field(default=Decimal("0.05"), ge=0)
|
|
|
|
created_at: datetime = Field(default_factory=datetime.now)
|
|
|
|
@field_validator("end_date")
|
|
@classmethod
|
|
def end_after_start(cls, v: date, info) -> date:
|
|
if "start_date" in info.data and v <= info.data["start_date"]:
|
|
raise ValueError("end_date must be > start_date")
|
|
return v
|
|
|
|
@field_validator("tickers")
|
|
@classmethod
|
|
def validate_tickers(cls, v: list[str]) -> list[str]:
|
|
return [t.upper().strip() for t in v]
|
|
|
|
@computed_field
|
|
@property
|
|
def trading_days_estimate(self) -> int:
|
|
delta = self.end_date - self.start_date
|
|
return int(delta.days * 252 / 365)
|
|
|
|
|
|
class EquityCurvePoint(BaseModel):
|
|
timestamp: datetime
|
|
equity: Decimal
|
|
cash: Decimal
|
|
positions_value: Decimal
|
|
benchmark_value: Decimal | None = None
|
|
drawdown: Decimal = Field(default=Decimal("0"))
|
|
drawdown_percent: Decimal = Field(default=Decimal("0"))
|
|
|
|
|
|
class TradeLog(BaseModel):
|
|
trades: list[Trade] = Field(default_factory=list)
|
|
total_trades: int = Field(default=0, ge=0)
|
|
winning_trades: int = Field(default=0, ge=0)
|
|
losing_trades: int = Field(default=0, ge=0)
|
|
break_even_trades: int = Field(default=0, ge=0)
|
|
|
|
@computed_field
|
|
@property
|
|
def win_rate(self) -> Decimal | None:
|
|
if self.total_trades == 0:
|
|
return None
|
|
return Decimal(self.winning_trades) / Decimal(self.total_trades) * 100
|
|
|
|
@computed_field
|
|
@property
|
|
def loss_rate(self) -> Decimal | None:
|
|
if self.total_trades == 0:
|
|
return None
|
|
return Decimal(self.losing_trades) / Decimal(self.total_trades) * 100
|
|
|
|
def add_trade(self, trade: Trade) -> None:
|
|
self.trades.append(trade)
|
|
self.total_trades += 1
|
|
if trade.is_closed and trade.pnl is not None:
|
|
if trade.pnl > 0:
|
|
self.winning_trades += 1
|
|
elif trade.pnl < 0:
|
|
self.losing_trades += 1
|
|
else:
|
|
self.break_even_trades += 1
|
|
|
|
@property
|
|
def gross_profit(self) -> Decimal:
|
|
return sum(
|
|
t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl > 0
|
|
) or Decimal("0")
|
|
|
|
@property
|
|
def gross_loss(self) -> Decimal:
|
|
return abs(
|
|
sum(t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl < 0)
|
|
or Decimal("0")
|
|
)
|
|
|
|
@property
|
|
def profit_factor(self) -> Decimal | None:
|
|
if self.gross_loss == 0:
|
|
return None
|
|
return self.gross_profit / self.gross_loss
|
|
|
|
@property
|
|
def avg_win(self) -> Decimal | None:
|
|
wins = [t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl > 0]
|
|
if not wins:
|
|
return None
|
|
return sum(wins) / len(wins)
|
|
|
|
@property
|
|
def avg_loss(self) -> Decimal | None:
|
|
losses = [t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl < 0]
|
|
if not losses:
|
|
return None
|
|
return sum(losses) / len(losses)
|
|
|
|
@property
|
|
def avg_holding_period(self) -> float | None:
|
|
periods = [
|
|
t.holding_period for t in self.trades if t.holding_period is not None
|
|
]
|
|
if not periods:
|
|
return None
|
|
return sum(periods) / len(periods)
|
|
|
|
|
|
class BacktestMetrics(BaseModel):
|
|
total_return: Decimal = Field(default=Decimal("0"))
|
|
total_return_percent: Decimal = Field(default=Decimal("0"))
|
|
annualized_return: Decimal = Field(default=Decimal("0"))
|
|
|
|
benchmark_return: Decimal | None = None
|
|
benchmark_return_percent: Decimal | None = None
|
|
alpha: Decimal | None = None
|
|
beta: Decimal | None = None
|
|
|
|
volatility: Decimal = Field(default=Decimal("0"), ge=0)
|
|
annualized_volatility: Decimal = Field(default=Decimal("0"), ge=0)
|
|
downside_volatility: Decimal = Field(default=Decimal("0"), ge=0)
|
|
|
|
sharpe_ratio: Decimal | None = None
|
|
sortino_ratio: Decimal | None = None
|
|
calmar_ratio: Decimal | None = None
|
|
information_ratio: Decimal | None = None
|
|
|
|
max_drawdown: Decimal = Field(default=Decimal("0"), ge=0)
|
|
max_drawdown_percent: Decimal = Field(default=Decimal("0"), ge=0, le=100)
|
|
max_drawdown_duration: int | None = None
|
|
avg_drawdown: Decimal = Field(default=Decimal("0"), ge=0)
|
|
|
|
total_trades: int = Field(default=0, ge=0)
|
|
win_rate: Decimal | None = Field(default=None, ge=0, le=100)
|
|
profit_factor: Decimal | None = None
|
|
avg_trade_pnl: Decimal | None = None
|
|
avg_win: Decimal | None = None
|
|
avg_loss: Decimal | None = None
|
|
largest_win: Decimal | None = None
|
|
largest_loss: Decimal | None = None
|
|
avg_holding_period_days: float | None = None
|
|
|
|
total_commission: Decimal = Field(default=Decimal("0"), ge=0)
|
|
total_slippage: Decimal = Field(default=Decimal("0"), ge=0)
|
|
|
|
trading_days: int = Field(default=0, ge=0)
|
|
start_equity: Decimal = Field(gt=0)
|
|
end_equity: Decimal = Field(gt=0)
|
|
|
|
def to_summary_dict(self) -> dict:
|
|
return {
|
|
"Performance": {
|
|
"Total Return": f"{self.total_return_percent:.2f}%",
|
|
"Annualized Return": f"{self.annualized_return:.2f}%",
|
|
"Sharpe Ratio": f"{self.sharpe_ratio:.2f}"
|
|
if self.sharpe_ratio
|
|
else "N/A",
|
|
"Sortino Ratio": f"{self.sortino_ratio:.2f}"
|
|
if self.sortino_ratio
|
|
else "N/A",
|
|
"Max Drawdown": f"{self.max_drawdown_percent:.2f}%",
|
|
},
|
|
"Risk": {
|
|
"Volatility (Ann.)": f"{self.annualized_volatility:.2f}%",
|
|
"Calmar Ratio": f"{self.calmar_ratio:.2f}"
|
|
if self.calmar_ratio
|
|
else "N/A",
|
|
"Beta": f"{self.beta:.2f}" if self.beta else "N/A",
|
|
},
|
|
"Trading": {
|
|
"Total Trades": self.total_trades,
|
|
"Win Rate": f"{self.win_rate:.1f}%" if self.win_rate else "N/A",
|
|
"Profit Factor": f"{self.profit_factor:.2f}"
|
|
if self.profit_factor
|
|
else "N/A",
|
|
"Avg Holding Period": f"{self.avg_holding_period_days:.1f} days"
|
|
if self.avg_holding_period_days
|
|
else "N/A",
|
|
},
|
|
"Costs": {
|
|
"Total Commission": f"${self.total_commission:.2f}",
|
|
"Total Slippage": f"${self.total_slippage:.2f}",
|
|
},
|
|
}
|
|
|
|
|
|
class BacktestResult(BaseModel):
|
|
id: UUID = Field(default_factory=uuid4)
|
|
config: BacktestConfig
|
|
metrics: BacktestMetrics
|
|
trade_log: TradeLog
|
|
equity_curve: list[EquityCurvePoint] = Field(default_factory=list)
|
|
daily_returns: list[Decimal] = Field(default_factory=list)
|
|
started_at: datetime
|
|
completed_at: datetime
|
|
status: BacktestStatus = Field(default=BacktestStatus.COMPLETED)
|
|
error_message: str | None = None
|
|
|
|
@computed_field
|
|
@property
|
|
def duration_seconds(self) -> float:
|
|
return (self.completed_at - self.started_at).total_seconds()
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"id": str(self.id),
|
|
"config": {
|
|
"name": self.config.name,
|
|
"tickers": self.config.tickers,
|
|
"start_date": self.config.start_date.isoformat(),
|
|
"end_date": self.config.end_date.isoformat(),
|
|
"initial_cash": float(self.config.portfolio_config.initial_cash),
|
|
},
|
|
"metrics": self.metrics.to_summary_dict(),
|
|
"trade_summary": {
|
|
"total_trades": self.trade_log.total_trades,
|
|
"winning_trades": self.trade_log.winning_trades,
|
|
"losing_trades": self.trade_log.losing_trades,
|
|
"win_rate": float(self.trade_log.win_rate)
|
|
if self.trade_log.win_rate
|
|
else None,
|
|
},
|
|
"duration_seconds": self.duration_seconds,
|
|
"status": self.status,
|
|
}
|