From f70874982a625dfa59ab6f0229f4141aa72ce5f7 Mon Sep 17 00:00:00 2001 From: Joseph O'Brien <98370624+89jobrien@users.noreply.github.com> Date: Wed, 3 Dec 2025 02:55:28 -0500 Subject: [PATCH] feat: add backtesting framework and fix code quality issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add complete backtesting engine with portfolio simulation, metrics calculation (Sharpe, Sortino, max drawdown), and agent integration - Add Pydantic data models for market data, trading, portfolio, and backtest results - Add backtest CLI command with SMA, RSI, and hold strategies - Fix 24+ bare exception handlers with specific exception types - Fix hardcoded path in default_config.py (use TRADINGAGENTS_DATA_DIR env var) - Fix unclosed file handle in local.py with context manager - Disable store=True in OpenAI API calls for data privacy - Fix typo: rename aggresive_debator.py to aggressive_debator.py - Add request timeouts (30s) to alpha_vantage_common.py and googlenews_utils.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- cli/main.py | 192 +++++++++ tests/models/__init__.py | 0 tests/models/test_backtest.py | 318 +++++++++++++++ tests/models/test_market_data.py | 216 +++++++++++ tests/models/test_portfolio.py | 207 ++++++++++ tests/models/test_trading.py | 297 ++++++++++++++ tradingagents/agents/__init__.py | 2 +- ...esive_debator.py => aggressive_debator.py} | 0 tradingagents/backtesting/__init__.py | 13 + .../backtesting/agent_integration.py | 249 ++++++++++++ tradingagents/backtesting/data_loader.py | 244 ++++++++++++ tradingagents/backtesting/engine.py | 366 ++++++++++++++++++ tradingagents/backtesting/metrics.py | 281 ++++++++++++++ .../dataflows/alpha_vantage_common.py | 4 +- .../dataflows/alpha_vantage_indicator.py | 3 +- tradingagents/dataflows/alpha_vantage_news.py | 3 +- tradingagents/dataflows/brave.py | 2 +- tradingagents/dataflows/google.py | 7 +- tradingagents/dataflows/googlenews_utils.py | 6 +- tradingagents/dataflows/interface.py | 7 +- tradingagents/dataflows/local.py | 4 +- tradingagents/dataflows/openai.py | 8 +- tradingagents/dataflows/tavily.py | 4 +- .../dataflows/trending/sector_classifier.py | 2 +- .../dataflows/trending/stock_resolver.py | 6 +- tradingagents/dataflows/y_finance.py | 12 +- tradingagents/default_config.py | 2 +- tradingagents/graph/trading_graph.py | 2 +- tradingagents/models/__init__.py | 69 ++++ tradingagents/models/backtest.py | 242 ++++++++++++ tradingagents/models/decisions.py | 157 ++++++++ tradingagents/models/market_data.py | 144 +++++++ tradingagents/models/portfolio.py | 172 ++++++++ tradingagents/models/trading.py | 201 ++++++++++ 34 files changed, 3409 insertions(+), 33 deletions(-) create mode 100644 tests/models/__init__.py create mode 100644 tests/models/test_backtest.py create mode 100644 tests/models/test_market_data.py create mode 100644 tests/models/test_portfolio.py create mode 100644 tests/models/test_trading.py rename tradingagents/agents/risk_mgmt/{aggresive_debator.py => aggressive_debator.py} (100%) create mode 100644 tradingagents/backtesting/__init__.py create mode 100644 tradingagents/backtesting/agent_integration.py create mode 100644 tradingagents/backtesting/data_loader.py create mode 100644 tradingagents/backtesting/engine.py create mode 100644 tradingagents/backtesting/metrics.py create mode 100644 tradingagents/models/__init__.py create mode 100644 tradingagents/models/backtest.py create mode 100644 tradingagents/models/decisions.py create mode 100644 tradingagents/models/market_data.py create mode 100644 tradingagents/models/portfolio.py create mode 100644 tradingagents/models/trading.py diff --git a/cli/main.py b/cli/main.py index ac22b40c..8f2211e8 100644 --- a/cli/main.py +++ b/cli/main.py @@ -34,6 +34,9 @@ from tradingagents.agents.discovery.models import ( EventCategory, ) from tradingagents.agents.discovery.persistence import save_discovery_result +from tradingagents.backtesting import SimpleBacktestEngine, DataLoader +from tradingagents.models.backtest import BacktestConfig +from tradingagents.models.portfolio import PortfolioConfig from cli.models import AnalystType from cli.utils import ( ANALYST_ORDER, @@ -1715,6 +1718,195 @@ def menu(): discover_trending_flow() +@app.command() +def backtest( + ticker: str = typer.Option(None, "--ticker", "-t", help="Ticker symbol to backtest"), + start_date: str = typer.Option(None, "--start", "-s", help="Start date (YYYY-MM-DD)"), + end_date: str = typer.Option(None, "--end", "-e", help="End date (YYYY-MM-DD)"), + initial_cash: float = typer.Option(100000.0, "--cash", "-c", help="Initial portfolio cash"), + strategy: str = typer.Option("sma", "--strategy", help="Strategy: sma, rsi, or hold"), +): + from decimal import Decimal + from datetime import date as date_type + + if not ticker: + console.print(create_question_box("Ticker Symbol", "Enter the ticker symbol to backtest", "AAPL")) + ticker = typer.prompt("", default="AAPL") + + if not start_date: + default_start = (datetime.datetime.now() - datetime.timedelta(days=365)).strftime("%Y-%m-%d") + console.print(create_question_box("Start Date", "Enter backtest start date (YYYY-MM-DD)", default_start)) + start_date = typer.prompt("", default=default_start) + + if not end_date: + default_end = datetime.datetime.now().strftime("%Y-%m-%d") + console.print(create_question_box("End Date", "Enter backtest end date (YYYY-MM-DD)", default_end)) + end_date = typer.prompt("", default=default_end) + + try: + start = datetime.datetime.strptime(start_date, "%Y-%m-%d").date() + end = datetime.datetime.strptime(end_date, "%Y-%m-%d").date() + except ValueError: + console.print("[red]Invalid date format. Use YYYY-MM-DD[/red]") + return + + if start >= end: + console.print("[red]Start date must be before end date[/red]") + return + + console.print() + console.print(Panel( + f"[bold]Backtest Configuration[/bold]\n\n" + f"Ticker: [cyan]{ticker.upper()}[/cyan]\n" + f"Period: [cyan]{start_date}[/cyan] to [cyan]{end_date}[/cyan]\n" + f"Initial Cash: [cyan]${initial_cash:,.2f}[/cyan]\n" + f"Strategy: [cyan]{strategy}[/cyan]", + title="Configuration", + border_style="blue", + )) + console.print() + + def sma_buy(t, trading_date, ctx): + loader = ctx["data_loader"] + ohlcv = loader.load_ohlcv(t, date_type(2020, 1, 1), trading_date) + if len(ohlcv.bars) < 20: + return False + prices = [float(b.close) for b in ohlcv.bars[-20:]] + sma = sum(prices) / len(prices) + current = float(ohlcv.bars[-1].close) + return current > sma * 1.02 + + def sma_sell(t, trading_date, ctx): + loader = ctx["data_loader"] + ohlcv = loader.load_ohlcv(t, date_type(2020, 1, 1), trading_date) + if len(ohlcv.bars) < 20: + return False + prices = [float(b.close) for b in ohlcv.bars[-20:]] + sma = sum(prices) / len(prices) + current = float(ohlcv.bars[-1].close) + return current < sma * 0.98 + + def rsi_buy(t, trading_date, ctx): + loader = ctx["data_loader"] + ohlcv = loader.load_ohlcv(t, date_type(2020, 1, 1), trading_date) + if len(ohlcv.bars) < 15: + return False + changes = [] + for i in range(1, min(15, len(ohlcv.bars))): + changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i-1].close)) + gains = [c for c in changes if c > 0] + losses = [-c for c in changes if c < 0] + avg_gain = sum(gains) / 14 if gains else 0.001 + avg_loss = sum(losses) / 14 if losses else 0.001 + rs = avg_gain / avg_loss if avg_loss else 100 + rsi = 100 - (100 / (1 + rs)) + return rsi < 30 + + def rsi_sell(t, trading_date, ctx): + loader = ctx["data_loader"] + ohlcv = loader.load_ohlcv(t, date_type(2020, 1, 1), trading_date) + if len(ohlcv.bars) < 15: + return False + changes = [] + for i in range(1, min(15, len(ohlcv.bars))): + changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i-1].close)) + gains = [c for c in changes if c > 0] + losses = [-c for c in changes if c < 0] + avg_gain = sum(gains) / 14 if gains else 0.001 + avg_loss = sum(losses) / 14 if losses else 0.001 + rs = avg_gain / avg_loss if avg_loss else 100 + rsi = 100 - (100 / (1 + rs)) + return rsi > 70 + + def hold_buy(t, trading_date, ctx): + return ctx.get("day_index", 0) == 5 + + def hold_sell(t, trading_date, ctx): + return False + + strategies = { + "sma": (sma_buy, sma_sell), + "rsi": (rsi_buy, rsi_sell), + "hold": (hold_buy, hold_sell), + } + + if strategy not in strategies: + console.print(f"[red]Unknown strategy: {strategy}. Use: sma, rsi, or hold[/red]") + return + + buy_fn, sell_fn = strategies[strategy] + + config = BacktestConfig( + name=f"{strategy.upper()} Backtest - {ticker.upper()}", + tickers=[ticker.upper()], + start_date=start, + end_date=end, + portfolio_config=PortfolioConfig( + initial_cash=Decimal(str(initial_cash)), + commission_per_trade=Decimal("1"), + slippage_percent=Decimal("0.05"), + ), + warmup_period=5, + ) + + with loading("Running backtest...", show_elapsed=True): + engine = SimpleBacktestEngine(config, buy_signal=buy_fn, sell_signal=sell_fn) + result = engine.run() + + console.print() + + if result.status == "failed": + console.print(f"[red]Backtest failed: {result.error_message}[/red]") + return + + metrics = result.metrics + trade_log = result.trade_log + + performance_table = Table(title="Performance Metrics", box=box.ROUNDED) + performance_table.add_column("Metric", style="cyan") + performance_table.add_column("Value", style="green") + + performance_table.add_row("Total Return", f"${float(metrics.total_return):,.2f}") + performance_table.add_row("Total Return %", f"{float(metrics.total_return_percent):.2f}%") + performance_table.add_row("Annualized Return", f"{float(metrics.annualized_return):.2f}%") + performance_table.add_row("Sharpe Ratio", f"{float(metrics.sharpe_ratio):.2f}" if metrics.sharpe_ratio else "N/A") + performance_table.add_row("Sortino Ratio", f"{float(metrics.sortino_ratio):.2f}" if metrics.sortino_ratio else "N/A") + performance_table.add_row("Max Drawdown", f"{float(metrics.max_drawdown_percent):.2f}%") + performance_table.add_row("Volatility (Ann.)", f"{float(metrics.annualized_volatility):.2f}%") + + console.print(performance_table) + console.print() + + trading_table = Table(title="Trading Statistics", box=box.ROUNDED) + trading_table.add_column("Metric", style="cyan") + trading_table.add_column("Value", style="green") + + trading_table.add_row("Total Trades", str(trade_log.total_trades)) + trading_table.add_row("Winning Trades", str(trade_log.winning_trades)) + trading_table.add_row("Losing Trades", str(trade_log.losing_trades)) + trading_table.add_row("Win Rate", f"{float(trade_log.win_rate):.1f}%" if trade_log.win_rate else "N/A") + trading_table.add_row("Profit Factor", f"{float(trade_log.profit_factor):.2f}" if trade_log.profit_factor else "N/A") + trading_table.add_row("Avg Win", f"${float(trade_log.avg_win):,.2f}" if trade_log.avg_win else "N/A") + trading_table.add_row("Avg Loss", f"${float(trade_log.avg_loss):,.2f}" if trade_log.avg_loss else "N/A") + + console.print(trading_table) + console.print() + + summary_table = Table(title="Portfolio Summary", box=box.ROUNDED) + summary_table.add_column("Metric", style="cyan") + summary_table.add_column("Value", style="green") + + summary_table.add_row("Start Equity", f"${float(metrics.start_equity):,.2f}") + summary_table.add_row("End Equity", f"${float(metrics.end_equity):,.2f}") + summary_table.add_row("Trading Days", str(metrics.trading_days)) + summary_table.add_row("Duration", f"{result.duration_seconds:.1f} seconds") + + console.print(summary_table) + console.print() + + console.print(f"[green]Backtest completed successfully![/green]") + + if __name__ == "__main__": choice = show_main_menu() if choice == "analyze": diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/test_backtest.py b/tests/models/test_backtest.py new file mode 100644 index 00000000..809fe281 --- /dev/null +++ b/tests/models/test_backtest.py @@ -0,0 +1,318 @@ +from datetime import date, datetime +from decimal import Decimal + +import pytest + +from tradingagents.models.backtest import ( + BacktestConfig, + BacktestResult, + BacktestMetrics, + EquityCurvePoint, + TradeLog, +) +from tradingagents.models.portfolio import PortfolioConfig +from tradingagents.models.trading import Trade, OrderSide + + +class TestBacktestConfig: + def test_basic_config(self): + config = BacktestConfig( + tickers=["AAPL"], + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + ) + assert config.tickers == ["AAPL"] + assert config.interval == "1d" + assert config.benchmark_ticker == "SPY" + + def test_multi_ticker_config(self): + config = BacktestConfig( + name="Multi-Stock Test", + tickers=["aapl", "googl", "msft"], + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + ) + assert config.tickers == ["AAPL", "GOOGL", "MSFT"] + + def test_invalid_date_range(self): + with pytest.raises(ValueError): + BacktestConfig( + tickers=["AAPL"], + start_date=date(2024, 6, 30), + end_date=date(2024, 1, 1), + ) + + def test_same_start_end_date(self): + with pytest.raises(ValueError): + BacktestConfig( + tickers=["AAPL"], + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 1), + ) + + def test_empty_tickers(self): + with pytest.raises(ValueError): + BacktestConfig( + tickers=[], + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + ) + + def test_trading_days_estimate(self): + config = BacktestConfig( + tickers=["AAPL"], + start_date=date(2024, 1, 1), + end_date=date(2024, 12, 31), + ) + assert config.trading_days_estimate > 200 + assert config.trading_days_estimate < 260 + + def test_custom_portfolio_config(self): + portfolio_config = PortfolioConfig( + initial_cash=Decimal("50000"), + commission_per_trade=Decimal("5"), + ) + config = BacktestConfig( + tickers=["AAPL"], + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + portfolio_config=portfolio_config, + ) + assert config.portfolio_config.initial_cash == Decimal("50000") + + +class TestTradeLog: + def test_empty_trade_log(self): + log = TradeLog() + assert log.total_trades == 0 + assert log.win_rate is None + assert log.profit_factor is None + + def test_add_winning_trade(self): + log = TradeLog() + trade = Trade( + ticker="AAPL", + side=OrderSide.BUY, + entry_price=Decimal("100"), + entry_quantity=100, + entry_time=datetime(2024, 1, 15), + exit_price=Decimal("110"), + exit_quantity=100, + exit_time=datetime(2024, 1, 20), + ) + log.add_trade(trade) + assert log.total_trades == 1 + assert log.winning_trades == 1 + assert log.win_rate == Decimal("100") + + def test_add_losing_trade(self): + log = TradeLog() + trade = Trade( + ticker="AAPL", + side=OrderSide.BUY, + entry_price=Decimal("100"), + entry_quantity=100, + entry_time=datetime(2024, 1, 15), + exit_price=Decimal("90"), + exit_quantity=100, + exit_time=datetime(2024, 1, 20), + ) + log.add_trade(trade) + assert log.total_trades == 1 + assert log.losing_trades == 1 + assert log.win_rate == Decimal("0") + + def test_mixed_trades(self): + log = TradeLog() + + win_trade = Trade( + ticker="AAPL", + side=OrderSide.BUY, + entry_price=Decimal("100"), + entry_quantity=100, + entry_time=datetime(2024, 1, 15), + exit_price=Decimal("120"), + exit_quantity=100, + exit_time=datetime(2024, 1, 20), + ) + log.add_trade(win_trade) + + loss_trade = Trade( + ticker="GOOGL", + side=OrderSide.BUY, + entry_price=Decimal("100"), + entry_quantity=100, + entry_time=datetime(2024, 1, 15), + exit_price=Decimal("90"), + exit_quantity=100, + exit_time=datetime(2024, 1, 20), + ) + log.add_trade(loss_trade) + + assert log.total_trades == 2 + assert log.winning_trades == 1 + assert log.losing_trades == 1 + assert log.win_rate == Decimal("50") + + def test_gross_profit_loss(self): + log = TradeLog() + + win_trade = Trade( + ticker="AAPL", + side=OrderSide.BUY, + entry_price=Decimal("100"), + entry_quantity=100, + entry_time=datetime(2024, 1, 15), + exit_price=Decimal("120"), + exit_quantity=100, + exit_time=datetime(2024, 1, 20), + ) + log.add_trade(win_trade) + + loss_trade = Trade( + ticker="GOOGL", + side=OrderSide.BUY, + entry_price=Decimal("100"), + entry_quantity=100, + entry_time=datetime(2024, 1, 15), + exit_price=Decimal("90"), + exit_quantity=100, + exit_time=datetime(2024, 1, 20), + ) + log.add_trade(loss_trade) + + assert log.gross_profit == Decimal("2000") + assert log.gross_loss == Decimal("1000") + assert log.profit_factor == Decimal("2") + + def test_avg_win_loss(self): + log = TradeLog() + + for i in range(3): + trade = Trade( + ticker="AAPL", + side=OrderSide.BUY, + entry_price=Decimal("100"), + entry_quantity=100, + entry_time=datetime(2024, 1, 15), + exit_price=Decimal("110"), + exit_quantity=100, + exit_time=datetime(2024, 1, 20), + ) + log.add_trade(trade) + + assert log.avg_win == Decimal("1000") + + +class TestBacktestMetrics: + def test_basic_metrics(self): + metrics = BacktestMetrics( + total_return=Decimal("10000"), + total_return_percent=Decimal("10"), + annualized_return=Decimal("15"), + max_drawdown=Decimal("5000"), + max_drawdown_percent=Decimal("5"), + sharpe_ratio=Decimal("1.5"), + total_trades=50, + win_rate=Decimal("60"), + start_equity=Decimal("100000"), + end_equity=Decimal("110000"), + ) + assert metrics.total_return_percent == Decimal("10") + assert metrics.sharpe_ratio == Decimal("1.5") + + def test_to_summary_dict(self): + metrics = BacktestMetrics( + total_return=Decimal("10000"), + total_return_percent=Decimal("10"), + annualized_return=Decimal("15"), + max_drawdown=Decimal("5000"), + max_drawdown_percent=Decimal("5"), + sharpe_ratio=Decimal("1.5"), + sortino_ratio=Decimal("2.0"), + volatility=Decimal("10"), + annualized_volatility=Decimal("15"), + total_trades=50, + win_rate=Decimal("60"), + profit_factor=Decimal("1.8"), + total_commission=Decimal("500"), + total_slippage=Decimal("200"), + start_equity=Decimal("100000"), + end_equity=Decimal("110000"), + ) + summary = metrics.to_summary_dict() + + assert "Performance" in summary + assert "Risk" in summary + assert "Trading" in summary + assert "Costs" in summary + assert summary["Performance"]["Sharpe Ratio"] == "1.50" + + +class TestEquityCurvePoint: + def test_equity_point(self): + point = EquityCurvePoint( + timestamp=datetime(2024, 1, 15), + equity=Decimal("105000"), + cash=Decimal("50000"), + positions_value=Decimal("55000"), + drawdown=Decimal("2000"), + drawdown_percent=Decimal("1.9"), + ) + assert point.equity == Decimal("105000") + assert point.drawdown_percent == Decimal("1.9") + + +class TestBacktestResult: + def test_backtest_result(self): + config = BacktestConfig( + tickers=["AAPL"], + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + ) + metrics = BacktestMetrics( + total_return=Decimal("10000"), + total_return_percent=Decimal("10"), + start_equity=Decimal("100000"), + end_equity=Decimal("110000"), + ) + trade_log = TradeLog(total_trades=10, winning_trades=6, losing_trades=4) + + result = BacktestResult( + config=config, + metrics=metrics, + trade_log=trade_log, + started_at=datetime(2024, 7, 1, 10, 0, 0), + completed_at=datetime(2024, 7, 1, 10, 5, 30), + ) + + assert result.duration_seconds == 330.0 + assert result.status == "completed" + + def test_to_dict(self): + config = BacktestConfig( + name="Test Backtest", + tickers=["AAPL"], + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + ) + metrics = BacktestMetrics( + total_return=Decimal("10000"), + total_return_percent=Decimal("10"), + start_equity=Decimal("100000"), + end_equity=Decimal("110000"), + ) + trade_log = TradeLog(total_trades=10, winning_trades=6, losing_trades=4) + + result = BacktestResult( + config=config, + metrics=metrics, + trade_log=trade_log, + started_at=datetime(2024, 7, 1, 10, 0, 0), + completed_at=datetime(2024, 7, 1, 10, 5, 0), + ) + + result_dict = result.to_dict() + assert result_dict["config"]["name"] == "Test Backtest" + assert result_dict["trade_summary"]["total_trades"] == 10 + assert result_dict["trade_summary"]["win_rate"] == 60.0 diff --git a/tests/models/test_market_data.py b/tests/models/test_market_data.py new file mode 100644 index 00000000..be373682 --- /dev/null +++ b/tests/models/test_market_data.py @@ -0,0 +1,216 @@ +from datetime import datetime, date +from decimal import Decimal + +import pytest + +from tradingagents.models.market_data import ( + OHLCVBar, + OHLCV, + TechnicalIndicators, + MarketSnapshot, + HistoricalDataRequest, + HistoricalDataResponse, +) + + +class TestOHLCVBar: + def test_valid_bar(self): + bar = OHLCVBar( + timestamp=datetime(2024, 1, 15, 9, 30), + open=Decimal("100.00"), + high=Decimal("105.00"), + low=Decimal("99.00"), + close=Decimal("103.50"), + volume=1000000, + ) + assert bar.open == Decimal("100.00") + assert bar.high == Decimal("105.00") + assert bar.volume == 1000000 + + def test_bar_with_adjusted_close(self): + bar = OHLCVBar( + timestamp=datetime(2024, 1, 15), + open=Decimal("100"), + high=Decimal("105"), + low=Decimal("99"), + close=Decimal("103"), + volume=1000000, + adjusted_close=Decimal("102.50"), + ) + assert bar.adjusted_close == Decimal("102.50") + + def test_invalid_negative_price(self): + with pytest.raises(ValueError): + OHLCVBar( + timestamp=datetime(2024, 1, 15), + open=Decimal("-100"), + high=Decimal("105"), + low=Decimal("99"), + close=Decimal("103"), + volume=1000000, + ) + + def test_invalid_negative_volume(self): + with pytest.raises(ValueError): + OHLCVBar( + timestamp=datetime(2024, 1, 15), + open=Decimal("100"), + high=Decimal("105"), + low=Decimal("99"), + close=Decimal("103"), + volume=-1000, + ) + + +class TestOHLCV: + @pytest.fixture + def sample_bars(self): + return [ + OHLCVBar( + timestamp=datetime(2024, 1, 15), + open=Decimal("100"), + high=Decimal("105"), + low=Decimal("99"), + close=Decimal("103"), + volume=1000000, + ), + OHLCVBar( + timestamp=datetime(2024, 1, 16), + open=Decimal("103"), + high=Decimal("108"), + low=Decimal("102"), + close=Decimal("107"), + volume=1200000, + ), + OHLCVBar( + timestamp=datetime(2024, 1, 17), + open=Decimal("107"), + high=Decimal("110"), + low=Decimal("105"), + close=Decimal("109"), + volume=900000, + ), + ] + + def test_ohlcv_creation(self, sample_bars): + ohlcv = OHLCV(ticker="AAPL", bars=sample_bars) + assert ohlcv.ticker == "AAPL" + assert len(ohlcv.bars) == 3 + assert ohlcv.interval == "1d" + assert ohlcv.currency == "USD" + + def test_start_end_dates(self, sample_bars): + ohlcv = OHLCV(ticker="AAPL", bars=sample_bars) + assert ohlcv.start_date == datetime(2024, 1, 15) + assert ohlcv.end_date == datetime(2024, 1, 17) + + def test_empty_ohlcv(self): + ohlcv = OHLCV(ticker="AAPL", bars=[]) + assert ohlcv.start_date is None + assert ohlcv.end_date is None + + def test_get_bar(self, sample_bars): + ohlcv = OHLCV(ticker="AAPL", bars=sample_bars) + bar = ohlcv.get_bar(datetime(2024, 1, 16)) + assert bar is not None + assert bar.close == Decimal("107") + + def test_get_bar_not_found(self, sample_bars): + ohlcv = OHLCV(ticker="AAPL", bars=sample_bars) + bar = ohlcv.get_bar(datetime(2024, 1, 20)) + assert bar is None + + def test_slice(self, sample_bars): + ohlcv = OHLCV(ticker="AAPL", bars=sample_bars) + sliced = ohlcv.slice(datetime(2024, 1, 15), datetime(2024, 1, 16)) + assert len(sliced.bars) == 2 + assert sliced.ticker == "AAPL" + + def test_invalid_ticker(self): + with pytest.raises(ValueError): + OHLCV(ticker="", bars=[]) + + +class TestTechnicalIndicators: + def test_valid_indicators(self): + indicators = TechnicalIndicators( + timestamp=datetime(2024, 1, 15), + ticker="AAPL", + sma_50=Decimal("150.00"), + rsi_14=Decimal("65.5"), + macd=Decimal("2.5"), + ) + assert indicators.sma_50 == Decimal("150.00") + assert indicators.rsi_14 == Decimal("65.5") + + def test_rsi_bounds(self): + with pytest.raises(ValueError): + TechnicalIndicators( + timestamp=datetime(2024, 1, 15), + ticker="AAPL", + rsi_14=Decimal("150"), + ) + + def test_mfi_bounds(self): + with pytest.raises(ValueError): + TechnicalIndicators( + timestamp=datetime(2024, 1, 15), + ticker="AAPL", + mfi_14=Decimal("-10"), + ) + + +class TestMarketSnapshot: + def test_snapshot_change_calculation(self): + bar = OHLCVBar( + timestamp=datetime(2024, 1, 15), + open=Decimal("100"), + high=Decimal("105"), + low=Decimal("99"), + close=Decimal("103"), + volume=1000000, + ) + snapshot = MarketSnapshot( + ticker="AAPL", + timestamp=datetime(2024, 1, 15), + bar=bar, + prev_close=Decimal("100"), + ) + assert snapshot.change == Decimal("3") + assert snapshot.change_percent == Decimal("3") + + def test_snapshot_no_prev_close(self): + bar = OHLCVBar( + timestamp=datetime(2024, 1, 15), + open=Decimal("100"), + high=Decimal("105"), + low=Decimal("99"), + close=Decimal("103"), + volume=1000000, + ) + snapshot = MarketSnapshot( + ticker="AAPL", + timestamp=datetime(2024, 1, 15), + bar=bar, + ) + assert snapshot.change is None + assert snapshot.change_percent is None + + +class TestHistoricalDataRequest: + def test_valid_request(self): + request = HistoricalDataRequest( + ticker="AAPL", + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + ) + assert request.ticker == "AAPL" + assert request.include_indicators is True + + def test_invalid_date_range(self): + with pytest.raises(ValueError): + HistoricalDataRequest( + ticker="AAPL", + start_date=date(2024, 6, 30), + end_date=date(2024, 1, 1), + ) diff --git a/tests/models/test_portfolio.py b/tests/models/test_portfolio.py new file mode 100644 index 00000000..e73ecf33 --- /dev/null +++ b/tests/models/test_portfolio.py @@ -0,0 +1,207 @@ +from datetime import datetime +from decimal import Decimal +from uuid import uuid4 + +import pytest + +from tradingagents.models.portfolio import ( + PortfolioConfig, + PortfolioSnapshot, + CashTransaction, + TransactionType, +) +from tradingagents.models.trading import OrderSide, Fill, Position + + +class TestPortfolioConfig: + def test_default_config(self): + config = PortfolioConfig() + assert config.initial_cash == Decimal("100000") + assert config.commission_per_share == Decimal("0") + assert config.slippage_percent == Decimal("0") + + def test_custom_config(self): + config = PortfolioConfig( + initial_cash=Decimal("50000"), + commission_per_trade=Decimal("5.00"), + slippage_percent=Decimal("0.1"), + ) + assert config.initial_cash == Decimal("50000") + assert config.commission_per_trade == Decimal("5.00") + + def test_calculate_commission_flat(self): + config = PortfolioConfig(commission_per_trade=Decimal("5.00")) + commission = config.calculate_commission(100, Decimal("150.00")) + assert commission == Decimal("5.00") + + def test_calculate_commission_per_share(self): + config = PortfolioConfig(commission_per_share=Decimal("0.01")) + commission = config.calculate_commission(100, Decimal("150.00")) + assert commission == Decimal("1.00") + + def test_calculate_commission_percent(self): + config = PortfolioConfig(commission_percent=Decimal("0.1")) + commission = config.calculate_commission(100, Decimal("100.00")) + assert commission == Decimal("10.00") + + def test_calculate_commission_minimum(self): + config = PortfolioConfig( + commission_per_trade=Decimal("1.00"), + min_commission=Decimal("5.00"), + ) + commission = config.calculate_commission(10, Decimal("10.00")) + assert commission == Decimal("5.00") + + def test_calculate_commission_maximum(self): + config = PortfolioConfig( + commission_percent=Decimal("1"), + max_commission=Decimal("50.00"), + ) + commission = config.calculate_commission(1000, Decimal("100.00")) + assert commission == Decimal("50.00") + + def test_calculate_slippage_buy(self): + config = PortfolioConfig(slippage_percent=Decimal("0.1")) + price = config.calculate_slippage(Decimal("100.00"), OrderSide.BUY) + assert price == Decimal("100.10") + + def test_calculate_slippage_sell(self): + config = PortfolioConfig(slippage_percent=Decimal("0.1")) + price = config.calculate_slippage(Decimal("100.00"), OrderSide.SELL) + assert price == Decimal("99.90") + + +class TestPortfolioSnapshot: + def test_new_portfolio(self): + portfolio = PortfolioSnapshot(cash=Decimal("100000")) + assert portfolio.cash == Decimal("100000") + assert portfolio.position_count == 0 + assert len(portfolio.positions) == 0 + + def test_get_position_creates_new(self): + portfolio = PortfolioSnapshot(cash=Decimal("100000")) + position = portfolio.get_position("AAPL") + assert position.ticker == "AAPL" + assert position.quantity == 0 + assert "AAPL" in portfolio.positions + + def test_positions_value(self): + portfolio = PortfolioSnapshot( + cash=Decimal("50000"), + positions={ + "AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")), + "GOOGL": Position(ticker="GOOGL", quantity=50, avg_cost=Decimal("100")), + }, + ) + prices = {"AAPL": Decimal("160"), "GOOGL": Decimal("110")} + assert portfolio.positions_value(prices) == Decimal("21500") + + def test_total_equity(self): + portfolio = PortfolioSnapshot( + cash=Decimal("50000"), + positions={ + "AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")), + }, + ) + prices = {"AAPL": Decimal("160")} + assert portfolio.total_equity(prices) == Decimal("66000") + + def test_total_unrealized_pnl(self): + portfolio = PortfolioSnapshot( + cash=Decimal("50000"), + positions={ + "AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")), + }, + ) + prices = {"AAPL": Decimal("160")} + assert portfolio.total_unrealized_pnl(prices) == Decimal("1000") + + def test_apply_buy_fill(self): + portfolio = PortfolioSnapshot(cash=Decimal("100000")) + fill = Fill( + order_id=uuid4(), + ticker="AAPL", + side=OrderSide.BUY, + quantity=100, + price=Decimal("150.00"), + commission=Decimal("5.00"), + ) + portfolio.apply_fill(fill) + assert portfolio.cash == Decimal("84995.00") + assert portfolio.positions["AAPL"].quantity == 100 + assert portfolio.total_commission_paid == Decimal("5.00") + + def test_apply_sell_fill_with_profit(self): + portfolio = PortfolioSnapshot( + cash=Decimal("50000"), + positions={ + "AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")), + }, + ) + fill = Fill( + order_id=uuid4(), + ticker="AAPL", + side=OrderSide.SELL, + quantity=100, + price=Decimal("160.00"), + commission=Decimal("5.00"), + ) + portfolio.apply_fill(fill) + assert portfolio.cash == Decimal("65995.00") + assert portfolio.positions["AAPL"].quantity == 0 + assert portfolio.total_realized_pnl == Decimal("1000.00") + + def test_add_deposit(self): + portfolio = PortfolioSnapshot(cash=Decimal("100000")) + transaction = CashTransaction( + transaction_type=TransactionType.DEPOSIT, + amount=Decimal("10000"), + ) + portfolio.add_cash_transaction(transaction) + assert portfolio.cash == Decimal("110000") + assert len(portfolio.cash_transactions) == 1 + + def test_add_withdrawal(self): + portfolio = PortfolioSnapshot(cash=Decimal("100000")) + transaction = CashTransaction( + transaction_type=TransactionType.WITHDRAWAL, + amount=Decimal("10000"), + ) + portfolio.add_cash_transaction(transaction) + assert portfolio.cash == Decimal("90000") + + def test_can_afford(self): + portfolio = PortfolioSnapshot(cash=Decimal("10000")) + config = PortfolioConfig(commission_per_trade=Decimal("5")) + + assert portfolio.can_afford("AAPL", 10, Decimal("100"), config) + assert not portfolio.can_afford("AAPL", 100, Decimal("100"), config) + + def test_max_shares_affordable(self): + portfolio = PortfolioSnapshot(cash=Decimal("10000")) + config = PortfolioConfig(commission_per_trade=Decimal("0")) + + max_shares = portfolio.max_shares_affordable("AAPL", Decimal("100"), config) + assert max_shares == 100 + + def test_max_shares_affordable_with_commission(self): + portfolio = PortfolioSnapshot(cash=Decimal("10050")) + config = PortfolioConfig(commission_per_trade=Decimal("50")) + + max_shares = portfolio.max_shares_affordable("AAPL", Decimal("100"), config) + assert max_shares == 100 + + def test_to_dict(self): + portfolio = PortfolioSnapshot( + cash=Decimal("50000"), + positions={ + "AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")), + }, + ) + prices = {"AAPL": Decimal("160")} + result = portfolio.to_dict(prices) + + assert result["cash"] == 50000.0 + assert result["positions_value"] == 16000.0 + assert result["total_equity"] == 66000.0 + assert result["position_count"] == 1 diff --git a/tests/models/test_trading.py b/tests/models/test_trading.py new file mode 100644 index 00000000..57a52692 --- /dev/null +++ b/tests/models/test_trading.py @@ -0,0 +1,297 @@ +from datetime import datetime +from decimal import Decimal +from uuid import uuid4 + +import pytest + +from tradingagents.models.trading import ( + OrderSide, + OrderType, + OrderStatus, + PositionSide, + Order, + Fill, + Position, + Trade, +) + + +class TestOrder: + def test_market_order_creation(self): + order = Order( + ticker="AAPL", + side=OrderSide.BUY, + quantity=100, + ) + assert order.ticker == "AAPL" + assert order.side == OrderSide.BUY + assert order.order_type == OrderType.MARKET + assert order.quantity == 100 + assert order.status == OrderStatus.PENDING + assert order.remaining_quantity == 100 + assert not order.is_complete + + def test_limit_order_creation(self): + order = Order( + ticker="AAPL", + side=OrderSide.SELL, + order_type=OrderType.LIMIT, + quantity=50, + limit_price=Decimal("150.00"), + ) + assert order.order_type == OrderType.LIMIT + assert order.limit_price == Decimal("150.00") + + def test_order_partial_fill(self): + order = Order( + ticker="AAPL", + side=OrderSide.BUY, + quantity=100, + status=OrderStatus.PARTIAL, + filled_quantity=30, + ) + assert order.remaining_quantity == 70 + assert not order.is_complete + + def test_order_complete_states(self): + for status in [OrderStatus.FILLED, OrderStatus.CANCELLED, OrderStatus.REJECTED]: + order = Order( + ticker="AAPL", + side=OrderSide.BUY, + quantity=100, + status=status, + ) + assert order.is_complete + + def test_invalid_quantity(self): + with pytest.raises(ValueError): + Order(ticker="AAPL", side=OrderSide.BUY, quantity=0) + + def test_invalid_limit_price(self): + with pytest.raises(ValueError): + Order( + ticker="AAPL", + side=OrderSide.BUY, + order_type=OrderType.LIMIT, + quantity=100, + limit_price=Decimal("-10"), + ) + + +class TestFill: + def test_buy_fill(self): + order_id = uuid4() + fill = Fill( + order_id=order_id, + ticker="AAPL", + side=OrderSide.BUY, + quantity=100, + price=Decimal("150.00"), + commission=Decimal("1.00"), + ) + assert fill.total_value == Decimal("15000.00") + assert fill.total_cost == Decimal("15001.00") + + def test_sell_fill(self): + order_id = uuid4() + fill = Fill( + order_id=order_id, + ticker="AAPL", + side=OrderSide.SELL, + quantity=100, + price=Decimal("150.00"), + commission=Decimal("1.00"), + ) + assert fill.total_value == Decimal("15000.00") + assert fill.total_cost == Decimal("14999.00") + + +class TestPosition: + def test_new_position(self): + position = Position(ticker="AAPL") + assert position.quantity == 0 + assert position.side == PositionSide.FLAT + assert position.cost_basis == Decimal("0") + + def test_long_position(self): + position = Position( + ticker="AAPL", + quantity=100, + avg_cost=Decimal("150.00"), + ) + assert position.side == PositionSide.LONG + assert position.cost_basis == Decimal("15000.00") + + def test_short_position(self): + position = Position( + ticker="AAPL", + quantity=-100, + avg_cost=Decimal("150.00"), + ) + assert position.side == PositionSide.SHORT + assert position.cost_basis == Decimal("15000.00") + + def test_unrealized_pnl_long(self): + position = Position( + ticker="AAPL", + quantity=100, + avg_cost=Decimal("150.00"), + ) + pnl = position.unrealized_pnl(Decimal("160.00")) + assert pnl == Decimal("1000.00") + + def test_unrealized_pnl_short(self): + position = Position( + ticker="AAPL", + quantity=-100, + avg_cost=Decimal("150.00"), + ) + pnl = position.unrealized_pnl(Decimal("140.00")) + assert pnl == Decimal("1000.00") + + def test_update_from_buy_fill_new_position(self): + position = Position(ticker="AAPL") + fill = Fill( + order_id=uuid4(), + ticker="AAPL", + side=OrderSide.BUY, + quantity=100, + price=Decimal("150.00"), + ) + position.update_from_fill(fill) + assert position.quantity == 100 + assert position.avg_cost == Decimal("150.00") + + def test_update_from_buy_fill_add_to_position(self): + position = Position( + ticker="AAPL", + quantity=100, + avg_cost=Decimal("150.00"), + ) + fill = Fill( + order_id=uuid4(), + ticker="AAPL", + side=OrderSide.BUY, + quantity=100, + price=Decimal("160.00"), + ) + position.update_from_fill(fill) + assert position.quantity == 200 + assert position.avg_cost == Decimal("155.00") + + def test_update_from_sell_fill_close_position(self): + position = Position( + ticker="AAPL", + quantity=100, + avg_cost=Decimal("150.00"), + ) + fill = Fill( + order_id=uuid4(), + ticker="AAPL", + side=OrderSide.SELL, + quantity=100, + price=Decimal("160.00"), + ) + position.update_from_fill(fill) + assert position.quantity == 0 + assert position.realized_pnl == Decimal("1000.00") + + def test_update_from_sell_fill_partial_close(self): + position = Position( + ticker="AAPL", + quantity=100, + avg_cost=Decimal("150.00"), + ) + fill = Fill( + order_id=uuid4(), + ticker="AAPL", + side=OrderSide.SELL, + quantity=50, + price=Decimal("160.00"), + ) + position.update_from_fill(fill) + assert position.quantity == 50 + assert position.realized_pnl == Decimal("500.00") + assert position.avg_cost == Decimal("150.00") + + +class TestTrade: + def test_open_trade(self): + trade = Trade( + ticker="AAPL", + side=OrderSide.BUY, + entry_price=Decimal("150.00"), + entry_quantity=100, + entry_time=datetime(2024, 1, 15, 9, 30), + ) + assert not trade.is_closed + assert trade.pnl is None + assert trade.holding_period is None + + def test_closed_trade_profit(self): + trade = Trade( + ticker="AAPL", + side=OrderSide.BUY, + entry_price=Decimal("150.00"), + entry_quantity=100, + entry_time=datetime(2024, 1, 15), + exit_price=Decimal("160.00"), + exit_quantity=100, + exit_time=datetime(2024, 1, 25), + ) + assert trade.is_closed + assert trade.pnl == Decimal("1000.00") + assert trade.holding_period == 10 + + def test_closed_trade_loss(self): + trade = Trade( + ticker="AAPL", + side=OrderSide.BUY, + entry_price=Decimal("150.00"), + entry_quantity=100, + entry_time=datetime(2024, 1, 15), + exit_price=Decimal("140.00"), + exit_quantity=100, + exit_time=datetime(2024, 1, 20), + ) + assert trade.pnl == Decimal("-1000.00") + + def test_trade_with_commission(self): + trade = Trade( + ticker="AAPL", + side=OrderSide.BUY, + entry_price=Decimal("150.00"), + entry_quantity=100, + entry_time=datetime(2024, 1, 15), + exit_price=Decimal("160.00"), + exit_quantity=100, + exit_time=datetime(2024, 1, 20), + commission=Decimal("10.00"), + ) + assert trade.pnl == Decimal("990.00") + + def test_short_trade_profit(self): + trade = Trade( + ticker="AAPL", + side=OrderSide.SELL, + entry_price=Decimal("150.00"), + entry_quantity=100, + entry_time=datetime(2024, 1, 15), + exit_price=Decimal("140.00"), + exit_quantity=100, + exit_time=datetime(2024, 1, 20), + ) + assert trade.pnl == Decimal("1000.00") + + def test_pnl_percent(self): + trade = Trade( + ticker="AAPL", + side=OrderSide.BUY, + entry_price=Decimal("100.00"), + entry_quantity=100, + entry_time=datetime(2024, 1, 15), + exit_price=Decimal("110.00"), + exit_quantity=100, + exit_time=datetime(2024, 1, 20), + ) + assert trade.pnl_percent == Decimal("10") diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index d84d9eb1..455334e1 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -10,7 +10,7 @@ from .analysts.social_media_analyst import create_social_media_analyst from .researchers.bear_researcher import create_bear_researcher from .researchers.bull_researcher import create_bull_researcher -from .risk_mgmt.aggresive_debator import create_risky_debator +from .risk_mgmt.aggressive_debator import create_risky_debator from .risk_mgmt.conservative_debator import create_safe_debator from .risk_mgmt.neutral_debator import create_neutral_debator diff --git a/tradingagents/agents/risk_mgmt/aggresive_debator.py b/tradingagents/agents/risk_mgmt/aggressive_debator.py similarity index 100% rename from tradingagents/agents/risk_mgmt/aggresive_debator.py rename to tradingagents/agents/risk_mgmt/aggressive_debator.py diff --git a/tradingagents/backtesting/__init__.py b/tradingagents/backtesting/__init__.py new file mode 100644 index 00000000..bef3308c --- /dev/null +++ b/tradingagents/backtesting/__init__.py @@ -0,0 +1,13 @@ +from .data_loader import DataLoader +from .engine import BacktestEngine, SimpleBacktestEngine +from .metrics import MetricsCalculator +from .agent_integration import AgentBacktestEngine, run_agent_backtest + +__all__ = [ + "DataLoader", + "BacktestEngine", + "SimpleBacktestEngine", + "MetricsCalculator", + "AgentBacktestEngine", + "run_agent_backtest", +] diff --git a/tradingagents/backtesting/agent_integration.py b/tradingagents/backtesting/agent_integration.py new file mode 100644 index 00000000..05cab9ae --- /dev/null +++ b/tradingagents/backtesting/agent_integration.py @@ -0,0 +1,249 @@ +import logging +import re +from datetime import date, datetime +from decimal import Decimal +from typing import Optional, Dict, Any + +from tradingagents.graph.trading_graph import TradingAgentsGraph +from tradingagents.models.backtest import BacktestConfig, BacktestResult +from tradingagents.models.decisions import ( + SignalType, + TradingDecision, + AnalystReport, + AnalystType, +) +from tradingagents.models.portfolio import PortfolioSnapshot + +from .engine import BacktestEngine +from .data_loader import DataLoader + +logger = logging.getLogger(__name__) + + +class AgentBacktestEngine(BacktestEngine): + def __init__( + self, + config: BacktestConfig, + agent_config: Optional[Dict[str, Any]] = None, + ): + super().__init__(config) + self.agent_config = agent_config or config.agent_config + self.trading_graph: Optional[TradingAgentsGraph] = None + self._decision_cache: Dict[str, TradingDecision] = {} + + def _initialize(self): + super()._initialize() + + graph_config = { + **self.agent_config, + } + + self.trading_graph = TradingAgentsGraph( + selected_analysts=self.agent_config.get( + "selected_analysts", + ["market", "social", "news", "fundamentals"], + ), + debug=self.agent_config.get("debug", False), + config=graph_config if graph_config else None, + ) + + def _get_decision( + self, + ticker: str, + trading_date: date, + day_index: int, + ) -> Optional[TradingDecision]: + cache_key = f"{ticker}_{trading_date}" + if cache_key in self._decision_cache: + return self._decision_cache[cache_key] + + try: + final_state, signal_info = self.trading_graph.propagate( + ticker, trading_date + ) + + decision = self._parse_agent_decision( + ticker, trading_date, final_state, signal_info + ) + + self._decision_cache[cache_key] = decision + return decision + + except (ValueError, KeyError, RuntimeError, ConnectionError, TimeoutError) as e: + logger.error( + "Agent decision failed for %s on %s: %s", + ticker, trading_date, e + ) + return None + + def _parse_agent_decision( + self, + ticker: str, + trading_date: date, + final_state: Dict[str, Any], + signal_info: Dict[str, Any], + ) -> TradingDecision: + signal = self._extract_signal(signal_info) + confidence = self._extract_confidence(signal_info) + + analyst_reports = [] + + if final_state.get("market_report"): + analyst_reports.append( + AnalystReport( + analyst_type=AnalystType.MARKET, + ticker=ticker, + report_date=datetime.combine(trading_date, datetime.min.time()), + summary=final_state["market_report"][:500], + raw_content=final_state["market_report"], + ) + ) + + if final_state.get("sentiment_report"): + analyst_reports.append( + AnalystReport( + analyst_type=AnalystType.SENTIMENT, + ticker=ticker, + report_date=datetime.combine(trading_date, datetime.min.time()), + summary=final_state["sentiment_report"][:500], + raw_content=final_state["sentiment_report"], + ) + ) + + if final_state.get("news_report"): + analyst_reports.append( + AnalystReport( + analyst_type=AnalystType.NEWS, + ticker=ticker, + report_date=datetime.combine(trading_date, datetime.min.time()), + summary=final_state["news_report"][:500], + raw_content=final_state["news_report"], + ) + ) + + if final_state.get("fundamentals_report"): + analyst_reports.append( + AnalystReport( + analyst_type=AnalystType.FUNDAMENTALS, + ticker=ticker, + report_date=datetime.combine(trading_date, datetime.min.time()), + summary=final_state["fundamentals_report"][:500], + raw_content=final_state["fundamentals_report"], + ) + ) + + debate_state = final_state.get("investment_debate_state", {}) + bull_argument = None + bear_argument = None + + if debate_state.get("bull_history"): + bull_argument = debate_state["bull_history"][-1] if debate_state["bull_history"] else None + if debate_state.get("bear_history"): + bear_argument = debate_state["bear_history"][-1] if debate_state["bear_history"] else None + + risk_state = final_state.get("risk_debate_state", {}) + risk_approved = self._extract_risk_approval(risk_state) + + final_decision_text = final_state.get("final_trade_decision", "") + recommended_action = self._extract_action(signal_info, final_decision_text) + + return TradingDecision( + ticker=ticker, + timestamp=datetime.now(), + decision_date=datetime.combine(trading_date, datetime.min.time()), + signal=signal, + confidence=confidence, + recommended_action=recommended_action, + analyst_reports=analyst_reports, + bull_argument=bull_argument, + bear_argument=bear_argument, + debate_rounds=debate_state.get("count", 0), + risk_manager_approved=risk_approved, + final_decision=recommended_action, + rationale=final_decision_text[:1000] if final_decision_text else "", + ) + + def _extract_signal(self, signal_info: Dict[str, Any]) -> SignalType: + action = signal_info.get("action", "").upper() + direction = signal_info.get("direction", "").upper() + + if action == "BUY" or direction == "BULLISH": + confidence = signal_info.get("confidence", 0.5) + if confidence > 0.8: + return SignalType.STRONG_BUY + return SignalType.BUY + + elif action == "SELL" or direction == "BEARISH": + confidence = signal_info.get("confidence", 0.5) + if confidence > 0.8: + return SignalType.STRONG_SELL + return SignalType.SELL + + return SignalType.HOLD + + def _extract_confidence(self, signal_info: Dict[str, Any]) -> Decimal: + confidence = signal_info.get("confidence", 0.5) + if isinstance(confidence, str): + try: + confidence = float(confidence.replace("%", "")) / 100 + except ValueError: + confidence = 0.5 + + return Decimal(str(min(max(float(confidence), 0.0), 1.0))) + + def _extract_action( + self, + signal_info: Dict[str, Any], + final_decision_text: str, + ) -> str: + action = signal_info.get("action", "") + if action: + return action.upper() + + text_upper = final_decision_text.upper() + if "BUY" in text_upper and "DON'T BUY" not in text_upper: + return "BUY" + elif "SELL" in text_upper: + return "SELL" + + return "HOLD" + + def _extract_risk_approval(self, risk_state: Dict[str, Any]) -> Optional[bool]: + judge_decision = risk_state.get("judge_decision", "") + if not judge_decision: + return None + + text_upper = judge_decision.upper() + if "APPROVE" in text_upper or "ACCEPT" in text_upper: + return True + elif "REJECT" in text_upper or "DENY" in text_upper: + return False + + return None + + +def run_agent_backtest( + tickers: list[str], + start_date: date, + end_date: date, + initial_cash: Decimal = Decimal("100000"), + agent_config: Optional[Dict[str, Any]] = None, +) -> BacktestResult: + from tradingagents.models.portfolio import PortfolioConfig + + config = BacktestConfig( + name=f"Agent Backtest - {', '.join(tickers)}", + tickers=tickers, + start_date=start_date, + end_date=end_date, + portfolio_config=PortfolioConfig( + initial_cash=initial_cash, + commission_per_trade=Decimal("1"), + slippage_percent=Decimal("0.05"), + ), + warmup_period=5, + agent_config=agent_config or {}, + ) + + engine = AgentBacktestEngine(config, agent_config) + return engine.run() diff --git a/tradingagents/backtesting/data_loader.py b/tradingagents/backtesting/data_loader.py new file mode 100644 index 00000000..6bc79027 --- /dev/null +++ b/tradingagents/backtesting/data_loader.py @@ -0,0 +1,244 @@ +import logging +from datetime import date, datetime, timedelta +from decimal import Decimal +from typing import Optional + +import pandas as pd +import yfinance as yf +from stockstats import wrap + +from tradingagents.models.market_data import ( + OHLCV, + OHLCVBar, + TechnicalIndicators, + HistoricalDataRequest, + HistoricalDataResponse, +) + +logger = logging.getLogger(__name__) + + +class DataLoader: + def __init__(self, cache_dir: Optional[str] = None): + self.cache_dir = cache_dir + self._cache: dict[str, pd.DataFrame] = {} + + def load_ohlcv( + self, + ticker: str, + start_date: date, + end_date: date, + interval: str = "1d", + ) -> OHLCV: + ticker = ticker.upper() + cache_key = f"{ticker}_{start_date}_{end_date}_{interval}" + + if cache_key in self._cache: + df = self._cache[cache_key] + else: + df = self._fetch_from_yfinance(ticker, start_date, end_date, interval) + self._cache[cache_key] = df + + bars = self._dataframe_to_bars(df) + return OHLCV(ticker=ticker, bars=bars, interval=interval) + + def load_historical_data( + self, + request: HistoricalDataRequest, + ) -> HistoricalDataResponse: + ohlcv = self.load_ohlcv( + request.ticker, + request.start_date, + request.end_date, + request.interval, + ) + + indicators = [] + if request.include_indicators and ohlcv.bars: + indicators = self._calculate_indicators( + request.ticker, + request.start_date, + request.end_date, + ) + + return HistoricalDataResponse( + request=request, + ohlcv=ohlcv, + indicators=indicators, + source="yfinance", + ) + + def _fetch_from_yfinance( + self, + ticker: str, + start_date: date, + end_date: date, + interval: str, + ) -> pd.DataFrame: + start_str = start_date.strftime("%Y-%m-%d") + end_str = (end_date + timedelta(days=1)).strftime("%Y-%m-%d") + + df = yf.download( + ticker, + start=start_str, + end=end_str, + interval=interval, + multi_level_index=False, + progress=False, + auto_adjust=False, + ) + + if df.empty: + logger.warning("No data returned for %s from %s to %s", ticker, start_date, end_date) + return pd.DataFrame() + + df = df.reset_index() + return df + + def _dataframe_to_bars(self, df: pd.DataFrame) -> list[OHLCVBar]: + if df.empty: + return [] + + bars = [] + for _, row in df.iterrows(): + timestamp = row.get("Date") or row.get("Datetime") + if isinstance(timestamp, str): + timestamp = pd.to_datetime(timestamp) + if hasattr(timestamp, "to_pydatetime"): + timestamp = timestamp.to_pydatetime() + if timestamp.tzinfo is not None: + timestamp = timestamp.replace(tzinfo=None) + + bar = OHLCVBar( + timestamp=timestamp, + open=Decimal(str(round(row["Open"], 4))), + high=Decimal(str(round(row["High"], 4))), + low=Decimal(str(round(row["Low"], 4))), + close=Decimal(str(round(row["Close"], 4))), + volume=int(row["Volume"]), + adjusted_close=Decimal(str(round(row["Adj Close"], 4))) if "Adj Close" in row else None, + ) + bars.append(bar) + + return bars + + def _calculate_indicators( + self, + ticker: str, + start_date: date, + end_date: date, + ) -> list[TechnicalIndicators]: + lookback_start = start_date - timedelta(days=250) + cache_key = f"{ticker}_{lookback_start}_{end_date}_1d" + + if cache_key in self._cache: + df = self._cache[cache_key] + else: + df = self._fetch_from_yfinance(ticker, lookback_start, end_date, "1d") + self._cache[cache_key] = df + + if df.empty: + return [] + + stock = wrap(df.copy()) + + stock["close_20_sma"] + stock["close_50_sma"] + stock["close_200_sma"] + stock["close_10_ema"] + stock["close_20_ema"] + stock["rsi_14"] + stock["macd"] + stock["macds"] + stock["macdh"] + stock["boll"] + stock["boll_ub"] + stock["boll_lb"] + stock["atr_14"] + stock["mfi_14"] + + indicators = [] + for _, row in stock.iterrows(): + timestamp = row.get("Date") or row.get("Datetime") + if isinstance(timestamp, str): + timestamp = pd.to_datetime(timestamp) + if hasattr(timestamp, "to_pydatetime"): + timestamp = timestamp.to_pydatetime() + if timestamp.tzinfo is not None: + timestamp = timestamp.replace(tzinfo=None) + + if timestamp.date() < start_date or timestamp.date() > end_date: + continue + + ind = TechnicalIndicators( + timestamp=timestamp, + ticker=ticker, + sma_20=self._safe_decimal(row.get("close_20_sma")), + sma_50=self._safe_decimal(row.get("close_50_sma")), + sma_200=self._safe_decimal(row.get("close_200_sma")), + ema_10=self._safe_decimal(row.get("close_10_ema")), + ema_20=self._safe_decimal(row.get("close_20_ema")), + rsi_14=self._safe_decimal(row.get("rsi_14")), + macd=self._safe_decimal(row.get("macd")), + macd_signal=self._safe_decimal(row.get("macds")), + macd_histogram=self._safe_decimal(row.get("macdh")), + bollinger_middle=self._safe_decimal(row.get("boll")), + bollinger_upper=self._safe_decimal(row.get("boll_ub")), + bollinger_lower=self._safe_decimal(row.get("boll_lb")), + atr_14=self._safe_decimal(row.get("atr_14")), + mfi_14=self._safe_decimal(row.get("mfi_14")), + ) + indicators.append(ind) + + return indicators + + @staticmethod + def _safe_decimal(value) -> Optional[Decimal]: + if value is None or pd.isna(value): + return None + return Decimal(str(round(float(value), 4))) + + def get_price_on_date( + self, + ticker: str, + target_date: date, + ohlcv: Optional[OHLCV] = None, + ) -> Optional[Decimal]: + if ohlcv is None: + ohlcv = self.load_ohlcv(ticker, target_date - timedelta(days=5), target_date) + + target_datetime = datetime.combine(target_date, datetime.min.time()) + bar = ohlcv.get_bar(target_datetime) + + if bar: + return bar.close + + for b in reversed(ohlcv.bars): + if b.timestamp.date() <= target_date: + return b.close + + return None + + def get_prices_dict( + self, + tickers: list[str], + target_date: date, + ) -> dict[str, Decimal]: + prices = {} + for ticker in tickers: + price = self.get_price_on_date(ticker, target_date) + if price is not None: + prices[ticker] = price + return prices + + def get_trading_days( + self, + ticker: str, + start_date: date, + end_date: date, + ) -> list[date]: + ohlcv = self.load_ohlcv(ticker, start_date, end_date) + return [bar.timestamp.date() for bar in ohlcv.bars] + + def clear_cache(self): + self._cache.clear() diff --git a/tradingagents/backtesting/engine.py b/tradingagents/backtesting/engine.py new file mode 100644 index 00000000..ae9e75d8 --- /dev/null +++ b/tradingagents/backtesting/engine.py @@ -0,0 +1,366 @@ +import logging +from datetime import date, datetime, timedelta +from decimal import Decimal +from typing import Optional, Callable +from uuid import uuid4 + +from tradingagents.models.backtest import ( + BacktestConfig, + BacktestResult, + EquityCurvePoint, + TradeLog, +) +from tradingagents.models.decisions import SignalType, TradingDecision +from tradingagents.models.portfolio import PortfolioSnapshot +from tradingagents.models.trading import Order, OrderSide, OrderStatus, Fill, Trade + +from .data_loader import DataLoader +from .metrics import MetricsCalculator + +logger = logging.getLogger(__name__) + + +class BacktestEngine: + def __init__( + self, + config: BacktestConfig, + decision_callback: Optional[Callable[[str, date, dict], TradingDecision]] = None, + ): + self.config = config + self.decision_callback = decision_callback + self.data_loader = DataLoader() + self.metrics_calculator = MetricsCalculator(config.risk_free_rate) + + self.portfolio: Optional[PortfolioSnapshot] = None + self.trade_log: Optional[TradeLog] = None + self.equity_curve: list[EquityCurvePoint] = [] + self.daily_returns: list[Decimal] = [] + self.decisions: list[TradingDecision] = [] + self.open_trades: dict[str, Trade] = {} + + def run(self) -> BacktestResult: + started_at = datetime.now() + + try: + self._initialize() + self._preload_data() + trading_days = self._get_trading_days() + + for i, trading_date in enumerate(trading_days): + if i < self.config.warmup_period: + continue + + self._process_day(trading_date, i) + + self._close_all_positions(trading_days[-1] if trading_days else self.config.end_date) + + metrics = self.metrics_calculator.calculate_metrics( + self.equity_curve, + self.trade_log, + ) + + completed_at = datetime.now() + + return BacktestResult( + config=self.config, + metrics=metrics, + trade_log=self.trade_log, + equity_curve=self.equity_curve, + daily_returns=self.daily_returns, + started_at=started_at, + completed_at=completed_at, + status="completed", + ) + + except (ValueError, KeyError, RuntimeError, FileNotFoundError, OSError) as e: + logger.exception("Backtest failed: %s", e) + completed_at = datetime.now() + + return BacktestResult( + config=self.config, + metrics=self._empty_metrics(), + trade_log=self.trade_log or TradeLog(), + equity_curve=self.equity_curve, + daily_returns=self.daily_returns, + started_at=started_at, + completed_at=completed_at, + status="failed", + error_message=str(e), + ) + + def _initialize(self): + self.portfolio = PortfolioSnapshot( + cash=self.config.portfolio_config.initial_cash, + ) + self.trade_log = TradeLog() + self.equity_curve = [] + self.daily_returns = [] + self.decisions = [] + self.open_trades = {} + + def _preload_data(self): + logger.info("Preloading data for %s tickers", len(self.config.tickers)) + for ticker in self.config.tickers: + self.data_loader.load_ohlcv( + ticker, + self.config.start_date - timedelta(days=self.config.warmup_period + 10), + self.config.end_date, + ) + + def _get_trading_days(self) -> list[date]: + primary_ticker = self.config.tickers[0] + return self.data_loader.get_trading_days( + primary_ticker, + self.config.start_date, + self.config.end_date, + ) + + def _process_day(self, trading_date: date, day_index: int): + prices = self.data_loader.get_prices_dict(self.config.tickers, trading_date) + + if not prices: + logger.debug("No prices available for %s", trading_date) + return + + for ticker in self.config.tickers: + if ticker not in prices: + continue + + decision = self._get_decision(ticker, trading_date, day_index) + if decision: + self.decisions.append(decision) + self._execute_decision(decision, prices[ticker], trading_date) + + self._record_equity(trading_date, prices) + + def _get_decision( + self, + ticker: str, + trading_date: date, + day_index: int, + ) -> Optional[TradingDecision]: + if self.decision_callback: + context = { + "day_index": day_index, + "portfolio": self.portfolio, + "open_trade": self.open_trades.get(ticker), + } + return self.decision_callback(ticker, trading_date, context) + + return self._simple_strategy(ticker, trading_date) + + def _simple_strategy( + self, + ticker: str, + trading_date: date, + ) -> Optional[TradingDecision]: + return None + + def _execute_decision( + self, + decision: TradingDecision, + price: Decimal, + trading_date: date, + ): + ticker = decision.ticker + config = self.config.portfolio_config + position = self.portfolio.get_position(ticker) + + if decision.is_buy and position.quantity == 0: + execution_price = config.calculate_slippage(price, OrderSide.BUY) + + if decision.recommended_quantity: + quantity = decision.recommended_quantity + else: + max_position_value = self.portfolio.cash * (config.max_position_size_percent / 100) + quantity = int(max_position_value / execution_price) + + if quantity <= 0: + return + + if not self.portfolio.can_afford(ticker, quantity, execution_price, config): + quantity = self.portfolio.max_shares_affordable(ticker, execution_price, config) + + if quantity <= 0: + return + + commission = config.calculate_commission(quantity, execution_price) + + order = Order( + ticker=ticker, + side=OrderSide.BUY, + quantity=quantity, + status=OrderStatus.FILLED, + filled_quantity=quantity, + filled_avg_price=execution_price, + filled_at=datetime.combine(trading_date, datetime.min.time()), + commission=commission, + ) + + fill = Fill( + order_id=order.id, + ticker=ticker, + side=OrderSide.BUY, + quantity=quantity, + price=execution_price, + commission=commission, + timestamp=datetime.combine(trading_date, datetime.min.time()), + ) + + self.portfolio.apply_fill(fill) + + trade = Trade( + ticker=ticker, + side=OrderSide.BUY, + entry_price=execution_price, + entry_quantity=quantity, + entry_time=datetime.combine(trading_date, datetime.min.time()), + entry_order_id=order.id, + ) + self.open_trades[ticker] = trade + + logger.debug( + "BUY %s: %d shares @ $%.2f on %s", + ticker, quantity, execution_price, trading_date + ) + + elif decision.is_sell and position.quantity > 0: + execution_price = config.calculate_slippage(price, OrderSide.SELL) + quantity = position.quantity + commission = config.calculate_commission(quantity, execution_price) + + order = Order( + ticker=ticker, + side=OrderSide.SELL, + quantity=quantity, + status=OrderStatus.FILLED, + filled_quantity=quantity, + filled_avg_price=execution_price, + filled_at=datetime.combine(trading_date, datetime.min.time()), + commission=commission, + ) + + fill = Fill( + order_id=order.id, + ticker=ticker, + side=OrderSide.SELL, + quantity=quantity, + price=execution_price, + commission=commission, + timestamp=datetime.combine(trading_date, datetime.min.time()), + ) + + self.portfolio.apply_fill(fill) + + if ticker in self.open_trades: + trade = self.open_trades[ticker] + trade.exit_price = execution_price + trade.exit_quantity = quantity + trade.exit_time = datetime.combine(trading_date, datetime.min.time()) + trade.exit_order_id = order.id + trade.commission = ( + config.calculate_commission(trade.entry_quantity, trade.entry_price) + + commission + ) + self.trade_log.add_trade(trade) + del self.open_trades[ticker] + + logger.debug( + "SELL %s: %d shares @ $%.2f on %s", + ticker, quantity, execution_price, trading_date + ) + + def _record_equity(self, trading_date: date, prices: dict[str, Decimal]): + equity = self.portfolio.total_equity(prices) + positions_value = self.portfolio.positions_value(prices) + + point = EquityCurvePoint( + timestamp=datetime.combine(trading_date, datetime.min.time()), + equity=equity, + cash=self.portfolio.cash, + positions_value=positions_value, + ) + self.equity_curve.append(point) + + if len(self.equity_curve) > 1: + prev_equity = self.equity_curve[-2].equity + if prev_equity > 0: + daily_return = (equity - prev_equity) / prev_equity + self.daily_returns.append(daily_return) + + def _close_all_positions(self, final_date: date): + prices = self.data_loader.get_prices_dict(self.config.tickers, final_date) + + for ticker, trade in list(self.open_trades.items()): + if ticker in prices: + decision = TradingDecision( + ticker=ticker, + timestamp=datetime.now(), + decision_date=datetime.combine(final_date, datetime.min.time()), + signal=SignalType.SELL, + confidence=Decimal("1.0"), + recommended_action="SELL", + final_decision="SELL - End of backtest", + rationale="Closing position at end of backtest period", + ) + self._execute_decision(decision, prices[ticker], final_date) + + def _empty_metrics(self): + from tradingagents.models.backtest import BacktestMetrics + return BacktestMetrics( + start_equity=self.config.portfolio_config.initial_cash, + end_equity=self.portfolio.cash if self.portfolio else self.config.portfolio_config.initial_cash, + ) + + +class SimpleBacktestEngine(BacktestEngine): + def __init__( + self, + config: BacktestConfig, + buy_signal: Callable[[str, date, dict], bool] = None, + sell_signal: Callable[[str, date, dict], bool] = None, + ): + super().__init__(config) + self.buy_signal = buy_signal + self.sell_signal = sell_signal + + def _get_decision( + self, + ticker: str, + trading_date: date, + day_index: int, + ) -> Optional[TradingDecision]: + context = { + "day_index": day_index, + "portfolio": self.portfolio, + "data_loader": self.data_loader, + "open_trade": self.open_trades.get(ticker), + } + + position = self.portfolio.get_position(ticker) + + if position.quantity == 0 and self.buy_signal and self.buy_signal(ticker, trading_date, context): + return TradingDecision( + ticker=ticker, + timestamp=datetime.now(), + decision_date=datetime.combine(trading_date, datetime.min.time()), + signal=SignalType.BUY, + confidence=Decimal("0.7"), + recommended_action="BUY", + final_decision="BUY", + rationale="Buy signal triggered", + ) + + if position.quantity > 0 and self.sell_signal and self.sell_signal(ticker, trading_date, context): + return TradingDecision( + ticker=ticker, + timestamp=datetime.now(), + decision_date=datetime.combine(trading_date, datetime.min.time()), + signal=SignalType.SELL, + confidence=Decimal("0.7"), + recommended_action="SELL", + final_decision="SELL", + rationale="Sell signal triggered", + ) + + return None diff --git a/tradingagents/backtesting/metrics.py b/tradingagents/backtesting/metrics.py new file mode 100644 index 00000000..05ba5181 --- /dev/null +++ b/tradingagents/backtesting/metrics.py @@ -0,0 +1,281 @@ +import math +from decimal import Decimal +from typing import Optional + +from tradingagents.models.backtest import BacktestMetrics, EquityCurvePoint, TradeLog + + +class MetricsCalculator: + TRADING_DAYS_PER_YEAR = 252 + + def __init__(self, risk_free_rate: Decimal = Decimal("0.05")): + self.risk_free_rate = risk_free_rate + + def calculate_metrics( + self, + equity_curve: list[EquityCurvePoint], + trade_log: TradeLog, + benchmark_curve: Optional[list[EquityCurvePoint]] = None, + ) -> BacktestMetrics: + if not equity_curve: + raise ValueError("Equity curve cannot be empty") + + start_equity = equity_curve[0].equity + end_equity = equity_curve[-1].equity + trading_days = len(equity_curve) + + total_return = end_equity - start_equity + total_return_percent = (total_return / start_equity) * 100 + + years = Decimal(trading_days) / Decimal(self.TRADING_DAYS_PER_YEAR) + if years > 0: + annualized_return = ((end_equity / start_equity) ** (1 / years) - 1) * 100 + else: + annualized_return = Decimal("0") + + daily_returns = self._calculate_daily_returns(equity_curve) + volatility = self._calculate_volatility(daily_returns) + annualized_volatility = volatility * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR)) + + downside_returns = [r for r in daily_returns if r < 0] + downside_volatility = self._calculate_volatility(downside_returns) if downside_returns else Decimal("0") + annualized_downside_vol = downside_volatility * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR)) + + max_dd, max_dd_pct, max_dd_duration, avg_dd = self._calculate_drawdown_metrics(equity_curve) + + sharpe = self._calculate_sharpe_ratio(annualized_return, annualized_volatility) + sortino = self._calculate_sortino_ratio(annualized_return, annualized_downside_vol) + calmar = self._calculate_calmar_ratio(annualized_return, max_dd_pct) + + benchmark_return = None + benchmark_return_percent = None + alpha = None + beta = None + information_ratio = None + + if benchmark_curve and len(benchmark_curve) == len(equity_curve): + benchmark_return = benchmark_curve[-1].equity - benchmark_curve[0].equity + benchmark_return_percent = (benchmark_return / benchmark_curve[0].equity) * 100 + + benchmark_daily = self._calculate_daily_returns(benchmark_curve) + alpha, beta = self._calculate_alpha_beta(daily_returns, benchmark_daily) + information_ratio = self._calculate_information_ratio( + daily_returns, benchmark_daily + ) + + all_pnls = [t.pnl for t in trade_log.trades if t.is_closed and t.pnl is not None] + avg_trade_pnl = sum(all_pnls) / len(all_pnls) if all_pnls else None + largest_win = max((p for p in all_pnls if p > 0), default=None) + largest_loss = min((p for p in all_pnls if p < 0), default=None) + + return BacktestMetrics( + total_return=total_return, + total_return_percent=total_return_percent, + annualized_return=annualized_return, + benchmark_return=benchmark_return, + benchmark_return_percent=benchmark_return_percent, + alpha=alpha, + beta=beta, + volatility=volatility * 100, + annualized_volatility=annualized_volatility * 100, + downside_volatility=annualized_downside_vol * 100, + sharpe_ratio=sharpe, + sortino_ratio=sortino, + calmar_ratio=calmar, + information_ratio=information_ratio, + max_drawdown=max_dd, + max_drawdown_percent=max_dd_pct, + max_drawdown_duration=max_dd_duration, + avg_drawdown=avg_dd, + total_trades=trade_log.total_trades, + win_rate=trade_log.win_rate, + profit_factor=trade_log.profit_factor, + avg_trade_pnl=avg_trade_pnl, + avg_win=trade_log.avg_win, + avg_loss=trade_log.avg_loss, + largest_win=largest_win, + largest_loss=largest_loss, + avg_holding_period_days=trade_log.avg_holding_period, + trading_days=trading_days, + start_equity=start_equity, + end_equity=end_equity, + ) + + def _calculate_daily_returns( + self, + equity_curve: list[EquityCurvePoint], + ) -> list[Decimal]: + returns = [] + for i in range(1, len(equity_curve)): + prev_equity = equity_curve[i - 1].equity + curr_equity = equity_curve[i].equity + if prev_equity > 0: + daily_return = (curr_equity - prev_equity) / prev_equity + returns.append(daily_return) + return returns + + def _calculate_volatility(self, returns: list[Decimal]) -> Decimal: + if len(returns) < 2: + return Decimal("0") + + mean = sum(returns) / len(returns) + variance = sum((r - mean) ** 2 for r in returns) / (len(returns) - 1) + return Decimal(str(math.sqrt(float(variance)))) + + def _calculate_drawdown_metrics( + self, + equity_curve: list[EquityCurvePoint], + ) -> tuple[Decimal, Decimal, Optional[int], Decimal]: + if not equity_curve: + return Decimal("0"), Decimal("0"), None, Decimal("0") + + peak = equity_curve[0].equity + max_drawdown = Decimal("0") + max_drawdown_percent = Decimal("0") + drawdown_start = 0 + max_drawdown_duration = 0 + current_drawdown_start = 0 + in_drawdown = False + drawdowns = [] + + for i, point in enumerate(equity_curve): + equity = point.equity + + if equity > peak: + if in_drawdown: + duration = i - current_drawdown_start + max_drawdown_duration = max(max_drawdown_duration, duration) + in_drawdown = False + peak = equity + current_drawdown_start = i + else: + if not in_drawdown: + in_drawdown = True + current_drawdown_start = i + + drawdown = peak - equity + drawdown_pct = (drawdown / peak) * 100 if peak > 0 else Decimal("0") + drawdowns.append(drawdown_pct) + + if drawdown > max_drawdown: + max_drawdown = drawdown + max_drawdown_percent = drawdown_pct + + point.drawdown = drawdown + point.drawdown_percent = drawdown_pct + + if in_drawdown: + duration = len(equity_curve) - current_drawdown_start + max_drawdown_duration = max(max_drawdown_duration, duration) + + avg_drawdown = sum(drawdowns) / len(drawdowns) if drawdowns else Decimal("0") + + return max_drawdown, max_drawdown_percent, max_drawdown_duration or None, avg_drawdown + + def _calculate_sharpe_ratio( + self, + annualized_return: Decimal, + annualized_volatility: Decimal, + ) -> Optional[Decimal]: + if annualized_volatility == 0: + return None + + excess_return = annualized_return - (self.risk_free_rate * 100) + return excess_return / annualized_volatility + + def _calculate_sortino_ratio( + self, + annualized_return: Decimal, + annualized_downside_vol: Decimal, + ) -> Optional[Decimal]: + if annualized_downside_vol == 0: + return None + + excess_return = annualized_return - (self.risk_free_rate * 100) + return excess_return / annualized_downside_vol + + def _calculate_calmar_ratio( + self, + annualized_return: Decimal, + max_drawdown_percent: Decimal, + ) -> Optional[Decimal]: + if max_drawdown_percent == 0: + return None + + return annualized_return / max_drawdown_percent + + def _calculate_alpha_beta( + self, + returns: list[Decimal], + benchmark_returns: list[Decimal], + ) -> tuple[Optional[Decimal], Optional[Decimal]]: + if len(returns) != len(benchmark_returns) or len(returns) < 2: + return None, None + + n = len(returns) + sum_x = sum(benchmark_returns) + sum_y = sum(returns) + sum_xy = sum(r * b for r, b in zip(returns, benchmark_returns)) + sum_xx = sum(b * b for b in benchmark_returns) + + denominator = n * sum_xx - sum_x * sum_x + if denominator == 0: + return None, None + + beta = (n * sum_xy - sum_x * sum_y) / denominator + alpha = (sum_y - beta * sum_x) / n + + alpha_annualized = alpha * self.TRADING_DAYS_PER_YEAR + + return alpha_annualized, beta + + def _calculate_information_ratio( + self, + returns: list[Decimal], + benchmark_returns: list[Decimal], + ) -> Optional[Decimal]: + if len(returns) != len(benchmark_returns) or len(returns) < 2: + return None + + excess_returns = [r - b for r, b in zip(returns, benchmark_returns)] + mean_excess = sum(excess_returns) / len(excess_returns) + tracking_error = self._calculate_volatility(excess_returns) + + if tracking_error == 0: + return None + + annualized_tracking_error = tracking_error * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR)) + annualized_excess = mean_excess * self.TRADING_DAYS_PER_YEAR + + return annualized_excess / annualized_tracking_error + + def calculate_rolling_metrics( + self, + equity_curve: list[EquityCurvePoint], + window: int = 20, + ) -> dict[str, list[Decimal]]: + if len(equity_curve) < window: + return {"rolling_sharpe": [], "rolling_volatility": []} + + rolling_sharpe = [] + rolling_volatility = [] + + daily_returns = self._calculate_daily_returns(equity_curve) + + for i in range(window - 1, len(daily_returns)): + window_returns = daily_returns[i - window + 1:i + 1] + vol = self._calculate_volatility(window_returns) + annualized_vol = vol * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR)) + + mean_return = sum(window_returns) / len(window_returns) + annualized_return = mean_return * self.TRADING_DAYS_PER_YEAR * 100 + + sharpe = self._calculate_sharpe_ratio(annualized_return, annualized_vol * 100) + + rolling_sharpe.append(sharpe if sharpe else Decimal("0")) + rolling_volatility.append(annualized_vol * 100) + + return { + "rolling_sharpe": rolling_sharpe, + "rolling_volatility": rolling_volatility, + } diff --git a/tradingagents/dataflows/alpha_vantage_common.py b/tradingagents/dataflows/alpha_vantage_common.py index 44585b01..9d4ac5d9 100644 --- a/tradingagents/dataflows/alpha_vantage_common.py +++ b/tradingagents/dataflows/alpha_vantage_common.py @@ -53,7 +53,7 @@ def _make_api_request(function_name: str, params: dict) -> dict | str: elif "entitlement" in api_params: api_params.pop("entitlement", None) - response = requests.get(API_BASE_URL, params=api_params) + response = requests.get(API_BASE_URL, params=api_params, timeout=30) response.raise_for_status() response_text = response.text @@ -88,6 +88,6 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) -> return filtered_df.to_csv(index=False) - except Exception as e: + except (pd.errors.ParserError, KeyError, ValueError) as e: logger.warning("Failed to filter CSV data by date range: %s", e) return csv_data diff --git a/tradingagents/dataflows/alpha_vantage_indicator.py b/tradingagents/dataflows/alpha_vantage_indicator.py index 913cc96c..12b146a4 100644 --- a/tradingagents/dataflows/alpha_vantage_indicator.py +++ b/tradingagents/dataflows/alpha_vantage_indicator.py @@ -1,4 +1,5 @@ import logging +import requests from .alpha_vantage_common import _make_api_request logger = logging.getLogger(__name__) @@ -193,6 +194,6 @@ def get_indicator( return result_str - except Exception as e: + except (ValueError, KeyError, IndexError, requests.RequestException) as e: logger.error("Error getting Alpha Vantage indicator data for %s: %s", indicator, e) return f"Error retrieving {indicator} data: {str(e)}" diff --git a/tradingagents/dataflows/alpha_vantage_news.py b/tradingagents/dataflows/alpha_vantage_news.py index ee941c3e..5dbf5e47 100644 --- a/tradingagents/dataflows/alpha_vantage_news.py +++ b/tradingagents/dataflows/alpha_vantage_news.py @@ -77,7 +77,8 @@ def get_bulk_news_alpha_vantage(lookback_hours: int) -> List[Dict[str, Any]]: "content_snippet": item.get("summary", "")[:500], } articles.append(article) - except Exception: + except (KeyError, TypeError, AttributeError) as e: + logger.debug("Error parsing news article: %s", e) continue return articles diff --git a/tradingagents/dataflows/brave.py b/tradingagents/dataflows/brave.py index 811f3d72..5135b21e 100644 --- a/tradingagents/dataflows/brave.py +++ b/tradingagents/dataflows/brave.py @@ -120,7 +120,7 @@ def get_bulk_news_brave(lookback_hours: int) -> List[Dict[str, Any]]: except requests.exceptions.RequestException as e: logger.debug("Brave search request failed for '%s': %s", query, e) continue - except Exception as e: + except (KeyError, TypeError, ValueError) as e: logger.debug("Brave search failed for query '%s': %s", query, e) continue diff --git a/tradingagents/dataflows/google.py b/tradingagents/dataflows/google.py index 90cc14aa..a0663de9 100644 --- a/tradingagents/dataflows/google.py +++ b/tradingagents/dataflows/google.py @@ -1,10 +1,14 @@ +import logging import re +import requests from typing import Annotated, List, Dict, Any from datetime import datetime, timedelta from dateutil.relativedelta import relativedelta from dateutil import parser as dateutil_parser from .googlenews_utils import getNewsData +logger = logging.getLogger(__name__) + def _parse_google_news_date(date_str: str) -> datetime: if not date_str: @@ -108,7 +112,8 @@ def get_bulk_news_google(lookback_hours: int) -> List[Dict[str, Any]]: } all_articles.append(article) - except Exception: + except (TypeError, KeyError, AttributeError, requests.RequestException) as e: + logger.debug("Google News search failed for query '%s': %s", query, e) continue return all_articles diff --git a/tradingagents/dataflows/googlenews_utils.py b/tradingagents/dataflows/googlenews_utils.py index c108aa3c..1179fb3d 100644 --- a/tradingagents/dataflows/googlenews_utils.py +++ b/tradingagents/dataflows/googlenews_utils.py @@ -27,7 +27,7 @@ def is_rate_limited(response): ) def make_request(url, headers): time.sleep(random.uniform(2, 6)) - response = requests.get(url, headers=headers) + response = requests.get(url, headers=headers, timeout=30) return response @@ -81,7 +81,7 @@ def getNewsData(query, start_date, end_date): "source": source, } ) - except Exception as e: + except (TypeError, AttributeError, KeyError) as e: logger.debug("Error processing result: %s", e) continue @@ -91,7 +91,7 @@ def getNewsData(query, start_date, end_date): page += 1 - except Exception as e: + except (requests.RequestException, ConnectionError, TimeoutError) as e: logger.debug("Failed after multiple retries: %s", e) break diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 58576406..ea169c73 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -191,7 +191,8 @@ def _convert_to_news_articles(raw_articles: List[Dict[str, Any]]) -> List[NewsAr ticker_mentions=[], ) articles.append(article) - except Exception: + except (KeyError, TypeError, ValueError) as e: + logger.debug("Error converting article to NewsArticle: %s", e) continue return articles @@ -218,7 +219,7 @@ def _fetch_bulk_news_from_vendor(lookback_period: str) -> List[Dict[str, Any]]: except AlphaVantageRateLimitError as e: logger.warning("Alpha Vantage rate limit exceeded: %s", e) continue - except Exception as e: + except (RuntimeError, ConnectionError, TimeoutError, ValueError, OSError) as e: logger.error("Vendor '%s' failed: %s", vendor, e) continue @@ -316,7 +317,7 @@ def route_to_vendor(method: str, *args, **kwargs): logger.warning("Alpha Vantage rate limit exceeded, falling back to next available vendor") logger.debug("Rate limit details: %s", e) continue - except Exception as e: + except (RuntimeError, ConnectionError, TimeoutError, ValueError, KeyError, OSError) as e: logger.error("%s from vendor '%s' failed: %s", impl_func.__name__, vendor_name, e) continue diff --git a/tradingagents/dataflows/local.py b/tradingagents/dataflows/local.py index 1c126ee2..94022bd7 100644 --- a/tradingagents/dataflows/local.py +++ b/tradingagents/dataflows/local.py @@ -170,8 +170,8 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period= data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json" ) - data = open(data_path, "r") - data = json.load(data) + with open(data_path, "r") as f: + data = json.load(f) filtered_data = {} for key, value in data.items(): diff --git a/tradingagents/dataflows/openai.py b/tradingagents/dataflows/openai.py index d1cde9d5..e0ab7fac 100644 --- a/tradingagents/dataflows/openai.py +++ b/tradingagents/dataflows/openai.py @@ -54,7 +54,7 @@ def get_stock_news_openai(query, start_date, end_date): temperature=1, max_output_tokens=4096, top_p=1, - store=True, + store=False, ) return _extract_response_text(response) or "" @@ -89,7 +89,7 @@ def get_global_news_openai(curr_date, look_back_days=7, limit=5): temperature=1, max_output_tokens=4096, top_p=1, - store=True, + store=False, ) return _extract_response_text(response) or "" @@ -124,7 +124,7 @@ def get_fundamentals_openai(ticker, curr_date): temperature=1, max_output_tokens=4096, top_p=1, - store=True, + store=False, ) return _extract_response_text(response) or "" @@ -187,7 +187,7 @@ Return ONLY the JSON array, no additional text.""" temperature=0.5, max_output_tokens=8192, top_p=1, - store=True, + store=False, ) try: diff --git a/tradingagents/dataflows/tavily.py b/tradingagents/dataflows/tavily.py index 4651c0ad..0599b325 100644 --- a/tradingagents/dataflows/tavily.py +++ b/tradingagents/dataflows/tavily.py @@ -36,7 +36,7 @@ def _search_with_retry(client, query: str, search_depth: str, topic: str, time_r max_results=max_results, ) return response - except Exception as e: + except (RuntimeError, ConnectionError, TimeoutError, OSError) as e: error_str = str(e).lower() if "rate" in error_str or "limit" in error_str or "429" in error_str: wait_time = RETRY_BACKOFF * (attempt + 1) * 2 @@ -123,7 +123,7 @@ def get_bulk_news_tavily(lookback_hours: int) -> List[Dict[str, Any]]: } all_articles.append(article) - except Exception as e: + except (RuntimeError, ConnectionError, TimeoutError, OSError, ValueError) as e: logger.debug("Tavily search failed for query '%s': %s", query, e) continue diff --git a/tradingagents/dataflows/trending/sector_classifier.py b/tradingagents/dataflows/trending/sector_classifier.py index 999ec3aa..a8df7533 100644 --- a/tradingagents/dataflows/trending/sector_classifier.py +++ b/tradingagents/dataflows/trending/sector_classifier.py @@ -261,7 +261,7 @@ def classify_sector(ticker: str) -> str: _sector_cache[ticker_upper] = sector logger.info("Classified %s as %s via LLM", ticker, sector) return sector - except Exception as e: + except (KeyError, ValueError, RuntimeError, ConnectionError, TimeoutError) as e: logger.error("LLM sector classification failed for %s: %s", ticker, str(e)) _sector_cache[ticker_upper] = "other" return "other" diff --git a/tradingagents/dataflows/trending/stock_resolver.py b/tradingagents/dataflows/trending/stock_resolver.py index 329573f5..bdb5a6b7 100644 --- a/tradingagents/dataflows/trending/stock_resolver.py +++ b/tradingagents/dataflows/trending/stock_resolver.py @@ -462,7 +462,7 @@ def _search_yfinance_ticker(company_name: str) -> Optional[str]: info = search_result.info if info and "symbol" in info: return info["symbol"] - except Exception as e: + except (KeyError, ValueError, AttributeError, RuntimeError) as e: logger.debug("yfinance search failed for %s: %s", company_name, str(e)) try: @@ -471,7 +471,7 @@ def _search_yfinance_ticker(company_name: str) -> Optional[str]: for quote in search.quotes: if "symbol" in quote: return quote["symbol"] - except Exception as e: + except (KeyError, ValueError, AttributeError, RuntimeError) as e: logger.debug("yfinance Search failed for %s: %s", company_name, str(e)) return None @@ -495,7 +495,7 @@ def validate_us_ticker(ticker: str) -> bool: logger.warning("Validation failed for %s: exchange %s is not a US exchange", ticker, exchange) return False - except Exception as e: + except (KeyError, ValueError, AttributeError, RuntimeError) as e: logger.warning("Validation failed for %s: %s", ticker, str(e)) return False diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index 8bb4927d..874ccea2 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -149,7 +149,7 @@ def get_stock_stats_indicators_window( for date_str, value in date_values: ind_string += f"{date_str}: {value}\n" - except Exception as e: + except (KeyError, ValueError, FileNotFoundError) as e: logger.error("Error getting bulk stockstats data: %s", e) ind_string = "" curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") @@ -260,7 +260,7 @@ def get_stockstats_indicator( indicator, curr_date, ) - except Exception as e: + except (KeyError, ValueError, IndexError) as e: logger.error( "Error getting stockstats indicator data for indicator %s on %s: %s", indicator, curr_date, e @@ -293,7 +293,7 @@ def get_balance_sheet( return header + csv_string - except Exception as e: + except (ValueError, KeyError, AttributeError) as e: return f"Error retrieving balance sheet for {ticker}: {str(e)}" @@ -320,7 +320,7 @@ def get_cashflow( return header + csv_string - except Exception as e: + except (ValueError, KeyError, AttributeError) as e: return f"Error retrieving cash flow for {ticker}: {str(e)}" @@ -347,7 +347,7 @@ def get_income_statement( return header + csv_string - except Exception as e: + except (ValueError, KeyError, AttributeError) as e: return f"Error retrieving income statement for {ticker}: {str(e)}" @@ -368,5 +368,5 @@ def get_insider_transactions( return header + csv_string - except Exception as e: + except (ValueError, KeyError, AttributeError) as e: return f"Error retrieving insider transactions for {ticker}: {str(e)}" diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index baecdfe5..a444b1a2 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -3,7 +3,7 @@ import os DEFAULT_CONFIG = { "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), - "data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data", + "data_dir": os.getenv("TRADINGAGENTS_DATA_DIR", "./data"), "data_cache_dir": os.path.join( os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "dataflows/data_cache", diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index f6dd3e8f..b2c2a5dd 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -280,7 +280,7 @@ class TradingAgentsGraph: ) discovery_result["stocks"] = trending_stocks - except Exception as e: + except (ValueError, KeyError, RuntimeError, ConnectionError, TimeoutError) as e: discovery_result["error"] = str(e) discovery_thread = threading.Thread(target=run_discovery) diff --git a/tradingagents/models/__init__.py b/tradingagents/models/__init__.py new file mode 100644 index 00000000..4ce1cb4b --- /dev/null +++ b/tradingagents/models/__init__.py @@ -0,0 +1,69 @@ +from .market_data import ( + OHLCV, + OHLCVBar, + TechnicalIndicators, + MarketSnapshot, + HistoricalDataRequest, + HistoricalDataResponse, +) +from .trading import ( + OrderSide, + OrderType, + OrderStatus, + PositionSide, + Order, + Fill, + Position, + Trade, +) +from .portfolio import ( + PortfolioSnapshot, + PortfolioConfig, + CashTransaction, + TransactionType, +) +from .backtest import ( + BacktestConfig, + BacktestResult, + BacktestMetrics, + EquityCurvePoint, + TradeLog, +) +from .decisions import ( + SignalType, + TradingSignal, + TradingDecision, + RiskAssessment, + AnalystReport, +) + +__all__ = [ + "OHLCV", + "OHLCVBar", + "TechnicalIndicators", + "MarketSnapshot", + "HistoricalDataRequest", + "HistoricalDataResponse", + "OrderSide", + "OrderType", + "OrderStatus", + "PositionSide", + "Order", + "Fill", + "Position", + "Trade", + "PortfolioSnapshot", + "PortfolioConfig", + "CashTransaction", + "TransactionType", + "BacktestConfig", + "BacktestResult", + "BacktestMetrics", + "EquityCurvePoint", + "TradeLog", + "SignalType", + "TradingSignal", + "TradingDecision", + "RiskAssessment", + "AnalystReport", +] diff --git a/tradingagents/models/backtest.py b/tradingagents/models/backtest.py new file mode 100644 index 00000000..c18aaa13 --- /dev/null +++ b/tradingagents/models/backtest.py @@ -0,0 +1,242 @@ +from datetime import date, datetime, timedelta +from decimal import Decimal +from enum import Enum +from typing import Optional +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field, computed_field, field_validator + +from .portfolio import PortfolioConfig +from .trading import Trade + + +class BacktestConfig(BaseModel): + id: UUID = Field(default_factory=uuid4) + name: str = Field(default="Backtest") + description: Optional[str] = None + + tickers: list[str] = Field(min_length=1) + start_date: date + end_date: date + interval: str = Field(default="1d") + + portfolio_config: PortfolioConfig = Field(default_factory=PortfolioConfig) + + warmup_period: int = Field(default=20, ge=0) + rebalance_frequency: Optional[str] = Field(default=None) + + use_agent_pipeline: bool = Field(default=True) + agent_config: dict = Field(default_factory=dict) + + benchmark_ticker: Optional[str] = Field(default="SPY") + risk_free_rate: Decimal = Field(default=Decimal("0.05"), ge=0) + + created_at: datetime = Field(default_factory=datetime.now) + + @field_validator("end_date") + @classmethod + def end_after_start(cls, v: date, info) -> date: + if "start_date" in info.data and v <= info.data["start_date"]: + raise ValueError("end_date must be > start_date") + return v + + @field_validator("tickers") + @classmethod + def validate_tickers(cls, v: list[str]) -> list[str]: + return [t.upper().strip() for t in v] + + @computed_field + @property + def trading_days_estimate(self) -> int: + delta = self.end_date - self.start_date + return int(delta.days * 252 / 365) + + +class EquityCurvePoint(BaseModel): + timestamp: datetime + equity: Decimal + cash: Decimal + positions_value: Decimal + benchmark_value: Optional[Decimal] = None + drawdown: Decimal = Field(default=Decimal("0")) + drawdown_percent: Decimal = Field(default=Decimal("0")) + + +class TradeLog(BaseModel): + trades: list[Trade] = Field(default_factory=list) + total_trades: int = Field(default=0, ge=0) + winning_trades: int = Field(default=0, ge=0) + losing_trades: int = Field(default=0, ge=0) + break_even_trades: int = Field(default=0, ge=0) + + @computed_field + @property + def win_rate(self) -> Optional[Decimal]: + if self.total_trades == 0: + return None + return Decimal(self.winning_trades) / Decimal(self.total_trades) * 100 + + @computed_field + @property + def loss_rate(self) -> Optional[Decimal]: + if self.total_trades == 0: + return None + return Decimal(self.losing_trades) / Decimal(self.total_trades) * 100 + + def add_trade(self, trade: Trade) -> None: + self.trades.append(trade) + self.total_trades += 1 + if trade.is_closed and trade.pnl is not None: + if trade.pnl > 0: + self.winning_trades += 1 + elif trade.pnl < 0: + self.losing_trades += 1 + else: + self.break_even_trades += 1 + + @property + def gross_profit(self) -> Decimal: + return sum( + t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl > 0 + ) or Decimal("0") + + @property + def gross_loss(self) -> Decimal: + return abs( + sum(t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl < 0) + or Decimal("0") + ) + + @property + def profit_factor(self) -> Optional[Decimal]: + if self.gross_loss == 0: + return None + return self.gross_profit / self.gross_loss + + @property + def avg_win(self) -> Optional[Decimal]: + wins = [t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl > 0] + if not wins: + return None + return sum(wins) / len(wins) + + @property + def avg_loss(self) -> Optional[Decimal]: + losses = [t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl < 0] + if not losses: + return None + return sum(losses) / len(losses) + + @property + def avg_holding_period(self) -> Optional[float]: + periods = [t.holding_period for t in self.trades if t.holding_period is not None] + if not periods: + return None + return sum(periods) / len(periods) + + +class BacktestMetrics(BaseModel): + total_return: Decimal = Field(default=Decimal("0")) + total_return_percent: Decimal = Field(default=Decimal("0")) + annualized_return: Decimal = Field(default=Decimal("0")) + + benchmark_return: Optional[Decimal] = None + benchmark_return_percent: Optional[Decimal] = None + alpha: Optional[Decimal] = None + beta: Optional[Decimal] = None + + volatility: Decimal = Field(default=Decimal("0"), ge=0) + annualized_volatility: Decimal = Field(default=Decimal("0"), ge=0) + downside_volatility: Decimal = Field(default=Decimal("0"), ge=0) + + sharpe_ratio: Optional[Decimal] = None + sortino_ratio: Optional[Decimal] = None + calmar_ratio: Optional[Decimal] = None + information_ratio: Optional[Decimal] = None + + max_drawdown: Decimal = Field(default=Decimal("0"), ge=0) + max_drawdown_percent: Decimal = Field(default=Decimal("0"), ge=0, le=100) + max_drawdown_duration: Optional[int] = None + avg_drawdown: Decimal = Field(default=Decimal("0"), ge=0) + + total_trades: int = Field(default=0, ge=0) + win_rate: Optional[Decimal] = Field(default=None, ge=0, le=100) + profit_factor: Optional[Decimal] = None + avg_trade_pnl: Optional[Decimal] = None + avg_win: Optional[Decimal] = None + avg_loss: Optional[Decimal] = None + largest_win: Optional[Decimal] = None + largest_loss: Optional[Decimal] = None + avg_holding_period_days: Optional[float] = None + + total_commission: Decimal = Field(default=Decimal("0"), ge=0) + total_slippage: Decimal = Field(default=Decimal("0"), ge=0) + + trading_days: int = Field(default=0, ge=0) + start_equity: Decimal = Field(gt=0) + end_equity: Decimal = Field(gt=0) + + def to_summary_dict(self) -> dict: + return { + "Performance": { + "Total Return": f"{self.total_return_percent:.2f}%", + "Annualized Return": f"{self.annualized_return:.2f}%", + "Sharpe Ratio": f"{self.sharpe_ratio:.2f}" if self.sharpe_ratio else "N/A", + "Sortino Ratio": f"{self.sortino_ratio:.2f}" if self.sortino_ratio else "N/A", + "Max Drawdown": f"{self.max_drawdown_percent:.2f}%", + }, + "Risk": { + "Volatility (Ann.)": f"{self.annualized_volatility:.2f}%", + "Calmar Ratio": f"{self.calmar_ratio:.2f}" if self.calmar_ratio else "N/A", + "Beta": f"{self.beta:.2f}" if self.beta else "N/A", + }, + "Trading": { + "Total Trades": self.total_trades, + "Win Rate": f"{self.win_rate:.1f}%" if self.win_rate else "N/A", + "Profit Factor": f"{self.profit_factor:.2f}" if self.profit_factor else "N/A", + "Avg Holding Period": f"{self.avg_holding_period_days:.1f} days" if self.avg_holding_period_days else "N/A", + }, + "Costs": { + "Total Commission": f"${self.total_commission:.2f}", + "Total Slippage": f"${self.total_slippage:.2f}", + }, + } + + +class BacktestResult(BaseModel): + id: UUID = Field(default_factory=uuid4) + config: BacktestConfig + metrics: BacktestMetrics + trade_log: TradeLog + equity_curve: list[EquityCurvePoint] = Field(default_factory=list) + daily_returns: list[Decimal] = Field(default_factory=list) + started_at: datetime + completed_at: datetime + status: str = Field(default="completed") + error_message: Optional[str] = None + + @computed_field + @property + def duration_seconds(self) -> float: + return (self.completed_at - self.started_at).total_seconds() + + def to_dict(self) -> dict: + return { + "id": str(self.id), + "config": { + "name": self.config.name, + "tickers": self.config.tickers, + "start_date": self.config.start_date.isoformat(), + "end_date": self.config.end_date.isoformat(), + "initial_cash": float(self.config.portfolio_config.initial_cash), + }, + "metrics": self.metrics.to_summary_dict(), + "trade_summary": { + "total_trades": self.trade_log.total_trades, + "winning_trades": self.trade_log.winning_trades, + "losing_trades": self.trade_log.losing_trades, + "win_rate": float(self.trade_log.win_rate) if self.trade_log.win_rate else None, + }, + "duration_seconds": self.duration_seconds, + "status": self.status, + } diff --git a/tradingagents/models/decisions.py b/tradingagents/models/decisions.py new file mode 100644 index 00000000..22bf32f5 --- /dev/null +++ b/tradingagents/models/decisions.py @@ -0,0 +1,157 @@ +from datetime import datetime +from decimal import Decimal +from enum import Enum +from typing import Optional +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field + + +class SignalType(str, Enum): + STRONG_BUY = "strong_buy" + BUY = "buy" + HOLD = "hold" + SELL = "sell" + STRONG_SELL = "strong_sell" + + +class AnalystType(str, Enum): + MARKET = "market" + SENTIMENT = "sentiment" + NEWS = "news" + FUNDAMENTALS = "fundamentals" + + +class AnalystReport(BaseModel): + id: UUID = Field(default_factory=uuid4) + analyst_type: AnalystType + ticker: str + report_date: datetime + signal: Optional[SignalType] = None + confidence: Decimal = Field(default=Decimal("0.5"), ge=0, le=1) + summary: str + key_findings: list[str] = Field(default_factory=list) + raw_content: Optional[str] = None + data_sources: list[str] = Field(default_factory=list) + created_at: datetime = Field(default_factory=datetime.now) + + +class TradingSignal(BaseModel): + id: UUID = Field(default_factory=uuid4) + ticker: str + timestamp: datetime + signal: SignalType + strength: Decimal = Field(ge=0, le=1) + source: str + timeframe: str = Field(default="1d") + price_at_signal: Optional[Decimal] = None + target_price: Optional[Decimal] = None + stop_loss: Optional[Decimal] = None + expiry: Optional[datetime] = None + metadata: dict = Field(default_factory=dict) + + +class RiskAssessment(BaseModel): + id: UUID = Field(default_factory=uuid4) + ticker: str + timestamp: datetime + overall_risk_score: Decimal = Field(ge=0, le=1) + + market_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1) + liquidity_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1) + volatility_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1) + concentration_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1) + event_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1) + + max_position_size: Optional[Decimal] = None + recommended_stop_loss: Optional[Decimal] = None + var_95: Optional[Decimal] = None + expected_shortfall: Optional[Decimal] = None + + risk_factors: list[str] = Field(default_factory=list) + mitigations: list[str] = Field(default_factory=list) + notes: Optional[str] = None + + +class TradingDecision(BaseModel): + id: UUID = Field(default_factory=uuid4) + ticker: str + timestamp: datetime + decision_date: datetime + + signal: SignalType + confidence: Decimal = Field(ge=0, le=1) + + recommended_action: str + recommended_quantity: Optional[int] = None + recommended_price: Optional[Decimal] = None + stop_loss: Optional[Decimal] = None + take_profit: Optional[Decimal] = None + + analyst_reports: list[AnalystReport] = Field(default_factory=list) + signals: list[TradingSignal] = Field(default_factory=list) + risk_assessment: Optional[RiskAssessment] = None + + bull_argument: Optional[str] = None + bear_argument: Optional[str] = None + debate_rounds: int = Field(default=0, ge=0) + debate_winner: Optional[str] = None + + risk_manager_approved: Optional[bool] = None + risk_manager_notes: Optional[str] = None + + final_decision: str + rationale: str + + execution_price: Optional[Decimal] = None + executed_at: Optional[datetime] = None + execution_notes: Optional[str] = None + + created_at: datetime = Field(default_factory=datetime.now) + + @property + def is_buy(self) -> bool: + return self.signal in (SignalType.BUY, SignalType.STRONG_BUY) + + @property + def is_sell(self) -> bool: + return self.signal in (SignalType.SELL, SignalType.STRONG_SELL) + + @property + def is_hold(self) -> bool: + return self.signal == SignalType.HOLD + + def get_analyst_report(self, analyst_type: AnalystType) -> Optional[AnalystReport]: + for report in self.analyst_reports: + if report.analyst_type == analyst_type: + return report + return None + + def to_summary(self) -> dict: + return { + "ticker": self.ticker, + "date": self.decision_date.isoformat(), + "signal": self.signal.value, + "confidence": float(self.confidence), + "final_decision": self.final_decision, + "risk_approved": self.risk_manager_approved, + "debate_rounds": self.debate_rounds, + "analyst_consensus": self._calculate_consensus(), + } + + def _calculate_consensus(self) -> Optional[str]: + if not self.analyst_reports: + return None + + signals = [r.signal for r in self.analyst_reports if r.signal] + if not signals: + return None + + buy_count = sum(1 for s in signals if s in (SignalType.BUY, SignalType.STRONG_BUY)) + sell_count = sum(1 for s in signals if s in (SignalType.SELL, SignalType.STRONG_SELL)) + + if buy_count > sell_count: + return "bullish" + elif sell_count > buy_count: + return "bearish" + return "neutral" diff --git a/tradingagents/models/market_data.py b/tradingagents/models/market_data.py new file mode 100644 index 00000000..f0e293dc --- /dev/null +++ b/tradingagents/models/market_data.py @@ -0,0 +1,144 @@ +from datetime import date, datetime +from decimal import Decimal +from typing import Optional + +from pydantic import BaseModel, Field, field_validator + + +class OHLCVBar(BaseModel): + timestamp: datetime + open: Decimal = Field(gt=0) + high: Decimal = Field(gt=0) + low: Decimal = Field(gt=0) + close: Decimal = Field(gt=0) + volume: int = Field(ge=0) + adjusted_close: Optional[Decimal] = Field(default=None, gt=0) + + @field_validator("high") + @classmethod + def high_gte_low(cls, v: Decimal, info) -> Decimal: + if "low" in info.data and v < info.data["low"]: + raise ValueError("high must be >= low") + return v + + @field_validator("high") + @classmethod + def high_gte_open_close(cls, v: Decimal, info) -> Decimal: + if "open" in info.data and v < info.data["open"]: + raise ValueError("high must be >= open") + if "close" in info.data and v < info.data["close"]: + raise ValueError("high must be >= close") + return v + + @field_validator("low") + @classmethod + def low_lte_open_close(cls, v: Decimal, info) -> Decimal: + if "open" in info.data and v > info.data["open"]: + raise ValueError("low must be <= open") + if "close" in info.data and v > info.data["close"]: + raise ValueError("low must be <= close") + return v + + +class OHLCV(BaseModel): + ticker: str = Field(min_length=1, max_length=10) + bars: list[OHLCVBar] = Field(default_factory=list) + interval: str = Field(default="1d") + currency: str = Field(default="USD") + + @property + def start_date(self) -> Optional[datetime]: + return self.bars[0].timestamp if self.bars else None + + @property + def end_date(self) -> Optional[datetime]: + return self.bars[-1].timestamp if self.bars else None + + def get_bar(self, dt: datetime) -> Optional[OHLCVBar]: + for bar in self.bars: + if bar.timestamp.date() == dt.date(): + return bar + return None + + def slice(self, start: datetime, end: datetime) -> "OHLCV": + filtered = [b for b in self.bars if start <= b.timestamp <= end] + return OHLCV( + ticker=self.ticker, + bars=filtered, + interval=self.interval, + currency=self.currency, + ) + + +class TechnicalIndicators(BaseModel): + timestamp: datetime + ticker: str + + sma_20: Optional[Decimal] = None + sma_50: Optional[Decimal] = None + sma_200: Optional[Decimal] = None + + ema_10: Optional[Decimal] = None + ema_20: Optional[Decimal] = None + + rsi_14: Optional[Decimal] = Field(default=None, ge=0, le=100) + + macd: Optional[Decimal] = None + macd_signal: Optional[Decimal] = None + macd_histogram: Optional[Decimal] = None + + bollinger_upper: Optional[Decimal] = None + bollinger_middle: Optional[Decimal] = None + bollinger_lower: Optional[Decimal] = None + + atr_14: Optional[Decimal] = Field(default=None, ge=0) + + mfi_14: Optional[Decimal] = Field(default=None, ge=0, le=100) + + vwap: Optional[Decimal] = None + + obv: Optional[int] = None + + +class MarketSnapshot(BaseModel): + ticker: str + timestamp: datetime + bar: OHLCVBar + indicators: Optional[TechnicalIndicators] = None + prev_close: Optional[Decimal] = None + + @property + def change(self) -> Optional[Decimal]: + if self.prev_close: + return self.bar.close - self.prev_close + return None + + @property + def change_percent(self) -> Optional[Decimal]: + if self.prev_close and self.prev_close > 0: + return ((self.bar.close - self.prev_close) / self.prev_close) * 100 + return None + + +class HistoricalDataRequest(BaseModel): + ticker: str = Field(min_length=1, max_length=10) + start_date: date + end_date: date + interval: str = Field(default="1d") + include_indicators: bool = Field(default=True) + adjusted: bool = Field(default=True) + + @field_validator("end_date") + @classmethod + def end_after_start(cls, v: date, info) -> date: + if "start_date" in info.data and v < info.data["start_date"]: + raise ValueError("end_date must be >= start_date") + return v + + +class HistoricalDataResponse(BaseModel): + request: HistoricalDataRequest + ohlcv: OHLCV + indicators: list[TechnicalIndicators] = Field(default_factory=list) + fetched_at: datetime = Field(default_factory=datetime.now) + source: str = Field(default="unknown") diff --git a/tradingagents/models/portfolio.py b/tradingagents/models/portfolio.py new file mode 100644 index 00000000..aa471582 --- /dev/null +++ b/tradingagents/models/portfolio.py @@ -0,0 +1,172 @@ +from datetime import datetime +from decimal import Decimal +from enum import Enum +from typing import Optional +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field, computed_field + +from .trading import Position, Fill, OrderSide + + +class TransactionType(str, Enum): + DEPOSIT = "deposit" + WITHDRAWAL = "withdrawal" + DIVIDEND = "dividend" + INTEREST = "interest" + FEE = "fee" + TRANSFER_IN = "transfer_in" + TRANSFER_OUT = "transfer_out" + + +class CashTransaction(BaseModel): + id: UUID = Field(default_factory=uuid4) + transaction_type: TransactionType + amount: Decimal + timestamp: datetime = Field(default_factory=datetime.now) + description: Optional[str] = None + reference_id: Optional[UUID] = None + + +class PortfolioConfig(BaseModel): + initial_cash: Decimal = Field(default=Decimal("100000"), gt=0) + commission_per_share: Decimal = Field(default=Decimal("0"), ge=0) + commission_per_trade: Decimal = Field(default=Decimal("0"), ge=0) + commission_percent: Decimal = Field(default=Decimal("0"), ge=0, le=100) + min_commission: Decimal = Field(default=Decimal("0"), ge=0) + max_commission: Optional[Decimal] = Field(default=None, ge=0) + slippage_percent: Decimal = Field(default=Decimal("0"), ge=0, le=100) + margin_enabled: bool = Field(default=False) + margin_rate: Decimal = Field(default=Decimal("0.05"), ge=0) + max_position_size_percent: Decimal = Field(default=Decimal("100"), gt=0, le=100) + allow_fractional_shares: bool = Field(default=False) + + def calculate_commission(self, quantity: int, price: Decimal) -> Decimal: + trade_value = quantity * price + commission = Decimal("0") + + commission += self.commission_per_trade + commission += self.commission_per_share * quantity + commission += trade_value * (self.commission_percent / 100) + + if commission < self.min_commission: + commission = self.min_commission + if self.max_commission and commission > self.max_commission: + commission = self.max_commission + + return commission + + def calculate_slippage(self, price: Decimal, side: OrderSide) -> Decimal: + slippage = price * (self.slippage_percent / 100) + if side == OrderSide.BUY: + return price + slippage + return price - slippage + + +class PortfolioSnapshot(BaseModel): + timestamp: datetime = Field(default_factory=datetime.now) + cash: Decimal = Field(default=Decimal("0")) + positions: dict[str, Position] = Field(default_factory=dict) + pending_orders: int = Field(default=0, ge=0) + total_commission_paid: Decimal = Field(default=Decimal("0"), ge=0) + total_realized_pnl: Decimal = Field(default=Decimal("0")) + cash_transactions: list[CashTransaction] = Field(default_factory=list) + + @computed_field + @property + def position_count(self) -> int: + return len([p for p in self.positions.values() if p.quantity != 0]) + + def positions_value(self, prices: dict[str, Decimal]) -> Decimal: + total = Decimal("0") + for ticker, position in self.positions.items(): + if ticker in prices and position.quantity != 0: + total += position.market_value(prices[ticker]) + return total + + def total_equity(self, prices: dict[str, Decimal]) -> Decimal: + return self.cash + self.positions_value(prices) + + def total_unrealized_pnl(self, prices: dict[str, Decimal]) -> Decimal: + total = Decimal("0") + for ticker, position in self.positions.items(): + if ticker in prices: + total += position.unrealized_pnl(prices[ticker]) + return total + + def get_position(self, ticker: str) -> Position: + if ticker not in self.positions: + self.positions[ticker] = Position(ticker=ticker) + return self.positions[ticker] + + def apply_fill(self, fill: Fill) -> None: + position = self.get_position(fill.ticker) + old_realized = position.realized_pnl + + position.update_from_fill(fill) + + pnl_change = position.realized_pnl - old_realized + self.total_realized_pnl += pnl_change + self.total_commission_paid += fill.commission + + if fill.side == OrderSide.BUY: + self.cash -= fill.total_cost + else: + self.cash += fill.total_value - fill.commission + + def add_cash_transaction(self, transaction: CashTransaction) -> None: + self.cash_transactions.append(transaction) + if transaction.transaction_type in ( + TransactionType.DEPOSIT, + TransactionType.DIVIDEND, + TransactionType.INTEREST, + TransactionType.TRANSFER_IN, + ): + self.cash += transaction.amount + else: + self.cash -= abs(transaction.amount) + + def can_afford( + self, ticker: str, quantity: int, price: Decimal, config: PortfolioConfig + ) -> bool: + execution_price = config.calculate_slippage(price, OrderSide.BUY) + commission = config.calculate_commission(quantity, execution_price) + total_cost = (quantity * execution_price) + commission + return self.cash >= total_cost + + def max_shares_affordable( + self, ticker: str, price: Decimal, config: PortfolioConfig + ) -> int: + if price <= 0: + return 0 + + execution_price = config.calculate_slippage(price, OrderSide.BUY) + available = self.cash + + low, high = 0, int(available / execution_price) + 1 + result = 0 + + while low <= high: + mid = (low + high) // 2 + commission = config.calculate_commission(mid, execution_price) + total_cost = (mid * execution_price) + commission + + if total_cost <= available: + result = mid + low = mid + 1 + else: + high = mid - 1 + + return result + + def to_dict(self, prices: dict[str, Decimal]) -> dict: + return { + "timestamp": self.timestamp.isoformat(), + "cash": float(self.cash), + "positions_value": float(self.positions_value(prices)), + "total_equity": float(self.total_equity(prices)), + "position_count": self.position_count, + "total_realized_pnl": float(self.total_realized_pnl), + "total_unrealized_pnl": float(self.total_unrealized_pnl(prices)), + "total_commission_paid": float(self.total_commission_paid), + } diff --git a/tradingagents/models/trading.py b/tradingagents/models/trading.py new file mode 100644 index 00000000..abb3bcb9 --- /dev/null +++ b/tradingagents/models/trading.py @@ -0,0 +1,201 @@ +from datetime import datetime +from decimal import Decimal +from enum import Enum +from typing import Optional +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field, computed_field + + +class OrderSide(str, Enum): + BUY = "buy" + SELL = "sell" + + +class OrderType(str, Enum): + MARKET = "market" + LIMIT = "limit" + STOP = "stop" + STOP_LIMIT = "stop_limit" + + +class OrderStatus(str, Enum): + PENDING = "pending" + SUBMITTED = "submitted" + PARTIAL = "partial" + FILLED = "filled" + CANCELLED = "cancelled" + REJECTED = "rejected" + EXPIRED = "expired" + + +class PositionSide(str, Enum): + LONG = "long" + SHORT = "short" + FLAT = "flat" + + +class Order(BaseModel): + id: UUID = Field(default_factory=uuid4) + ticker: str = Field(min_length=1, max_length=10) + side: OrderSide + order_type: OrderType = Field(default=OrderType.MARKET) + quantity: int = Field(gt=0) + limit_price: Optional[Decimal] = Field(default=None, gt=0) + stop_price: Optional[Decimal] = Field(default=None, gt=0) + status: OrderStatus = Field(default=OrderStatus.PENDING) + created_at: datetime = Field(default_factory=datetime.now) + submitted_at: Optional[datetime] = None + filled_at: Optional[datetime] = None + filled_quantity: int = Field(default=0, ge=0) + filled_avg_price: Optional[Decimal] = None + commission: Decimal = Field(default=Decimal("0")) + notes: Optional[str] = None + + @computed_field + @property + def remaining_quantity(self) -> int: + return self.quantity - self.filled_quantity + + @computed_field + @property + def is_complete(self) -> bool: + return self.status in ( + OrderStatus.FILLED, + OrderStatus.CANCELLED, + OrderStatus.REJECTED, + OrderStatus.EXPIRED, + ) + + +class Fill(BaseModel): + id: UUID = Field(default_factory=uuid4) + order_id: UUID + ticker: str + side: OrderSide + quantity: int = Field(gt=0) + price: Decimal = Field(gt=0) + commission: Decimal = Field(default=Decimal("0"), ge=0) + timestamp: datetime = Field(default_factory=datetime.now) + + @computed_field + @property + def total_value(self) -> Decimal: + return self.price * self.quantity + + @computed_field + @property + def total_cost(self) -> Decimal: + if self.side == OrderSide.BUY: + return self.total_value + self.commission + return self.total_value - self.commission + + +class Position(BaseModel): + ticker: str = Field(min_length=1, max_length=10) + quantity: int = Field(default=0) + avg_cost: Decimal = Field(default=Decimal("0"), ge=0) + realized_pnl: Decimal = Field(default=Decimal("0")) + opened_at: Optional[datetime] = None + last_updated: datetime = Field(default_factory=datetime.now) + + @computed_field + @property + def side(self) -> PositionSide: + if self.quantity > 0: + return PositionSide.LONG + elif self.quantity < 0: + return PositionSide.SHORT + return PositionSide.FLAT + + @computed_field + @property + def cost_basis(self) -> Decimal: + return abs(self.quantity) * self.avg_cost + + def unrealized_pnl(self, current_price: Decimal) -> Decimal: + if self.quantity == 0: + return Decimal("0") + market_value = self.quantity * current_price + return market_value - (self.quantity * self.avg_cost) + + def market_value(self, current_price: Decimal) -> Decimal: + return abs(self.quantity) * current_price + + def update_from_fill(self, fill: Fill) -> None: + if fill.side == OrderSide.BUY: + if self.quantity >= 0: + total_cost = (self.quantity * self.avg_cost) + fill.total_value + self.quantity += fill.quantity + self.avg_cost = total_cost / self.quantity if self.quantity else Decimal("0") + else: + close_qty = min(fill.quantity, abs(self.quantity)) + pnl = close_qty * (self.avg_cost - fill.price) + self.realized_pnl += pnl + self.quantity += fill.quantity + if self.quantity > 0: + self.avg_cost = fill.price + else: + if self.quantity <= 0: + total_cost = (abs(self.quantity) * self.avg_cost) + fill.total_value + self.quantity -= fill.quantity + self.avg_cost = total_cost / abs(self.quantity) if self.quantity else Decimal("0") + else: + close_qty = min(fill.quantity, self.quantity) + pnl = close_qty * (fill.price - self.avg_cost) + self.realized_pnl += pnl + self.quantity -= fill.quantity + if self.quantity < 0: + self.avg_cost = fill.price + + if self.quantity != 0 and self.opened_at is None: + self.opened_at = fill.timestamp + elif self.quantity == 0: + self.opened_at = None + + self.last_updated = fill.timestamp + + +class Trade(BaseModel): + id: UUID = Field(default_factory=uuid4) + ticker: str + side: OrderSide + entry_price: Decimal = Field(gt=0) + entry_quantity: int = Field(gt=0) + entry_time: datetime + exit_price: Optional[Decimal] = Field(default=None, gt=0) + exit_quantity: Optional[int] = Field(default=None, gt=0) + exit_time: Optional[datetime] = None + commission: Decimal = Field(default=Decimal("0"), ge=0) + entry_order_id: Optional[UUID] = None + exit_order_id: Optional[UUID] = None + notes: Optional[str] = None + tags: list[str] = Field(default_factory=list) + + @computed_field + @property + def is_closed(self) -> bool: + return self.exit_price is not None and self.exit_quantity is not None + + @computed_field + @property + def pnl(self) -> Optional[Decimal]: + if not self.is_closed: + return None + if self.side == OrderSide.BUY: + return (self.exit_price - self.entry_price) * self.exit_quantity - self.commission + return (self.entry_price - self.exit_price) * self.exit_quantity - self.commission + + @computed_field + @property + def pnl_percent(self) -> Optional[Decimal]: + if not self.is_closed or self.entry_price == 0: + return None + return (self.pnl / (self.entry_price * self.entry_quantity)) * 100 + + @computed_field + @property + def holding_period(self) -> Optional[int]: + if not self.exit_time: + return None + return (self.exit_time - self.entry_time).days