import datetime from datetime import date as date_type from decimal import Decimal import typer from rich import box from rich.console import Console from rich.panel import Panel from rich.table import Table from cli.display import create_question_box from cli.utils import loading from tradingagents.backtesting import SimpleBacktestEngine from tradingagents.models.backtest import BacktestConfig, BacktestStatus from tradingagents.models.portfolio import PortfolioConfig console = Console() def sma_buy(ticker: str, trading_date: date_type, ctx: dict) -> bool: loader = ctx["data_loader"] ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date) if len(ohlcv.bars) < 20: return False prices = [float(b.close) for b in ohlcv.bars[-20:]] sma = sum(prices) / len(prices) current = float(ohlcv.bars[-1].close) return current > sma * 1.02 def sma_sell(ticker: str, trading_date: date_type, ctx: dict) -> bool: loader = ctx["data_loader"] ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date) if len(ohlcv.bars) < 20: return False prices = [float(b.close) for b in ohlcv.bars[-20:]] sma = sum(prices) / len(prices) current = float(ohlcv.bars[-1].close) return current < sma * 0.98 def rsi_buy(ticker: str, trading_date: date_type, ctx: dict) -> bool: loader = ctx["data_loader"] ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date) if len(ohlcv.bars) < 15: return False changes = [] for i in range(1, min(15, len(ohlcv.bars))): changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i - 1].close)) gains = [c for c in changes if c > 0] losses = [-c for c in changes if c < 0] avg_gain = sum(gains) / 14 if gains else 0.001 avg_loss = sum(losses) / 14 if losses else 0.001 rs = avg_gain / avg_loss if avg_loss else 100 rsi = 100 - (100 / (1 + rs)) return rsi < 30 def rsi_sell(ticker: str, trading_date: date_type, ctx: dict) -> bool: loader = ctx["data_loader"] ohlcv = loader.load_ohlcv(ticker, date_type(2020, 1, 1), trading_date) if len(ohlcv.bars) < 15: return False changes = [] for i in range(1, min(15, len(ohlcv.bars))): changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i - 1].close)) gains = [c for c in changes if c > 0] losses = [-c for c in changes if c < 0] avg_gain = sum(gains) / 14 if gains else 0.001 avg_loss = sum(losses) / 14 if losses else 0.001 rs = avg_gain / avg_loss if avg_loss else 100 rsi = 100 - (100 / (1 + rs)) return rsi > 70 def hold_buy(ticker: str, trading_date: date_type, ctx: dict) -> bool: return ctx.get("day_index", 0) == 5 def hold_sell(ticker: str, trading_date: date_type, ctx: dict) -> bool: return False STRATEGIES = { "sma": (sma_buy, sma_sell), "rsi": (rsi_buy, rsi_sell), "hold": (hold_buy, hold_sell), } def run_backtest( ticker: str = None, start_date: str = None, end_date: str = None, initial_cash: float = 100000.0, strategy: str = "sma", ) -> None: if not ticker: console.print( create_question_box( "Ticker Symbol", "Enter the ticker symbol to backtest", "AAPL" ) ) ticker = typer.prompt("", default="AAPL") if not start_date: default_start = ( datetime.datetime.now() - datetime.timedelta(days=365) ).strftime("%Y-%m-%d") console.print( create_question_box( "Start Date", "Enter backtest start date (YYYY-MM-DD)", default_start ) ) start_date = typer.prompt("", default=default_start) if not end_date: default_end = datetime.datetime.now().strftime("%Y-%m-%d") console.print( create_question_box( "End Date", "Enter backtest end date (YYYY-MM-DD)", default_end ) ) end_date = typer.prompt("", default=default_end) try: start = datetime.datetime.strptime(start_date, "%Y-%m-%d").date() end = datetime.datetime.strptime(end_date, "%Y-%m-%d").date() except ValueError: console.print("[red]Invalid date format. Use YYYY-MM-DD[/red]") return if start >= end: console.print("[red]Start date must be before end date[/red]") return console.print() console.print( Panel( f"[bold]Backtest Configuration[/bold]\n\n" f"Ticker: [cyan]{ticker.upper()}[/cyan]\n" f"Period: [cyan]{start_date}[/cyan] to [cyan]{end_date}[/cyan]\n" f"Initial Cash: [cyan]${initial_cash:,.2f}[/cyan]\n" f"Strategy: [cyan]{strategy}[/cyan]", title="Configuration", border_style="blue", ) ) console.print() if strategy not in STRATEGIES: console.print( f"[red]Unknown strategy: {strategy}. Use: sma, rsi, or hold[/red]" ) return buy_fn, sell_fn = STRATEGIES[strategy] config = BacktestConfig( name=f"{strategy.upper()} Backtest - {ticker.upper()}", tickers=[ticker.upper()], start_date=start, end_date=end, portfolio_config=PortfolioConfig( initial_cash=Decimal(str(initial_cash)), commission_per_trade=Decimal("1"), slippage_percent=Decimal("0.05"), ), warmup_period=5, ) with loading("Running backtest...", show_elapsed=True): engine = SimpleBacktestEngine(config, buy_signal=buy_fn, sell_signal=sell_fn) result = engine.run() console.print() if result.status == BacktestStatus.FAILED: console.print(f"[red]Backtest failed: {result.error_message}[/red]") return metrics = result.metrics trade_log = result.trade_log performance_table = Table(title="Performance Metrics", box=box.ROUNDED) performance_table.add_column("Metric", style="cyan") performance_table.add_column("Value", style="green") performance_table.add_row("Total Return", f"${float(metrics.total_return):,.2f}") performance_table.add_row( "Total Return %", f"{float(metrics.total_return_percent):.2f}%" ) performance_table.add_row( "Annualized Return", f"{float(metrics.annualized_return):.2f}%" ) performance_table.add_row( "Sharpe Ratio", f"{float(metrics.sharpe_ratio):.2f}" if metrics.sharpe_ratio else "N/A", ) performance_table.add_row( "Sortino Ratio", f"{float(metrics.sortino_ratio):.2f}" if metrics.sortino_ratio else "N/A", ) performance_table.add_row( "Max Drawdown", f"{float(metrics.max_drawdown_percent):.2f}%" ) performance_table.add_row( "Volatility (Ann.)", f"{float(metrics.annualized_volatility):.2f}%" ) console.print(performance_table) console.print() trading_table = Table(title="Trading Statistics", box=box.ROUNDED) trading_table.add_column("Metric", style="cyan") trading_table.add_column("Value", style="green") trading_table.add_row("Total Trades", str(trade_log.total_trades)) trading_table.add_row("Winning Trades", str(trade_log.winning_trades)) trading_table.add_row("Losing Trades", str(trade_log.losing_trades)) trading_table.add_row( "Win Rate", f"{float(trade_log.win_rate):.1f}%" if trade_log.win_rate else "N/A" ) trading_table.add_row( "Profit Factor", f"{float(trade_log.profit_factor):.2f}" if trade_log.profit_factor else "N/A", ) trading_table.add_row( "Avg Win", f"${float(trade_log.avg_win):,.2f}" if trade_log.avg_win else "N/A" ) trading_table.add_row( "Avg Loss", f"${float(trade_log.avg_loss):,.2f}" if trade_log.avg_loss else "N/A", ) console.print(trading_table) console.print() summary_table = Table(title="Portfolio Summary", box=box.ROUNDED) summary_table.add_column("Metric", style="cyan") summary_table.add_column("Value", style="green") summary_table.add_row("Start Equity", f"${float(metrics.start_equity):,.2f}") summary_table.add_row("End Equity", f"${float(metrics.end_equity):,.2f}") summary_table.add_row("Trading Days", str(metrics.trading_days)) summary_table.add_row("Duration", f"{result.duration_seconds:.1f} seconds") console.print(summary_table) console.print() console.print("[green]Backtest completed successfully![/green]")