252 lines
8.3 KiB
Python
252 lines
8.3 KiB
Python
import datetime
|
|
from datetime import date as date_type
|
|
from decimal import Decimal
|
|
|
|
import typer
|
|
from rich import box
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.table import Table
|
|
|
|
from cli.display import create_question_box
|
|
from cli.utils import loading
|
|
from tradingagents.backtesting import SimpleBacktestEngine
|
|
from tradingagents.models.backtest import BacktestConfig, BacktestStatus
|
|
from tradingagents.models.portfolio import PortfolioConfig
|
|
|
|
console = Console()
|
|
|
|
|
|
def sma_buy(ticker: str, trading_date: date_type, ctx: dict) -> bool:
|
|
loader = ctx["data_loader"]
|
|
ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date)
|
|
if len(ohlcv.bars) < 20:
|
|
return False
|
|
prices = [float(b.close) for b in ohlcv.bars[-20:]]
|
|
sma = sum(prices) / len(prices)
|
|
current = float(ohlcv.bars[-1].close)
|
|
return current > sma * 1.02
|
|
|
|
|
|
def sma_sell(ticker: str, trading_date: date_type, ctx: dict) -> bool:
|
|
loader = ctx["data_loader"]
|
|
ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date)
|
|
if len(ohlcv.bars) < 20:
|
|
return False
|
|
prices = [float(b.close) for b in ohlcv.bars[-20:]]
|
|
sma = sum(prices) / len(prices)
|
|
current = float(ohlcv.bars[-1].close)
|
|
return current < sma * 0.98
|
|
|
|
|
|
def rsi_buy(ticker: str, trading_date: date_type, ctx: dict) -> bool:
|
|
loader = ctx["data_loader"]
|
|
ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date)
|
|
if len(ohlcv.bars) < 15:
|
|
return False
|
|
changes = []
|
|
for i in range(1, min(15, len(ohlcv.bars))):
|
|
changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i - 1].close))
|
|
gains = [c for c in changes if c > 0]
|
|
losses = [-c for c in changes if c < 0]
|
|
avg_gain = sum(gains) / 14 if gains else 0.001
|
|
avg_loss = sum(losses) / 14 if losses else 0.001
|
|
rs = avg_gain / avg_loss if avg_loss else 100
|
|
rsi = 100 - (100 / (1 + rs))
|
|
return rsi < 30
|
|
|
|
|
|
def rsi_sell(ticker: str, trading_date: date_type, ctx: dict) -> bool:
|
|
loader = ctx["data_loader"]
|
|
ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date)
|
|
if len(ohlcv.bars) < 15:
|
|
return False
|
|
changes = []
|
|
for i in range(1, min(15, len(ohlcv.bars))):
|
|
changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i - 1].close))
|
|
gains = [c for c in changes if c > 0]
|
|
losses = [-c for c in changes if c < 0]
|
|
avg_gain = sum(gains) / 14 if gains else 0.001
|
|
avg_loss = sum(losses) / 14 if losses else 0.001
|
|
rs = avg_gain / avg_loss if avg_loss else 100
|
|
rsi = 100 - (100 / (1 + rs))
|
|
return rsi > 70
|
|
|
|
|
|
def hold_buy(ticker: str, trading_date: date_type, ctx: dict) -> bool:
|
|
return ctx.get("day_index", 0) == 5
|
|
|
|
|
|
def hold_sell(ticker: str, trading_date: date_type, ctx: dict) -> bool:
|
|
return False
|
|
|
|
|
|
STRATEGIES = {
|
|
"sma": (sma_buy, sma_sell),
|
|
"rsi": (rsi_buy, rsi_sell),
|
|
"hold": (hold_buy, hold_sell),
|
|
}
|
|
|
|
|
|
def run_backtest(
|
|
ticker: str = None,
|
|
start_date: str = None,
|
|
end_date: str = None,
|
|
initial_cash: float = 100000.0,
|
|
strategy: str = "sma",
|
|
) -> None:
|
|
if not ticker:
|
|
console.print(
|
|
create_question_box(
|
|
"Ticker Symbol", "Enter the ticker symbol to backtest", "AAPL"
|
|
)
|
|
)
|
|
ticker = typer.prompt("", default="AAPL")
|
|
|
|
if not start_date:
|
|
default_start = (
|
|
datetime.datetime.now() - datetime.timedelta(days=365)
|
|
).strftime("%Y-%m-%d")
|
|
console.print(
|
|
create_question_box(
|
|
"Start Date", "Enter backtest start date (YYYY-MM-DD)", default_start
|
|
)
|
|
)
|
|
start_date = typer.prompt("", default=default_start)
|
|
|
|
if not end_date:
|
|
default_end = datetime.datetime.now().strftime("%Y-%m-%d")
|
|
console.print(
|
|
create_question_box(
|
|
"End Date", "Enter backtest end date (YYYY-MM-DD)", default_end
|
|
)
|
|
)
|
|
end_date = typer.prompt("", default=default_end)
|
|
|
|
try:
|
|
start = datetime.datetime.strptime(start_date, "%Y-%m-%d").date()
|
|
end = datetime.datetime.strptime(end_date, "%Y-%m-%d").date()
|
|
except ValueError:
|
|
console.print("[red]Invalid date format. Use YYYY-MM-DD[/red]")
|
|
return
|
|
|
|
if start >= end:
|
|
console.print("[red]Start date must be before end date[/red]")
|
|
return
|
|
|
|
console.print()
|
|
console.print(
|
|
Panel(
|
|
f"[bold]Backtest Configuration[/bold]\n\n"
|
|
f"Ticker: [cyan]{ticker.upper()}[/cyan]\n"
|
|
f"Period: [cyan]{start_date}[/cyan] to [cyan]{end_date}[/cyan]\n"
|
|
f"Initial Cash: [cyan]${initial_cash:,.2f}[/cyan]\n"
|
|
f"Strategy: [cyan]{strategy}[/cyan]",
|
|
title="Configuration",
|
|
border_style="blue",
|
|
)
|
|
)
|
|
console.print()
|
|
|
|
if strategy not in STRATEGIES:
|
|
console.print(
|
|
f"[red]Unknown strategy: {strategy}. Use: sma, rsi, or hold[/red]"
|
|
)
|
|
return
|
|
|
|
buy_fn, sell_fn = STRATEGIES[strategy]
|
|
|
|
config = BacktestConfig(
|
|
name=f"{strategy.upper()} Backtest - {ticker.upper()}",
|
|
tickers=[ticker.upper()],
|
|
start_date=start,
|
|
end_date=end,
|
|
portfolio_config=PortfolioConfig(
|
|
initial_cash=Decimal(str(initial_cash)),
|
|
commission_per_trade=Decimal("1"),
|
|
slippage_percent=Decimal("0.05"),
|
|
),
|
|
warmup_period=5,
|
|
)
|
|
|
|
with loading("Running backtest...", show_elapsed=True):
|
|
engine = SimpleBacktestEngine(config, buy_signal=buy_fn, sell_signal=sell_fn)
|
|
result = engine.run()
|
|
|
|
console.print()
|
|
|
|
if result.status == BacktestStatus.FAILED:
|
|
console.print(f"[red]Backtest failed: {result.error_message}[/red]")
|
|
return
|
|
|
|
metrics = result.metrics
|
|
trade_log = result.trade_log
|
|
|
|
performance_table = Table(title="Performance Metrics", box=box.ROUNDED)
|
|
performance_table.add_column("Metric", style="cyan")
|
|
performance_table.add_column("Value", style="green")
|
|
|
|
performance_table.add_row("Total Return", f"${float(metrics.total_return):,.2f}")
|
|
performance_table.add_row(
|
|
"Total Return %", f"{float(metrics.total_return_percent):.2f}%"
|
|
)
|
|
performance_table.add_row(
|
|
"Annualized Return", f"{float(metrics.annualized_return):.2f}%"
|
|
)
|
|
performance_table.add_row(
|
|
"Sharpe Ratio",
|
|
f"{float(metrics.sharpe_ratio):.2f}" if metrics.sharpe_ratio else "N/A",
|
|
)
|
|
performance_table.add_row(
|
|
"Sortino Ratio",
|
|
f"{float(metrics.sortino_ratio):.2f}" if metrics.sortino_ratio else "N/A",
|
|
)
|
|
performance_table.add_row(
|
|
"Max Drawdown", f"{float(metrics.max_drawdown_percent):.2f}%"
|
|
)
|
|
performance_table.add_row(
|
|
"Volatility (Ann.)", f"{float(metrics.annualized_volatility):.2f}%"
|
|
)
|
|
|
|
console.print(performance_table)
|
|
console.print()
|
|
|
|
trading_table = Table(title="Trading Statistics", box=box.ROUNDED)
|
|
trading_table.add_column("Metric", style="cyan")
|
|
trading_table.add_column("Value", style="green")
|
|
|
|
trading_table.add_row("Total Trades", str(trade_log.total_trades))
|
|
trading_table.add_row("Winning Trades", str(trade_log.winning_trades))
|
|
trading_table.add_row("Losing Trades", str(trade_log.losing_trades))
|
|
trading_table.add_row(
|
|
"Win Rate", f"{float(trade_log.win_rate):.1f}%" if trade_log.win_rate else "N/A"
|
|
)
|
|
trading_table.add_row(
|
|
"Profit Factor",
|
|
f"{float(trade_log.profit_factor):.2f}" if trade_log.profit_factor else "N/A",
|
|
)
|
|
trading_table.add_row(
|
|
"Avg Win", f"${float(trade_log.avg_win):,.2f}" if trade_log.avg_win else "N/A"
|
|
)
|
|
trading_table.add_row(
|
|
"Avg Loss",
|
|
f"${float(trade_log.avg_loss):,.2f}" if trade_log.avg_loss else "N/A",
|
|
)
|
|
|
|
console.print(trading_table)
|
|
console.print()
|
|
|
|
summary_table = Table(title="Portfolio Summary", box=box.ROUNDED)
|
|
summary_table.add_column("Metric", style="cyan")
|
|
summary_table.add_column("Value", style="green")
|
|
|
|
summary_table.add_row("Start Equity", f"${float(metrics.start_equity):,.2f}")
|
|
summary_table.add_row("End Equity", f"${float(metrics.end_equity):,.2f}")
|
|
summary_table.add_row("Trading Days", str(metrics.trading_days))
|
|
summary_table.add_row("Duration", f"{result.duration_seconds:.1f} seconds")
|
|
|
|
console.print(summary_table)
|
|
console.print()
|
|
|
|
console.print("[green]Backtest completed successfully![/green]")
|