feat: add backtesting framework and fix code quality issues
- Add complete backtesting engine with portfolio simulation, metrics calculation (Sharpe, Sortino, max drawdown), and agent integration - Add Pydantic data models for market data, trading, portfolio, and backtest results - Add backtest CLI command with SMA, RSI, and hold strategies - Fix 24+ bare exception handlers with specific exception types - Fix hardcoded path in default_config.py (use TRADINGAGENTS_DATA_DIR env var) - Fix unclosed file handle in local.py with context manager - Disable store=True in OpenAI API calls for data privacy - Fix typo: rename aggresive_debator.py to aggressive_debator.py - Add request timeouts (30s) to alpha_vantage_common.py and googlenews_utils.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
69c3a13883
commit
f70874982a
192
cli/main.py
192
cli/main.py
|
|
@ -34,6 +34,9 @@ from tradingagents.agents.discovery.models import (
|
|||
EventCategory,
|
||||
)
|
||||
from tradingagents.agents.discovery.persistence import save_discovery_result
|
||||
from tradingagents.backtesting import SimpleBacktestEngine, DataLoader
|
||||
from tradingagents.models.backtest import BacktestConfig
|
||||
from tradingagents.models.portfolio import PortfolioConfig
|
||||
from cli.models import AnalystType
|
||||
from cli.utils import (
|
||||
ANALYST_ORDER,
|
||||
|
|
@ -1715,6 +1718,195 @@ def menu():
|
|||
discover_trending_flow()
|
||||
|
||||
|
||||
@app.command()
|
||||
def backtest(
|
||||
ticker: str = typer.Option(None, "--ticker", "-t", help="Ticker symbol to backtest"),
|
||||
start_date: str = typer.Option(None, "--start", "-s", help="Start date (YYYY-MM-DD)"),
|
||||
end_date: str = typer.Option(None, "--end", "-e", help="End date (YYYY-MM-DD)"),
|
||||
initial_cash: float = typer.Option(100000.0, "--cash", "-c", help="Initial portfolio cash"),
|
||||
strategy: str = typer.Option("sma", "--strategy", help="Strategy: sma, rsi, or hold"),
|
||||
):
|
||||
from decimal import Decimal
|
||||
from datetime import date as date_type
|
||||
|
||||
if not ticker:
|
||||
console.print(create_question_box("Ticker Symbol", "Enter the ticker symbol to backtest", "AAPL"))
|
||||
ticker = typer.prompt("", default="AAPL")
|
||||
|
||||
if not start_date:
|
||||
default_start = (datetime.datetime.now() - datetime.timedelta(days=365)).strftime("%Y-%m-%d")
|
||||
console.print(create_question_box("Start Date", "Enter backtest start date (YYYY-MM-DD)", default_start))
|
||||
start_date = typer.prompt("", default=default_start)
|
||||
|
||||
if not end_date:
|
||||
default_end = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
console.print(create_question_box("End Date", "Enter backtest end date (YYYY-MM-DD)", default_end))
|
||||
end_date = typer.prompt("", default=default_end)
|
||||
|
||||
try:
|
||||
start = datetime.datetime.strptime(start_date, "%Y-%m-%d").date()
|
||||
end = datetime.datetime.strptime(end_date, "%Y-%m-%d").date()
|
||||
except ValueError:
|
||||
console.print("[red]Invalid date format. Use YYYY-MM-DD[/red]")
|
||||
return
|
||||
|
||||
if start >= end:
|
||||
console.print("[red]Start date must be before end date[/red]")
|
||||
return
|
||||
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
f"[bold]Backtest Configuration[/bold]\n\n"
|
||||
f"Ticker: [cyan]{ticker.upper()}[/cyan]\n"
|
||||
f"Period: [cyan]{start_date}[/cyan] to [cyan]{end_date}[/cyan]\n"
|
||||
f"Initial Cash: [cyan]${initial_cash:,.2f}[/cyan]\n"
|
||||
f"Strategy: [cyan]{strategy}[/cyan]",
|
||||
title="Configuration",
|
||||
border_style="blue",
|
||||
))
|
||||
console.print()
|
||||
|
||||
def sma_buy(t, trading_date, ctx):
|
||||
loader = ctx["data_loader"]
|
||||
ohlcv = loader.load_ohlcv(t, date_type(2020, 1, 1), trading_date)
|
||||
if len(ohlcv.bars) < 20:
|
||||
return False
|
||||
prices = [float(b.close) for b in ohlcv.bars[-20:]]
|
||||
sma = sum(prices) / len(prices)
|
||||
current = float(ohlcv.bars[-1].close)
|
||||
return current > sma * 1.02
|
||||
|
||||
def sma_sell(t, trading_date, ctx):
|
||||
loader = ctx["data_loader"]
|
||||
ohlcv = loader.load_ohlcv(t, date_type(2020, 1, 1), trading_date)
|
||||
if len(ohlcv.bars) < 20:
|
||||
return False
|
||||
prices = [float(b.close) for b in ohlcv.bars[-20:]]
|
||||
sma = sum(prices) / len(prices)
|
||||
current = float(ohlcv.bars[-1].close)
|
||||
return current < sma * 0.98
|
||||
|
||||
def rsi_buy(t, trading_date, ctx):
|
||||
loader = ctx["data_loader"]
|
||||
ohlcv = loader.load_ohlcv(t, date_type(2020, 1, 1), trading_date)
|
||||
if len(ohlcv.bars) < 15:
|
||||
return False
|
||||
changes = []
|
||||
for i in range(1, min(15, len(ohlcv.bars))):
|
||||
changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i-1].close))
|
||||
gains = [c for c in changes if c > 0]
|
||||
losses = [-c for c in changes if c < 0]
|
||||
avg_gain = sum(gains) / 14 if gains else 0.001
|
||||
avg_loss = sum(losses) / 14 if losses else 0.001
|
||||
rs = avg_gain / avg_loss if avg_loss else 100
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi < 30
|
||||
|
||||
def rsi_sell(t, trading_date, ctx):
|
||||
loader = ctx["data_loader"]
|
||||
ohlcv = loader.load_ohlcv(t, date_type(2020, 1, 1), trading_date)
|
||||
if len(ohlcv.bars) < 15:
|
||||
return False
|
||||
changes = []
|
||||
for i in range(1, min(15, len(ohlcv.bars))):
|
||||
changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i-1].close))
|
||||
gains = [c for c in changes if c > 0]
|
||||
losses = [-c for c in changes if c < 0]
|
||||
avg_gain = sum(gains) / 14 if gains else 0.001
|
||||
avg_loss = sum(losses) / 14 if losses else 0.001
|
||||
rs = avg_gain / avg_loss if avg_loss else 100
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi > 70
|
||||
|
||||
def hold_buy(t, trading_date, ctx):
|
||||
return ctx.get("day_index", 0) == 5
|
||||
|
||||
def hold_sell(t, trading_date, ctx):
|
||||
return False
|
||||
|
||||
strategies = {
|
||||
"sma": (sma_buy, sma_sell),
|
||||
"rsi": (rsi_buy, rsi_sell),
|
||||
"hold": (hold_buy, hold_sell),
|
||||
}
|
||||
|
||||
if strategy not in strategies:
|
||||
console.print(f"[red]Unknown strategy: {strategy}. Use: sma, rsi, or hold[/red]")
|
||||
return
|
||||
|
||||
buy_fn, sell_fn = strategies[strategy]
|
||||
|
||||
config = BacktestConfig(
|
||||
name=f"{strategy.upper()} Backtest - {ticker.upper()}",
|
||||
tickers=[ticker.upper()],
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
portfolio_config=PortfolioConfig(
|
||||
initial_cash=Decimal(str(initial_cash)),
|
||||
commission_per_trade=Decimal("1"),
|
||||
slippage_percent=Decimal("0.05"),
|
||||
),
|
||||
warmup_period=5,
|
||||
)
|
||||
|
||||
with loading("Running backtest...", show_elapsed=True):
|
||||
engine = SimpleBacktestEngine(config, buy_signal=buy_fn, sell_signal=sell_fn)
|
||||
result = engine.run()
|
||||
|
||||
console.print()
|
||||
|
||||
if result.status == "failed":
|
||||
console.print(f"[red]Backtest failed: {result.error_message}[/red]")
|
||||
return
|
||||
|
||||
metrics = result.metrics
|
||||
trade_log = result.trade_log
|
||||
|
||||
performance_table = Table(title="Performance Metrics", box=box.ROUNDED)
|
||||
performance_table.add_column("Metric", style="cyan")
|
||||
performance_table.add_column("Value", style="green")
|
||||
|
||||
performance_table.add_row("Total Return", f"${float(metrics.total_return):,.2f}")
|
||||
performance_table.add_row("Total Return %", f"{float(metrics.total_return_percent):.2f}%")
|
||||
performance_table.add_row("Annualized Return", f"{float(metrics.annualized_return):.2f}%")
|
||||
performance_table.add_row("Sharpe Ratio", f"{float(metrics.sharpe_ratio):.2f}" if metrics.sharpe_ratio else "N/A")
|
||||
performance_table.add_row("Sortino Ratio", f"{float(metrics.sortino_ratio):.2f}" if metrics.sortino_ratio else "N/A")
|
||||
performance_table.add_row("Max Drawdown", f"{float(metrics.max_drawdown_percent):.2f}%")
|
||||
performance_table.add_row("Volatility (Ann.)", f"{float(metrics.annualized_volatility):.2f}%")
|
||||
|
||||
console.print(performance_table)
|
||||
console.print()
|
||||
|
||||
trading_table = Table(title="Trading Statistics", box=box.ROUNDED)
|
||||
trading_table.add_column("Metric", style="cyan")
|
||||
trading_table.add_column("Value", style="green")
|
||||
|
||||
trading_table.add_row("Total Trades", str(trade_log.total_trades))
|
||||
trading_table.add_row("Winning Trades", str(trade_log.winning_trades))
|
||||
trading_table.add_row("Losing Trades", str(trade_log.losing_trades))
|
||||
trading_table.add_row("Win Rate", f"{float(trade_log.win_rate):.1f}%" if trade_log.win_rate else "N/A")
|
||||
trading_table.add_row("Profit Factor", f"{float(trade_log.profit_factor):.2f}" if trade_log.profit_factor else "N/A")
|
||||
trading_table.add_row("Avg Win", f"${float(trade_log.avg_win):,.2f}" if trade_log.avg_win else "N/A")
|
||||
trading_table.add_row("Avg Loss", f"${float(trade_log.avg_loss):,.2f}" if trade_log.avg_loss else "N/A")
|
||||
|
||||
console.print(trading_table)
|
||||
console.print()
|
||||
|
||||
summary_table = Table(title="Portfolio Summary", box=box.ROUNDED)
|
||||
summary_table.add_column("Metric", style="cyan")
|
||||
summary_table.add_column("Value", style="green")
|
||||
|
||||
summary_table.add_row("Start Equity", f"${float(metrics.start_equity):,.2f}")
|
||||
summary_table.add_row("End Equity", f"${float(metrics.end_equity):,.2f}")
|
||||
summary_table.add_row("Trading Days", str(metrics.trading_days))
|
||||
summary_table.add_row("Duration", f"{result.duration_seconds:.1f} seconds")
|
||||
|
||||
console.print(summary_table)
|
||||
console.print()
|
||||
|
||||
console.print(f"[green]Backtest completed successfully![/green]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
choice = show_main_menu()
|
||||
if choice == "analyze":
|
||||
|
|
|
|||
|
|
@ -0,0 +1,318 @@
|
|||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.models.backtest import (
|
||||
BacktestConfig,
|
||||
BacktestResult,
|
||||
BacktestMetrics,
|
||||
EquityCurvePoint,
|
||||
TradeLog,
|
||||
)
|
||||
from tradingagents.models.portfolio import PortfolioConfig
|
||||
from tradingagents.models.trading import Trade, OrderSide
|
||||
|
||||
|
||||
class TestBacktestConfig:
|
||||
def test_basic_config(self):
|
||||
config = BacktestConfig(
|
||||
tickers=["AAPL"],
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 6, 30),
|
||||
)
|
||||
assert config.tickers == ["AAPL"]
|
||||
assert config.interval == "1d"
|
||||
assert config.benchmark_ticker == "SPY"
|
||||
|
||||
def test_multi_ticker_config(self):
|
||||
config = BacktestConfig(
|
||||
name="Multi-Stock Test",
|
||||
tickers=["aapl", "googl", "msft"],
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 12, 31),
|
||||
)
|
||||
assert config.tickers == ["AAPL", "GOOGL", "MSFT"]
|
||||
|
||||
def test_invalid_date_range(self):
|
||||
with pytest.raises(ValueError):
|
||||
BacktestConfig(
|
||||
tickers=["AAPL"],
|
||||
start_date=date(2024, 6, 30),
|
||||
end_date=date(2024, 1, 1),
|
||||
)
|
||||
|
||||
def test_same_start_end_date(self):
|
||||
with pytest.raises(ValueError):
|
||||
BacktestConfig(
|
||||
tickers=["AAPL"],
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 1, 1),
|
||||
)
|
||||
|
||||
def test_empty_tickers(self):
|
||||
with pytest.raises(ValueError):
|
||||
BacktestConfig(
|
||||
tickers=[],
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 6, 30),
|
||||
)
|
||||
|
||||
def test_trading_days_estimate(self):
|
||||
config = BacktestConfig(
|
||||
tickers=["AAPL"],
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 12, 31),
|
||||
)
|
||||
assert config.trading_days_estimate > 200
|
||||
assert config.trading_days_estimate < 260
|
||||
|
||||
def test_custom_portfolio_config(self):
|
||||
portfolio_config = PortfolioConfig(
|
||||
initial_cash=Decimal("50000"),
|
||||
commission_per_trade=Decimal("5"),
|
||||
)
|
||||
config = BacktestConfig(
|
||||
tickers=["AAPL"],
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 6, 30),
|
||||
portfolio_config=portfolio_config,
|
||||
)
|
||||
assert config.portfolio_config.initial_cash == Decimal("50000")
|
||||
|
||||
|
||||
class TestTradeLog:
|
||||
def test_empty_trade_log(self):
|
||||
log = TradeLog()
|
||||
assert log.total_trades == 0
|
||||
assert log.win_rate is None
|
||||
assert log.profit_factor is None
|
||||
|
||||
def test_add_winning_trade(self):
|
||||
log = TradeLog()
|
||||
trade = Trade(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
entry_price=Decimal("100"),
|
||||
entry_quantity=100,
|
||||
entry_time=datetime(2024, 1, 15),
|
||||
exit_price=Decimal("110"),
|
||||
exit_quantity=100,
|
||||
exit_time=datetime(2024, 1, 20),
|
||||
)
|
||||
log.add_trade(trade)
|
||||
assert log.total_trades == 1
|
||||
assert log.winning_trades == 1
|
||||
assert log.win_rate == Decimal("100")
|
||||
|
||||
def test_add_losing_trade(self):
|
||||
log = TradeLog()
|
||||
trade = Trade(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
entry_price=Decimal("100"),
|
||||
entry_quantity=100,
|
||||
entry_time=datetime(2024, 1, 15),
|
||||
exit_price=Decimal("90"),
|
||||
exit_quantity=100,
|
||||
exit_time=datetime(2024, 1, 20),
|
||||
)
|
||||
log.add_trade(trade)
|
||||
assert log.total_trades == 1
|
||||
assert log.losing_trades == 1
|
||||
assert log.win_rate == Decimal("0")
|
||||
|
||||
def test_mixed_trades(self):
|
||||
log = TradeLog()
|
||||
|
||||
win_trade = Trade(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
entry_price=Decimal("100"),
|
||||
entry_quantity=100,
|
||||
entry_time=datetime(2024, 1, 15),
|
||||
exit_price=Decimal("120"),
|
||||
exit_quantity=100,
|
||||
exit_time=datetime(2024, 1, 20),
|
||||
)
|
||||
log.add_trade(win_trade)
|
||||
|
||||
loss_trade = Trade(
|
||||
ticker="GOOGL",
|
||||
side=OrderSide.BUY,
|
||||
entry_price=Decimal("100"),
|
||||
entry_quantity=100,
|
||||
entry_time=datetime(2024, 1, 15),
|
||||
exit_price=Decimal("90"),
|
||||
exit_quantity=100,
|
||||
exit_time=datetime(2024, 1, 20),
|
||||
)
|
||||
log.add_trade(loss_trade)
|
||||
|
||||
assert log.total_trades == 2
|
||||
assert log.winning_trades == 1
|
||||
assert log.losing_trades == 1
|
||||
assert log.win_rate == Decimal("50")
|
||||
|
||||
def test_gross_profit_loss(self):
|
||||
log = TradeLog()
|
||||
|
||||
win_trade = Trade(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
entry_price=Decimal("100"),
|
||||
entry_quantity=100,
|
||||
entry_time=datetime(2024, 1, 15),
|
||||
exit_price=Decimal("120"),
|
||||
exit_quantity=100,
|
||||
exit_time=datetime(2024, 1, 20),
|
||||
)
|
||||
log.add_trade(win_trade)
|
||||
|
||||
loss_trade = Trade(
|
||||
ticker="GOOGL",
|
||||
side=OrderSide.BUY,
|
||||
entry_price=Decimal("100"),
|
||||
entry_quantity=100,
|
||||
entry_time=datetime(2024, 1, 15),
|
||||
exit_price=Decimal("90"),
|
||||
exit_quantity=100,
|
||||
exit_time=datetime(2024, 1, 20),
|
||||
)
|
||||
log.add_trade(loss_trade)
|
||||
|
||||
assert log.gross_profit == Decimal("2000")
|
||||
assert log.gross_loss == Decimal("1000")
|
||||
assert log.profit_factor == Decimal("2")
|
||||
|
||||
def test_avg_win_loss(self):
|
||||
log = TradeLog()
|
||||
|
||||
for i in range(3):
|
||||
trade = Trade(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
entry_price=Decimal("100"),
|
||||
entry_quantity=100,
|
||||
entry_time=datetime(2024, 1, 15),
|
||||
exit_price=Decimal("110"),
|
||||
exit_quantity=100,
|
||||
exit_time=datetime(2024, 1, 20),
|
||||
)
|
||||
log.add_trade(trade)
|
||||
|
||||
assert log.avg_win == Decimal("1000")
|
||||
|
||||
|
||||
class TestBacktestMetrics:
|
||||
def test_basic_metrics(self):
|
||||
metrics = BacktestMetrics(
|
||||
total_return=Decimal("10000"),
|
||||
total_return_percent=Decimal("10"),
|
||||
annualized_return=Decimal("15"),
|
||||
max_drawdown=Decimal("5000"),
|
||||
max_drawdown_percent=Decimal("5"),
|
||||
sharpe_ratio=Decimal("1.5"),
|
||||
total_trades=50,
|
||||
win_rate=Decimal("60"),
|
||||
start_equity=Decimal("100000"),
|
||||
end_equity=Decimal("110000"),
|
||||
)
|
||||
assert metrics.total_return_percent == Decimal("10")
|
||||
assert metrics.sharpe_ratio == Decimal("1.5")
|
||||
|
||||
def test_to_summary_dict(self):
|
||||
metrics = BacktestMetrics(
|
||||
total_return=Decimal("10000"),
|
||||
total_return_percent=Decimal("10"),
|
||||
annualized_return=Decimal("15"),
|
||||
max_drawdown=Decimal("5000"),
|
||||
max_drawdown_percent=Decimal("5"),
|
||||
sharpe_ratio=Decimal("1.5"),
|
||||
sortino_ratio=Decimal("2.0"),
|
||||
volatility=Decimal("10"),
|
||||
annualized_volatility=Decimal("15"),
|
||||
total_trades=50,
|
||||
win_rate=Decimal("60"),
|
||||
profit_factor=Decimal("1.8"),
|
||||
total_commission=Decimal("500"),
|
||||
total_slippage=Decimal("200"),
|
||||
start_equity=Decimal("100000"),
|
||||
end_equity=Decimal("110000"),
|
||||
)
|
||||
summary = metrics.to_summary_dict()
|
||||
|
||||
assert "Performance" in summary
|
||||
assert "Risk" in summary
|
||||
assert "Trading" in summary
|
||||
assert "Costs" in summary
|
||||
assert summary["Performance"]["Sharpe Ratio"] == "1.50"
|
||||
|
||||
|
||||
class TestEquityCurvePoint:
|
||||
def test_equity_point(self):
|
||||
point = EquityCurvePoint(
|
||||
timestamp=datetime(2024, 1, 15),
|
||||
equity=Decimal("105000"),
|
||||
cash=Decimal("50000"),
|
||||
positions_value=Decimal("55000"),
|
||||
drawdown=Decimal("2000"),
|
||||
drawdown_percent=Decimal("1.9"),
|
||||
)
|
||||
assert point.equity == Decimal("105000")
|
||||
assert point.drawdown_percent == Decimal("1.9")
|
||||
|
||||
|
||||
class TestBacktestResult:
|
||||
def test_backtest_result(self):
|
||||
config = BacktestConfig(
|
||||
tickers=["AAPL"],
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 6, 30),
|
||||
)
|
||||
metrics = BacktestMetrics(
|
||||
total_return=Decimal("10000"),
|
||||
total_return_percent=Decimal("10"),
|
||||
start_equity=Decimal("100000"),
|
||||
end_equity=Decimal("110000"),
|
||||
)
|
||||
trade_log = TradeLog(total_trades=10, winning_trades=6, losing_trades=4)
|
||||
|
||||
result = BacktestResult(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
trade_log=trade_log,
|
||||
started_at=datetime(2024, 7, 1, 10, 0, 0),
|
||||
completed_at=datetime(2024, 7, 1, 10, 5, 30),
|
||||
)
|
||||
|
||||
assert result.duration_seconds == 330.0
|
||||
assert result.status == "completed"
|
||||
|
||||
def test_to_dict(self):
|
||||
config = BacktestConfig(
|
||||
name="Test Backtest",
|
||||
tickers=["AAPL"],
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 6, 30),
|
||||
)
|
||||
metrics = BacktestMetrics(
|
||||
total_return=Decimal("10000"),
|
||||
total_return_percent=Decimal("10"),
|
||||
start_equity=Decimal("100000"),
|
||||
end_equity=Decimal("110000"),
|
||||
)
|
||||
trade_log = TradeLog(total_trades=10, winning_trades=6, losing_trades=4)
|
||||
|
||||
result = BacktestResult(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
trade_log=trade_log,
|
||||
started_at=datetime(2024, 7, 1, 10, 0, 0),
|
||||
completed_at=datetime(2024, 7, 1, 10, 5, 0),
|
||||
)
|
||||
|
||||
result_dict = result.to_dict()
|
||||
assert result_dict["config"]["name"] == "Test Backtest"
|
||||
assert result_dict["trade_summary"]["total_trades"] == 10
|
||||
assert result_dict["trade_summary"]["win_rate"] == 60.0
|
||||
|
|
@ -0,0 +1,216 @@
|
|||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.models.market_data import (
|
||||
OHLCVBar,
|
||||
OHLCV,
|
||||
TechnicalIndicators,
|
||||
MarketSnapshot,
|
||||
HistoricalDataRequest,
|
||||
HistoricalDataResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestOHLCVBar:
|
||||
def test_valid_bar(self):
|
||||
bar = OHLCVBar(
|
||||
timestamp=datetime(2024, 1, 15, 9, 30),
|
||||
open=Decimal("100.00"),
|
||||
high=Decimal("105.00"),
|
||||
low=Decimal("99.00"),
|
||||
close=Decimal("103.50"),
|
||||
volume=1000000,
|
||||
)
|
||||
assert bar.open == Decimal("100.00")
|
||||
assert bar.high == Decimal("105.00")
|
||||
assert bar.volume == 1000000
|
||||
|
||||
def test_bar_with_adjusted_close(self):
|
||||
bar = OHLCVBar(
|
||||
timestamp=datetime(2024, 1, 15),
|
||||
open=Decimal("100"),
|
||||
high=Decimal("105"),
|
||||
low=Decimal("99"),
|
||||
close=Decimal("103"),
|
||||
volume=1000000,
|
||||
adjusted_close=Decimal("102.50"),
|
||||
)
|
||||
assert bar.adjusted_close == Decimal("102.50")
|
||||
|
||||
def test_invalid_negative_price(self):
|
||||
with pytest.raises(ValueError):
|
||||
OHLCVBar(
|
||||
timestamp=datetime(2024, 1, 15),
|
||||
open=Decimal("-100"),
|
||||
high=Decimal("105"),
|
||||
low=Decimal("99"),
|
||||
close=Decimal("103"),
|
||||
volume=1000000,
|
||||
)
|
||||
|
||||
def test_invalid_negative_volume(self):
|
||||
with pytest.raises(ValueError):
|
||||
OHLCVBar(
|
||||
timestamp=datetime(2024, 1, 15),
|
||||
open=Decimal("100"),
|
||||
high=Decimal("105"),
|
||||
low=Decimal("99"),
|
||||
close=Decimal("103"),
|
||||
volume=-1000,
|
||||
)
|
||||
|
||||
|
||||
class TestOHLCV:
|
||||
@pytest.fixture
|
||||
def sample_bars(self):
|
||||
return [
|
||||
OHLCVBar(
|
||||
timestamp=datetime(2024, 1, 15),
|
||||
open=Decimal("100"),
|
||||
high=Decimal("105"),
|
||||
low=Decimal("99"),
|
||||
close=Decimal("103"),
|
||||
volume=1000000,
|
||||
),
|
||||
OHLCVBar(
|
||||
timestamp=datetime(2024, 1, 16),
|
||||
open=Decimal("103"),
|
||||
high=Decimal("108"),
|
||||
low=Decimal("102"),
|
||||
close=Decimal("107"),
|
||||
volume=1200000,
|
||||
),
|
||||
OHLCVBar(
|
||||
timestamp=datetime(2024, 1, 17),
|
||||
open=Decimal("107"),
|
||||
high=Decimal("110"),
|
||||
low=Decimal("105"),
|
||||
close=Decimal("109"),
|
||||
volume=900000,
|
||||
),
|
||||
]
|
||||
|
||||
def test_ohlcv_creation(self, sample_bars):
|
||||
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
|
||||
assert ohlcv.ticker == "AAPL"
|
||||
assert len(ohlcv.bars) == 3
|
||||
assert ohlcv.interval == "1d"
|
||||
assert ohlcv.currency == "USD"
|
||||
|
||||
def test_start_end_dates(self, sample_bars):
|
||||
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
|
||||
assert ohlcv.start_date == datetime(2024, 1, 15)
|
||||
assert ohlcv.end_date == datetime(2024, 1, 17)
|
||||
|
||||
def test_empty_ohlcv(self):
|
||||
ohlcv = OHLCV(ticker="AAPL", bars=[])
|
||||
assert ohlcv.start_date is None
|
||||
assert ohlcv.end_date is None
|
||||
|
||||
def test_get_bar(self, sample_bars):
|
||||
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
|
||||
bar = ohlcv.get_bar(datetime(2024, 1, 16))
|
||||
assert bar is not None
|
||||
assert bar.close == Decimal("107")
|
||||
|
||||
def test_get_bar_not_found(self, sample_bars):
|
||||
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
|
||||
bar = ohlcv.get_bar(datetime(2024, 1, 20))
|
||||
assert bar is None
|
||||
|
||||
def test_slice(self, sample_bars):
|
||||
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
|
||||
sliced = ohlcv.slice(datetime(2024, 1, 15), datetime(2024, 1, 16))
|
||||
assert len(sliced.bars) == 2
|
||||
assert sliced.ticker == "AAPL"
|
||||
|
||||
def test_invalid_ticker(self):
|
||||
with pytest.raises(ValueError):
|
||||
OHLCV(ticker="", bars=[])
|
||||
|
||||
|
||||
class TestTechnicalIndicators:
|
||||
def test_valid_indicators(self):
|
||||
indicators = TechnicalIndicators(
|
||||
timestamp=datetime(2024, 1, 15),
|
||||
ticker="AAPL",
|
||||
sma_50=Decimal("150.00"),
|
||||
rsi_14=Decimal("65.5"),
|
||||
macd=Decimal("2.5"),
|
||||
)
|
||||
assert indicators.sma_50 == Decimal("150.00")
|
||||
assert indicators.rsi_14 == Decimal("65.5")
|
||||
|
||||
def test_rsi_bounds(self):
|
||||
with pytest.raises(ValueError):
|
||||
TechnicalIndicators(
|
||||
timestamp=datetime(2024, 1, 15),
|
||||
ticker="AAPL",
|
||||
rsi_14=Decimal("150"),
|
||||
)
|
||||
|
||||
def test_mfi_bounds(self):
|
||||
with pytest.raises(ValueError):
|
||||
TechnicalIndicators(
|
||||
timestamp=datetime(2024, 1, 15),
|
||||
ticker="AAPL",
|
||||
mfi_14=Decimal("-10"),
|
||||
)
|
||||
|
||||
|
||||
class TestMarketSnapshot:
|
||||
def test_snapshot_change_calculation(self):
|
||||
bar = OHLCVBar(
|
||||
timestamp=datetime(2024, 1, 15),
|
||||
open=Decimal("100"),
|
||||
high=Decimal("105"),
|
||||
low=Decimal("99"),
|
||||
close=Decimal("103"),
|
||||
volume=1000000,
|
||||
)
|
||||
snapshot = MarketSnapshot(
|
||||
ticker="AAPL",
|
||||
timestamp=datetime(2024, 1, 15),
|
||||
bar=bar,
|
||||
prev_close=Decimal("100"),
|
||||
)
|
||||
assert snapshot.change == Decimal("3")
|
||||
assert snapshot.change_percent == Decimal("3")
|
||||
|
||||
def test_snapshot_no_prev_close(self):
|
||||
bar = OHLCVBar(
|
||||
timestamp=datetime(2024, 1, 15),
|
||||
open=Decimal("100"),
|
||||
high=Decimal("105"),
|
||||
low=Decimal("99"),
|
||||
close=Decimal("103"),
|
||||
volume=1000000,
|
||||
)
|
||||
snapshot = MarketSnapshot(
|
||||
ticker="AAPL",
|
||||
timestamp=datetime(2024, 1, 15),
|
||||
bar=bar,
|
||||
)
|
||||
assert snapshot.change is None
|
||||
assert snapshot.change_percent is None
|
||||
|
||||
|
||||
class TestHistoricalDataRequest:
|
||||
def test_valid_request(self):
|
||||
request = HistoricalDataRequest(
|
||||
ticker="AAPL",
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 6, 30),
|
||||
)
|
||||
assert request.ticker == "AAPL"
|
||||
assert request.include_indicators is True
|
||||
|
||||
def test_invalid_date_range(self):
|
||||
with pytest.raises(ValueError):
|
||||
HistoricalDataRequest(
|
||||
ticker="AAPL",
|
||||
start_date=date(2024, 6, 30),
|
||||
end_date=date(2024, 1, 1),
|
||||
)
|
||||
|
|
@ -0,0 +1,207 @@
|
|||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.models.portfolio import (
|
||||
PortfolioConfig,
|
||||
PortfolioSnapshot,
|
||||
CashTransaction,
|
||||
TransactionType,
|
||||
)
|
||||
from tradingagents.models.trading import OrderSide, Fill, Position
|
||||
|
||||
|
||||
class TestPortfolioConfig:
|
||||
def test_default_config(self):
|
||||
config = PortfolioConfig()
|
||||
assert config.initial_cash == Decimal("100000")
|
||||
assert config.commission_per_share == Decimal("0")
|
||||
assert config.slippage_percent == Decimal("0")
|
||||
|
||||
def test_custom_config(self):
|
||||
config = PortfolioConfig(
|
||||
initial_cash=Decimal("50000"),
|
||||
commission_per_trade=Decimal("5.00"),
|
||||
slippage_percent=Decimal("0.1"),
|
||||
)
|
||||
assert config.initial_cash == Decimal("50000")
|
||||
assert config.commission_per_trade == Decimal("5.00")
|
||||
|
||||
def test_calculate_commission_flat(self):
|
||||
config = PortfolioConfig(commission_per_trade=Decimal("5.00"))
|
||||
commission = config.calculate_commission(100, Decimal("150.00"))
|
||||
assert commission == Decimal("5.00")
|
||||
|
||||
def test_calculate_commission_per_share(self):
|
||||
config = PortfolioConfig(commission_per_share=Decimal("0.01"))
|
||||
commission = config.calculate_commission(100, Decimal("150.00"))
|
||||
assert commission == Decimal("1.00")
|
||||
|
||||
def test_calculate_commission_percent(self):
|
||||
config = PortfolioConfig(commission_percent=Decimal("0.1"))
|
||||
commission = config.calculate_commission(100, Decimal("100.00"))
|
||||
assert commission == Decimal("10.00")
|
||||
|
||||
def test_calculate_commission_minimum(self):
|
||||
config = PortfolioConfig(
|
||||
commission_per_trade=Decimal("1.00"),
|
||||
min_commission=Decimal("5.00"),
|
||||
)
|
||||
commission = config.calculate_commission(10, Decimal("10.00"))
|
||||
assert commission == Decimal("5.00")
|
||||
|
||||
def test_calculate_commission_maximum(self):
|
||||
config = PortfolioConfig(
|
||||
commission_percent=Decimal("1"),
|
||||
max_commission=Decimal("50.00"),
|
||||
)
|
||||
commission = config.calculate_commission(1000, Decimal("100.00"))
|
||||
assert commission == Decimal("50.00")
|
||||
|
||||
def test_calculate_slippage_buy(self):
|
||||
config = PortfolioConfig(slippage_percent=Decimal("0.1"))
|
||||
price = config.calculate_slippage(Decimal("100.00"), OrderSide.BUY)
|
||||
assert price == Decimal("100.10")
|
||||
|
||||
def test_calculate_slippage_sell(self):
|
||||
config = PortfolioConfig(slippage_percent=Decimal("0.1"))
|
||||
price = config.calculate_slippage(Decimal("100.00"), OrderSide.SELL)
|
||||
assert price == Decimal("99.90")
|
||||
|
||||
|
||||
class TestPortfolioSnapshot:
|
||||
def test_new_portfolio(self):
|
||||
portfolio = PortfolioSnapshot(cash=Decimal("100000"))
|
||||
assert portfolio.cash == Decimal("100000")
|
||||
assert portfolio.position_count == 0
|
||||
assert len(portfolio.positions) == 0
|
||||
|
||||
def test_get_position_creates_new(self):
|
||||
portfolio = PortfolioSnapshot(cash=Decimal("100000"))
|
||||
position = portfolio.get_position("AAPL")
|
||||
assert position.ticker == "AAPL"
|
||||
assert position.quantity == 0
|
||||
assert "AAPL" in portfolio.positions
|
||||
|
||||
def test_positions_value(self):
|
||||
portfolio = PortfolioSnapshot(
|
||||
cash=Decimal("50000"),
|
||||
positions={
|
||||
"AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")),
|
||||
"GOOGL": Position(ticker="GOOGL", quantity=50, avg_cost=Decimal("100")),
|
||||
},
|
||||
)
|
||||
prices = {"AAPL": Decimal("160"), "GOOGL": Decimal("110")}
|
||||
assert portfolio.positions_value(prices) == Decimal("21500")
|
||||
|
||||
def test_total_equity(self):
|
||||
portfolio = PortfolioSnapshot(
|
||||
cash=Decimal("50000"),
|
||||
positions={
|
||||
"AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")),
|
||||
},
|
||||
)
|
||||
prices = {"AAPL": Decimal("160")}
|
||||
assert portfolio.total_equity(prices) == Decimal("66000")
|
||||
|
||||
def test_total_unrealized_pnl(self):
|
||||
portfolio = PortfolioSnapshot(
|
||||
cash=Decimal("50000"),
|
||||
positions={
|
||||
"AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")),
|
||||
},
|
||||
)
|
||||
prices = {"AAPL": Decimal("160")}
|
||||
assert portfolio.total_unrealized_pnl(prices) == Decimal("1000")
|
||||
|
||||
def test_apply_buy_fill(self):
|
||||
portfolio = PortfolioSnapshot(cash=Decimal("100000"))
|
||||
fill = Fill(
|
||||
order_id=uuid4(),
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
quantity=100,
|
||||
price=Decimal("150.00"),
|
||||
commission=Decimal("5.00"),
|
||||
)
|
||||
portfolio.apply_fill(fill)
|
||||
assert portfolio.cash == Decimal("84995.00")
|
||||
assert portfolio.positions["AAPL"].quantity == 100
|
||||
assert portfolio.total_commission_paid == Decimal("5.00")
|
||||
|
||||
def test_apply_sell_fill_with_profit(self):
|
||||
portfolio = PortfolioSnapshot(
|
||||
cash=Decimal("50000"),
|
||||
positions={
|
||||
"AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")),
|
||||
},
|
||||
)
|
||||
fill = Fill(
|
||||
order_id=uuid4(),
|
||||
ticker="AAPL",
|
||||
side=OrderSide.SELL,
|
||||
quantity=100,
|
||||
price=Decimal("160.00"),
|
||||
commission=Decimal("5.00"),
|
||||
)
|
||||
portfolio.apply_fill(fill)
|
||||
assert portfolio.cash == Decimal("65995.00")
|
||||
assert portfolio.positions["AAPL"].quantity == 0
|
||||
assert portfolio.total_realized_pnl == Decimal("1000.00")
|
||||
|
||||
def test_add_deposit(self):
|
||||
portfolio = PortfolioSnapshot(cash=Decimal("100000"))
|
||||
transaction = CashTransaction(
|
||||
transaction_type=TransactionType.DEPOSIT,
|
||||
amount=Decimal("10000"),
|
||||
)
|
||||
portfolio.add_cash_transaction(transaction)
|
||||
assert portfolio.cash == Decimal("110000")
|
||||
assert len(portfolio.cash_transactions) == 1
|
||||
|
||||
def test_add_withdrawal(self):
|
||||
portfolio = PortfolioSnapshot(cash=Decimal("100000"))
|
||||
transaction = CashTransaction(
|
||||
transaction_type=TransactionType.WITHDRAWAL,
|
||||
amount=Decimal("10000"),
|
||||
)
|
||||
portfolio.add_cash_transaction(transaction)
|
||||
assert portfolio.cash == Decimal("90000")
|
||||
|
||||
def test_can_afford(self):
|
||||
portfolio = PortfolioSnapshot(cash=Decimal("10000"))
|
||||
config = PortfolioConfig(commission_per_trade=Decimal("5"))
|
||||
|
||||
assert portfolio.can_afford("AAPL", 10, Decimal("100"), config)
|
||||
assert not portfolio.can_afford("AAPL", 100, Decimal("100"), config)
|
||||
|
||||
def test_max_shares_affordable(self):
|
||||
portfolio = PortfolioSnapshot(cash=Decimal("10000"))
|
||||
config = PortfolioConfig(commission_per_trade=Decimal("0"))
|
||||
|
||||
max_shares = portfolio.max_shares_affordable("AAPL", Decimal("100"), config)
|
||||
assert max_shares == 100
|
||||
|
||||
def test_max_shares_affordable_with_commission(self):
|
||||
portfolio = PortfolioSnapshot(cash=Decimal("10050"))
|
||||
config = PortfolioConfig(commission_per_trade=Decimal("50"))
|
||||
|
||||
max_shares = portfolio.max_shares_affordable("AAPL", Decimal("100"), config)
|
||||
assert max_shares == 100
|
||||
|
||||
def test_to_dict(self):
|
||||
portfolio = PortfolioSnapshot(
|
||||
cash=Decimal("50000"),
|
||||
positions={
|
||||
"AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")),
|
||||
},
|
||||
)
|
||||
prices = {"AAPL": Decimal("160")}
|
||||
result = portfolio.to_dict(prices)
|
||||
|
||||
assert result["cash"] == 50000.0
|
||||
assert result["positions_value"] == 16000.0
|
||||
assert result["total_equity"] == 66000.0
|
||||
assert result["position_count"] == 1
|
||||
|
|
@ -0,0 +1,297 @@
|
|||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.models.trading import (
|
||||
OrderSide,
|
||||
OrderType,
|
||||
OrderStatus,
|
||||
PositionSide,
|
||||
Order,
|
||||
Fill,
|
||||
Position,
|
||||
Trade,
|
||||
)
|
||||
|
||||
|
||||
class TestOrder:
|
||||
def test_market_order_creation(self):
|
||||
order = Order(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
quantity=100,
|
||||
)
|
||||
assert order.ticker == "AAPL"
|
||||
assert order.side == OrderSide.BUY
|
||||
assert order.order_type == OrderType.MARKET
|
||||
assert order.quantity == 100
|
||||
assert order.status == OrderStatus.PENDING
|
||||
assert order.remaining_quantity == 100
|
||||
assert not order.is_complete
|
||||
|
||||
def test_limit_order_creation(self):
|
||||
order = Order(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.SELL,
|
||||
order_type=OrderType.LIMIT,
|
||||
quantity=50,
|
||||
limit_price=Decimal("150.00"),
|
||||
)
|
||||
assert order.order_type == OrderType.LIMIT
|
||||
assert order.limit_price == Decimal("150.00")
|
||||
|
||||
def test_order_partial_fill(self):
|
||||
order = Order(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
quantity=100,
|
||||
status=OrderStatus.PARTIAL,
|
||||
filled_quantity=30,
|
||||
)
|
||||
assert order.remaining_quantity == 70
|
||||
assert not order.is_complete
|
||||
|
||||
def test_order_complete_states(self):
|
||||
for status in [OrderStatus.FILLED, OrderStatus.CANCELLED, OrderStatus.REJECTED]:
|
||||
order = Order(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
quantity=100,
|
||||
status=status,
|
||||
)
|
||||
assert order.is_complete
|
||||
|
||||
def test_invalid_quantity(self):
|
||||
with pytest.raises(ValueError):
|
||||
Order(ticker="AAPL", side=OrderSide.BUY, quantity=0)
|
||||
|
||||
def test_invalid_limit_price(self):
|
||||
with pytest.raises(ValueError):
|
||||
Order(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
order_type=OrderType.LIMIT,
|
||||
quantity=100,
|
||||
limit_price=Decimal("-10"),
|
||||
)
|
||||
|
||||
|
||||
class TestFill:
|
||||
def test_buy_fill(self):
|
||||
order_id = uuid4()
|
||||
fill = Fill(
|
||||
order_id=order_id,
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
quantity=100,
|
||||
price=Decimal("150.00"),
|
||||
commission=Decimal("1.00"),
|
||||
)
|
||||
assert fill.total_value == Decimal("15000.00")
|
||||
assert fill.total_cost == Decimal("15001.00")
|
||||
|
||||
def test_sell_fill(self):
|
||||
order_id = uuid4()
|
||||
fill = Fill(
|
||||
order_id=order_id,
|
||||
ticker="AAPL",
|
||||
side=OrderSide.SELL,
|
||||
quantity=100,
|
||||
price=Decimal("150.00"),
|
||||
commission=Decimal("1.00"),
|
||||
)
|
||||
assert fill.total_value == Decimal("15000.00")
|
||||
assert fill.total_cost == Decimal("14999.00")
|
||||
|
||||
|
||||
class TestPosition:
|
||||
def test_new_position(self):
|
||||
position = Position(ticker="AAPL")
|
||||
assert position.quantity == 0
|
||||
assert position.side == PositionSide.FLAT
|
||||
assert position.cost_basis == Decimal("0")
|
||||
|
||||
def test_long_position(self):
|
||||
position = Position(
|
||||
ticker="AAPL",
|
||||
quantity=100,
|
||||
avg_cost=Decimal("150.00"),
|
||||
)
|
||||
assert position.side == PositionSide.LONG
|
||||
assert position.cost_basis == Decimal("15000.00")
|
||||
|
||||
def test_short_position(self):
|
||||
position = Position(
|
||||
ticker="AAPL",
|
||||
quantity=-100,
|
||||
avg_cost=Decimal("150.00"),
|
||||
)
|
||||
assert position.side == PositionSide.SHORT
|
||||
assert position.cost_basis == Decimal("15000.00")
|
||||
|
||||
def test_unrealized_pnl_long(self):
|
||||
position = Position(
|
||||
ticker="AAPL",
|
||||
quantity=100,
|
||||
avg_cost=Decimal("150.00"),
|
||||
)
|
||||
pnl = position.unrealized_pnl(Decimal("160.00"))
|
||||
assert pnl == Decimal("1000.00")
|
||||
|
||||
def test_unrealized_pnl_short(self):
|
||||
position = Position(
|
||||
ticker="AAPL",
|
||||
quantity=-100,
|
||||
avg_cost=Decimal("150.00"),
|
||||
)
|
||||
pnl = position.unrealized_pnl(Decimal("140.00"))
|
||||
assert pnl == Decimal("1000.00")
|
||||
|
||||
def test_update_from_buy_fill_new_position(self):
|
||||
position = Position(ticker="AAPL")
|
||||
fill = Fill(
|
||||
order_id=uuid4(),
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
quantity=100,
|
||||
price=Decimal("150.00"),
|
||||
)
|
||||
position.update_from_fill(fill)
|
||||
assert position.quantity == 100
|
||||
assert position.avg_cost == Decimal("150.00")
|
||||
|
||||
def test_update_from_buy_fill_add_to_position(self):
|
||||
position = Position(
|
||||
ticker="AAPL",
|
||||
quantity=100,
|
||||
avg_cost=Decimal("150.00"),
|
||||
)
|
||||
fill = Fill(
|
||||
order_id=uuid4(),
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
quantity=100,
|
||||
price=Decimal("160.00"),
|
||||
)
|
||||
position.update_from_fill(fill)
|
||||
assert position.quantity == 200
|
||||
assert position.avg_cost == Decimal("155.00")
|
||||
|
||||
def test_update_from_sell_fill_close_position(self):
|
||||
position = Position(
|
||||
ticker="AAPL",
|
||||
quantity=100,
|
||||
avg_cost=Decimal("150.00"),
|
||||
)
|
||||
fill = Fill(
|
||||
order_id=uuid4(),
|
||||
ticker="AAPL",
|
||||
side=OrderSide.SELL,
|
||||
quantity=100,
|
||||
price=Decimal("160.00"),
|
||||
)
|
||||
position.update_from_fill(fill)
|
||||
assert position.quantity == 0
|
||||
assert position.realized_pnl == Decimal("1000.00")
|
||||
|
||||
def test_update_from_sell_fill_partial_close(self):
|
||||
position = Position(
|
||||
ticker="AAPL",
|
||||
quantity=100,
|
||||
avg_cost=Decimal("150.00"),
|
||||
)
|
||||
fill = Fill(
|
||||
order_id=uuid4(),
|
||||
ticker="AAPL",
|
||||
side=OrderSide.SELL,
|
||||
quantity=50,
|
||||
price=Decimal("160.00"),
|
||||
)
|
||||
position.update_from_fill(fill)
|
||||
assert position.quantity == 50
|
||||
assert position.realized_pnl == Decimal("500.00")
|
||||
assert position.avg_cost == Decimal("150.00")
|
||||
|
||||
|
||||
class TestTrade:
|
||||
def test_open_trade(self):
|
||||
trade = Trade(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
entry_price=Decimal("150.00"),
|
||||
entry_quantity=100,
|
||||
entry_time=datetime(2024, 1, 15, 9, 30),
|
||||
)
|
||||
assert not trade.is_closed
|
||||
assert trade.pnl is None
|
||||
assert trade.holding_period is None
|
||||
|
||||
def test_closed_trade_profit(self):
|
||||
trade = Trade(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
entry_price=Decimal("150.00"),
|
||||
entry_quantity=100,
|
||||
entry_time=datetime(2024, 1, 15),
|
||||
exit_price=Decimal("160.00"),
|
||||
exit_quantity=100,
|
||||
exit_time=datetime(2024, 1, 25),
|
||||
)
|
||||
assert trade.is_closed
|
||||
assert trade.pnl == Decimal("1000.00")
|
||||
assert trade.holding_period == 10
|
||||
|
||||
def test_closed_trade_loss(self):
|
||||
trade = Trade(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
entry_price=Decimal("150.00"),
|
||||
entry_quantity=100,
|
||||
entry_time=datetime(2024, 1, 15),
|
||||
exit_price=Decimal("140.00"),
|
||||
exit_quantity=100,
|
||||
exit_time=datetime(2024, 1, 20),
|
||||
)
|
||||
assert trade.pnl == Decimal("-1000.00")
|
||||
|
||||
def test_trade_with_commission(self):
|
||||
trade = Trade(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
entry_price=Decimal("150.00"),
|
||||
entry_quantity=100,
|
||||
entry_time=datetime(2024, 1, 15),
|
||||
exit_price=Decimal("160.00"),
|
||||
exit_quantity=100,
|
||||
exit_time=datetime(2024, 1, 20),
|
||||
commission=Decimal("10.00"),
|
||||
)
|
||||
assert trade.pnl == Decimal("990.00")
|
||||
|
||||
def test_short_trade_profit(self):
|
||||
trade = Trade(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.SELL,
|
||||
entry_price=Decimal("150.00"),
|
||||
entry_quantity=100,
|
||||
entry_time=datetime(2024, 1, 15),
|
||||
exit_price=Decimal("140.00"),
|
||||
exit_quantity=100,
|
||||
exit_time=datetime(2024, 1, 20),
|
||||
)
|
||||
assert trade.pnl == Decimal("1000.00")
|
||||
|
||||
def test_pnl_percent(self):
|
||||
trade = Trade(
|
||||
ticker="AAPL",
|
||||
side=OrderSide.BUY,
|
||||
entry_price=Decimal("100.00"),
|
||||
entry_quantity=100,
|
||||
entry_time=datetime(2024, 1, 15),
|
||||
exit_price=Decimal("110.00"),
|
||||
exit_quantity=100,
|
||||
exit_time=datetime(2024, 1, 20),
|
||||
)
|
||||
assert trade.pnl_percent == Decimal("10")
|
||||
|
|
@ -10,7 +10,7 @@ from .analysts.social_media_analyst import create_social_media_analyst
|
|||
from .researchers.bear_researcher import create_bear_researcher
|
||||
from .researchers.bull_researcher import create_bull_researcher
|
||||
|
||||
from .risk_mgmt.aggresive_debator import create_risky_debator
|
||||
from .risk_mgmt.aggressive_debator import create_risky_debator
|
||||
from .risk_mgmt.conservative_debator import create_safe_debator
|
||||
from .risk_mgmt.neutral_debator import create_neutral_debator
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
from .data_loader import DataLoader
|
||||
from .engine import BacktestEngine, SimpleBacktestEngine
|
||||
from .metrics import MetricsCalculator
|
||||
from .agent_integration import AgentBacktestEngine, run_agent_backtest
|
||||
|
||||
__all__ = [
|
||||
"DataLoader",
|
||||
"BacktestEngine",
|
||||
"SimpleBacktestEngine",
|
||||
"MetricsCalculator",
|
||||
"AgentBacktestEngine",
|
||||
"run_agent_backtest",
|
||||
]
|
||||
|
|
@ -0,0 +1,249 @@
|
|||
import logging
|
||||
import re
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.models.backtest import BacktestConfig, BacktestResult
|
||||
from tradingagents.models.decisions import (
|
||||
SignalType,
|
||||
TradingDecision,
|
||||
AnalystReport,
|
||||
AnalystType,
|
||||
)
|
||||
from tradingagents.models.portfolio import PortfolioSnapshot
|
||||
|
||||
from .engine import BacktestEngine
|
||||
from .data_loader import DataLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentBacktestEngine(BacktestEngine):
|
||||
def __init__(
|
||||
self,
|
||||
config: BacktestConfig,
|
||||
agent_config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.agent_config = agent_config or config.agent_config
|
||||
self.trading_graph: Optional[TradingAgentsGraph] = None
|
||||
self._decision_cache: Dict[str, TradingDecision] = {}
|
||||
|
||||
def _initialize(self):
|
||||
super()._initialize()
|
||||
|
||||
graph_config = {
|
||||
**self.agent_config,
|
||||
}
|
||||
|
||||
self.trading_graph = TradingAgentsGraph(
|
||||
selected_analysts=self.agent_config.get(
|
||||
"selected_analysts",
|
||||
["market", "social", "news", "fundamentals"],
|
||||
),
|
||||
debug=self.agent_config.get("debug", False),
|
||||
config=graph_config if graph_config else None,
|
||||
)
|
||||
|
||||
def _get_decision(
|
||||
self,
|
||||
ticker: str,
|
||||
trading_date: date,
|
||||
day_index: int,
|
||||
) -> Optional[TradingDecision]:
|
||||
cache_key = f"{ticker}_{trading_date}"
|
||||
if cache_key in self._decision_cache:
|
||||
return self._decision_cache[cache_key]
|
||||
|
||||
try:
|
||||
final_state, signal_info = self.trading_graph.propagate(
|
||||
ticker, trading_date
|
||||
)
|
||||
|
||||
decision = self._parse_agent_decision(
|
||||
ticker, trading_date, final_state, signal_info
|
||||
)
|
||||
|
||||
self._decision_cache[cache_key] = decision
|
||||
return decision
|
||||
|
||||
except (ValueError, KeyError, RuntimeError, ConnectionError, TimeoutError) as e:
|
||||
logger.error(
|
||||
"Agent decision failed for %s on %s: %s",
|
||||
ticker, trading_date, e
|
||||
)
|
||||
return None
|
||||
|
||||
def _parse_agent_decision(
|
||||
self,
|
||||
ticker: str,
|
||||
trading_date: date,
|
||||
final_state: Dict[str, Any],
|
||||
signal_info: Dict[str, Any],
|
||||
) -> TradingDecision:
|
||||
signal = self._extract_signal(signal_info)
|
||||
confidence = self._extract_confidence(signal_info)
|
||||
|
||||
analyst_reports = []
|
||||
|
||||
if final_state.get("market_report"):
|
||||
analyst_reports.append(
|
||||
AnalystReport(
|
||||
analyst_type=AnalystType.MARKET,
|
||||
ticker=ticker,
|
||||
report_date=datetime.combine(trading_date, datetime.min.time()),
|
||||
summary=final_state["market_report"][:500],
|
||||
raw_content=final_state["market_report"],
|
||||
)
|
||||
)
|
||||
|
||||
if final_state.get("sentiment_report"):
|
||||
analyst_reports.append(
|
||||
AnalystReport(
|
||||
analyst_type=AnalystType.SENTIMENT,
|
||||
ticker=ticker,
|
||||
report_date=datetime.combine(trading_date, datetime.min.time()),
|
||||
summary=final_state["sentiment_report"][:500],
|
||||
raw_content=final_state["sentiment_report"],
|
||||
)
|
||||
)
|
||||
|
||||
if final_state.get("news_report"):
|
||||
analyst_reports.append(
|
||||
AnalystReport(
|
||||
analyst_type=AnalystType.NEWS,
|
||||
ticker=ticker,
|
||||
report_date=datetime.combine(trading_date, datetime.min.time()),
|
||||
summary=final_state["news_report"][:500],
|
||||
raw_content=final_state["news_report"],
|
||||
)
|
||||
)
|
||||
|
||||
if final_state.get("fundamentals_report"):
|
||||
analyst_reports.append(
|
||||
AnalystReport(
|
||||
analyst_type=AnalystType.FUNDAMENTALS,
|
||||
ticker=ticker,
|
||||
report_date=datetime.combine(trading_date, datetime.min.time()),
|
||||
summary=final_state["fundamentals_report"][:500],
|
||||
raw_content=final_state["fundamentals_report"],
|
||||
)
|
||||
)
|
||||
|
||||
debate_state = final_state.get("investment_debate_state", {})
|
||||
bull_argument = None
|
||||
bear_argument = None
|
||||
|
||||
if debate_state.get("bull_history"):
|
||||
bull_argument = debate_state["bull_history"][-1] if debate_state["bull_history"] else None
|
||||
if debate_state.get("bear_history"):
|
||||
bear_argument = debate_state["bear_history"][-1] if debate_state["bear_history"] else None
|
||||
|
||||
risk_state = final_state.get("risk_debate_state", {})
|
||||
risk_approved = self._extract_risk_approval(risk_state)
|
||||
|
||||
final_decision_text = final_state.get("final_trade_decision", "")
|
||||
recommended_action = self._extract_action(signal_info, final_decision_text)
|
||||
|
||||
return TradingDecision(
|
||||
ticker=ticker,
|
||||
timestamp=datetime.now(),
|
||||
decision_date=datetime.combine(trading_date, datetime.min.time()),
|
||||
signal=signal,
|
||||
confidence=confidence,
|
||||
recommended_action=recommended_action,
|
||||
analyst_reports=analyst_reports,
|
||||
bull_argument=bull_argument,
|
||||
bear_argument=bear_argument,
|
||||
debate_rounds=debate_state.get("count", 0),
|
||||
risk_manager_approved=risk_approved,
|
||||
final_decision=recommended_action,
|
||||
rationale=final_decision_text[:1000] if final_decision_text else "",
|
||||
)
|
||||
|
||||
def _extract_signal(self, signal_info: Dict[str, Any]) -> SignalType:
|
||||
action = signal_info.get("action", "").upper()
|
||||
direction = signal_info.get("direction", "").upper()
|
||||
|
||||
if action == "BUY" or direction == "BULLISH":
|
||||
confidence = signal_info.get("confidence", 0.5)
|
||||
if confidence > 0.8:
|
||||
return SignalType.STRONG_BUY
|
||||
return SignalType.BUY
|
||||
|
||||
elif action == "SELL" or direction == "BEARISH":
|
||||
confidence = signal_info.get("confidence", 0.5)
|
||||
if confidence > 0.8:
|
||||
return SignalType.STRONG_SELL
|
||||
return SignalType.SELL
|
||||
|
||||
return SignalType.HOLD
|
||||
|
||||
def _extract_confidence(self, signal_info: Dict[str, Any]) -> Decimal:
|
||||
confidence = signal_info.get("confidence", 0.5)
|
||||
if isinstance(confidence, str):
|
||||
try:
|
||||
confidence = float(confidence.replace("%", "")) / 100
|
||||
except ValueError:
|
||||
confidence = 0.5
|
||||
|
||||
return Decimal(str(min(max(float(confidence), 0.0), 1.0)))
|
||||
|
||||
def _extract_action(
|
||||
self,
|
||||
signal_info: Dict[str, Any],
|
||||
final_decision_text: str,
|
||||
) -> str:
|
||||
action = signal_info.get("action", "")
|
||||
if action:
|
||||
return action.upper()
|
||||
|
||||
text_upper = final_decision_text.upper()
|
||||
if "BUY" in text_upper and "DON'T BUY" not in text_upper:
|
||||
return "BUY"
|
||||
elif "SELL" in text_upper:
|
||||
return "SELL"
|
||||
|
||||
return "HOLD"
|
||||
|
||||
def _extract_risk_approval(self, risk_state: Dict[str, Any]) -> Optional[bool]:
|
||||
judge_decision = risk_state.get("judge_decision", "")
|
||||
if not judge_decision:
|
||||
return None
|
||||
|
||||
text_upper = judge_decision.upper()
|
||||
if "APPROVE" in text_upper or "ACCEPT" in text_upper:
|
||||
return True
|
||||
elif "REJECT" in text_upper or "DENY" in text_upper:
|
||||
return False
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def run_agent_backtest(
|
||||
tickers: list[str],
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
initial_cash: Decimal = Decimal("100000"),
|
||||
agent_config: Optional[Dict[str, Any]] = None,
|
||||
) -> BacktestResult:
|
||||
from tradingagents.models.portfolio import PortfolioConfig
|
||||
|
||||
config = BacktestConfig(
|
||||
name=f"Agent Backtest - {', '.join(tickers)}",
|
||||
tickers=tickers,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
portfolio_config=PortfolioConfig(
|
||||
initial_cash=initial_cash,
|
||||
commission_per_trade=Decimal("1"),
|
||||
slippage_percent=Decimal("0.05"),
|
||||
),
|
||||
warmup_period=5,
|
||||
agent_config=agent_config or {},
|
||||
)
|
||||
|
||||
engine = AgentBacktestEngine(config, agent_config)
|
||||
return engine.run()
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
import logging
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import yfinance as yf
|
||||
from stockstats import wrap
|
||||
|
||||
from tradingagents.models.market_data import (
|
||||
OHLCV,
|
||||
OHLCVBar,
|
||||
TechnicalIndicators,
|
||||
HistoricalDataRequest,
|
||||
HistoricalDataResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataLoader:
|
||||
def __init__(self, cache_dir: Optional[str] = None):
|
||||
self.cache_dir = cache_dir
|
||||
self._cache: dict[str, pd.DataFrame] = {}
|
||||
|
||||
def load_ohlcv(
|
||||
self,
|
||||
ticker: str,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
interval: str = "1d",
|
||||
) -> OHLCV:
|
||||
ticker = ticker.upper()
|
||||
cache_key = f"{ticker}_{start_date}_{end_date}_{interval}"
|
||||
|
||||
if cache_key in self._cache:
|
||||
df = self._cache[cache_key]
|
||||
else:
|
||||
df = self._fetch_from_yfinance(ticker, start_date, end_date, interval)
|
||||
self._cache[cache_key] = df
|
||||
|
||||
bars = self._dataframe_to_bars(df)
|
||||
return OHLCV(ticker=ticker, bars=bars, interval=interval)
|
||||
|
||||
def load_historical_data(
|
||||
self,
|
||||
request: HistoricalDataRequest,
|
||||
) -> HistoricalDataResponse:
|
||||
ohlcv = self.load_ohlcv(
|
||||
request.ticker,
|
||||
request.start_date,
|
||||
request.end_date,
|
||||
request.interval,
|
||||
)
|
||||
|
||||
indicators = []
|
||||
if request.include_indicators and ohlcv.bars:
|
||||
indicators = self._calculate_indicators(
|
||||
request.ticker,
|
||||
request.start_date,
|
||||
request.end_date,
|
||||
)
|
||||
|
||||
return HistoricalDataResponse(
|
||||
request=request,
|
||||
ohlcv=ohlcv,
|
||||
indicators=indicators,
|
||||
source="yfinance",
|
||||
)
|
||||
|
||||
def _fetch_from_yfinance(
|
||||
self,
|
||||
ticker: str,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
interval: str,
|
||||
) -> pd.DataFrame:
|
||||
start_str = start_date.strftime("%Y-%m-%d")
|
||||
end_str = (end_date + timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
df = yf.download(
|
||||
ticker,
|
||||
start=start_str,
|
||||
end=end_str,
|
||||
interval=interval,
|
||||
multi_level_index=False,
|
||||
progress=False,
|
||||
auto_adjust=False,
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
logger.warning("No data returned for %s from %s to %s", ticker, start_date, end_date)
|
||||
return pd.DataFrame()
|
||||
|
||||
df = df.reset_index()
|
||||
return df
|
||||
|
||||
def _dataframe_to_bars(self, df: pd.DataFrame) -> list[OHLCVBar]:
|
||||
if df.empty:
|
||||
return []
|
||||
|
||||
bars = []
|
||||
for _, row in df.iterrows():
|
||||
timestamp = row.get("Date") or row.get("Datetime")
|
||||
if isinstance(timestamp, str):
|
||||
timestamp = pd.to_datetime(timestamp)
|
||||
if hasattr(timestamp, "to_pydatetime"):
|
||||
timestamp = timestamp.to_pydatetime()
|
||||
if timestamp.tzinfo is not None:
|
||||
timestamp = timestamp.replace(tzinfo=None)
|
||||
|
||||
bar = OHLCVBar(
|
||||
timestamp=timestamp,
|
||||
open=Decimal(str(round(row["Open"], 4))),
|
||||
high=Decimal(str(round(row["High"], 4))),
|
||||
low=Decimal(str(round(row["Low"], 4))),
|
||||
close=Decimal(str(round(row["Close"], 4))),
|
||||
volume=int(row["Volume"]),
|
||||
adjusted_close=Decimal(str(round(row["Adj Close"], 4))) if "Adj Close" in row else None,
|
||||
)
|
||||
bars.append(bar)
|
||||
|
||||
return bars
|
||||
|
||||
def _calculate_indicators(
|
||||
self,
|
||||
ticker: str,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
) -> list[TechnicalIndicators]:
|
||||
lookback_start = start_date - timedelta(days=250)
|
||||
cache_key = f"{ticker}_{lookback_start}_{end_date}_1d"
|
||||
|
||||
if cache_key in self._cache:
|
||||
df = self._cache[cache_key]
|
||||
else:
|
||||
df = self._fetch_from_yfinance(ticker, lookback_start, end_date, "1d")
|
||||
self._cache[cache_key] = df
|
||||
|
||||
if df.empty:
|
||||
return []
|
||||
|
||||
stock = wrap(df.copy())
|
||||
|
||||
stock["close_20_sma"]
|
||||
stock["close_50_sma"]
|
||||
stock["close_200_sma"]
|
||||
stock["close_10_ema"]
|
||||
stock["close_20_ema"]
|
||||
stock["rsi_14"]
|
||||
stock["macd"]
|
||||
stock["macds"]
|
||||
stock["macdh"]
|
||||
stock["boll"]
|
||||
stock["boll_ub"]
|
||||
stock["boll_lb"]
|
||||
stock["atr_14"]
|
||||
stock["mfi_14"]
|
||||
|
||||
indicators = []
|
||||
for _, row in stock.iterrows():
|
||||
timestamp = row.get("Date") or row.get("Datetime")
|
||||
if isinstance(timestamp, str):
|
||||
timestamp = pd.to_datetime(timestamp)
|
||||
if hasattr(timestamp, "to_pydatetime"):
|
||||
timestamp = timestamp.to_pydatetime()
|
||||
if timestamp.tzinfo is not None:
|
||||
timestamp = timestamp.replace(tzinfo=None)
|
||||
|
||||
if timestamp.date() < start_date or timestamp.date() > end_date:
|
||||
continue
|
||||
|
||||
ind = TechnicalIndicators(
|
||||
timestamp=timestamp,
|
||||
ticker=ticker,
|
||||
sma_20=self._safe_decimal(row.get("close_20_sma")),
|
||||
sma_50=self._safe_decimal(row.get("close_50_sma")),
|
||||
sma_200=self._safe_decimal(row.get("close_200_sma")),
|
||||
ema_10=self._safe_decimal(row.get("close_10_ema")),
|
||||
ema_20=self._safe_decimal(row.get("close_20_ema")),
|
||||
rsi_14=self._safe_decimal(row.get("rsi_14")),
|
||||
macd=self._safe_decimal(row.get("macd")),
|
||||
macd_signal=self._safe_decimal(row.get("macds")),
|
||||
macd_histogram=self._safe_decimal(row.get("macdh")),
|
||||
bollinger_middle=self._safe_decimal(row.get("boll")),
|
||||
bollinger_upper=self._safe_decimal(row.get("boll_ub")),
|
||||
bollinger_lower=self._safe_decimal(row.get("boll_lb")),
|
||||
atr_14=self._safe_decimal(row.get("atr_14")),
|
||||
mfi_14=self._safe_decimal(row.get("mfi_14")),
|
||||
)
|
||||
indicators.append(ind)
|
||||
|
||||
return indicators
|
||||
|
||||
@staticmethod
|
||||
def _safe_decimal(value) -> Optional[Decimal]:
|
||||
if value is None or pd.isna(value):
|
||||
return None
|
||||
return Decimal(str(round(float(value), 4)))
|
||||
|
||||
def get_price_on_date(
|
||||
self,
|
||||
ticker: str,
|
||||
target_date: date,
|
||||
ohlcv: Optional[OHLCV] = None,
|
||||
) -> Optional[Decimal]:
|
||||
if ohlcv is None:
|
||||
ohlcv = self.load_ohlcv(ticker, target_date - timedelta(days=5), target_date)
|
||||
|
||||
target_datetime = datetime.combine(target_date, datetime.min.time())
|
||||
bar = ohlcv.get_bar(target_datetime)
|
||||
|
||||
if bar:
|
||||
return bar.close
|
||||
|
||||
for b in reversed(ohlcv.bars):
|
||||
if b.timestamp.date() <= target_date:
|
||||
return b.close
|
||||
|
||||
return None
|
||||
|
||||
def get_prices_dict(
|
||||
self,
|
||||
tickers: list[str],
|
||||
target_date: date,
|
||||
) -> dict[str, Decimal]:
|
||||
prices = {}
|
||||
for ticker in tickers:
|
||||
price = self.get_price_on_date(ticker, target_date)
|
||||
if price is not None:
|
||||
prices[ticker] = price
|
||||
return prices
|
||||
|
||||
def get_trading_days(
|
||||
self,
|
||||
ticker: str,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
) -> list[date]:
|
||||
ohlcv = self.load_ohlcv(ticker, start_date, end_date)
|
||||
return [bar.timestamp.date() for bar in ohlcv.bars]
|
||||
|
||||
def clear_cache(self):
|
||||
self._cache.clear()
|
||||
|
|
@ -0,0 +1,366 @@
|
|||
import logging
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Callable
|
||||
from uuid import uuid4
|
||||
|
||||
from tradingagents.models.backtest import (
|
||||
BacktestConfig,
|
||||
BacktestResult,
|
||||
EquityCurvePoint,
|
||||
TradeLog,
|
||||
)
|
||||
from tradingagents.models.decisions import SignalType, TradingDecision
|
||||
from tradingagents.models.portfolio import PortfolioSnapshot
|
||||
from tradingagents.models.trading import Order, OrderSide, OrderStatus, Fill, Trade
|
||||
|
||||
from .data_loader import DataLoader
|
||||
from .metrics import MetricsCalculator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BacktestEngine:
|
||||
def __init__(
|
||||
self,
|
||||
config: BacktestConfig,
|
||||
decision_callback: Optional[Callable[[str, date, dict], TradingDecision]] = None,
|
||||
):
|
||||
self.config = config
|
||||
self.decision_callback = decision_callback
|
||||
self.data_loader = DataLoader()
|
||||
self.metrics_calculator = MetricsCalculator(config.risk_free_rate)
|
||||
|
||||
self.portfolio: Optional[PortfolioSnapshot] = None
|
||||
self.trade_log: Optional[TradeLog] = None
|
||||
self.equity_curve: list[EquityCurvePoint] = []
|
||||
self.daily_returns: list[Decimal] = []
|
||||
self.decisions: list[TradingDecision] = []
|
||||
self.open_trades: dict[str, Trade] = {}
|
||||
|
||||
def run(self) -> BacktestResult:
|
||||
started_at = datetime.now()
|
||||
|
||||
try:
|
||||
self._initialize()
|
||||
self._preload_data()
|
||||
trading_days = self._get_trading_days()
|
||||
|
||||
for i, trading_date in enumerate(trading_days):
|
||||
if i < self.config.warmup_period:
|
||||
continue
|
||||
|
||||
self._process_day(trading_date, i)
|
||||
|
||||
self._close_all_positions(trading_days[-1] if trading_days else self.config.end_date)
|
||||
|
||||
metrics = self.metrics_calculator.calculate_metrics(
|
||||
self.equity_curve,
|
||||
self.trade_log,
|
||||
)
|
||||
|
||||
completed_at = datetime.now()
|
||||
|
||||
return BacktestResult(
|
||||
config=self.config,
|
||||
metrics=metrics,
|
||||
trade_log=self.trade_log,
|
||||
equity_curve=self.equity_curve,
|
||||
daily_returns=self.daily_returns,
|
||||
started_at=started_at,
|
||||
completed_at=completed_at,
|
||||
status="completed",
|
||||
)
|
||||
|
||||
except (ValueError, KeyError, RuntimeError, FileNotFoundError, OSError) as e:
|
||||
logger.exception("Backtest failed: %s", e)
|
||||
completed_at = datetime.now()
|
||||
|
||||
return BacktestResult(
|
||||
config=self.config,
|
||||
metrics=self._empty_metrics(),
|
||||
trade_log=self.trade_log or TradeLog(),
|
||||
equity_curve=self.equity_curve,
|
||||
daily_returns=self.daily_returns,
|
||||
started_at=started_at,
|
||||
completed_at=completed_at,
|
||||
status="failed",
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
def _initialize(self):
|
||||
self.portfolio = PortfolioSnapshot(
|
||||
cash=self.config.portfolio_config.initial_cash,
|
||||
)
|
||||
self.trade_log = TradeLog()
|
||||
self.equity_curve = []
|
||||
self.daily_returns = []
|
||||
self.decisions = []
|
||||
self.open_trades = {}
|
||||
|
||||
def _preload_data(self):
|
||||
logger.info("Preloading data for %s tickers", len(self.config.tickers))
|
||||
for ticker in self.config.tickers:
|
||||
self.data_loader.load_ohlcv(
|
||||
ticker,
|
||||
self.config.start_date - timedelta(days=self.config.warmup_period + 10),
|
||||
self.config.end_date,
|
||||
)
|
||||
|
||||
def _get_trading_days(self) -> list[date]:
|
||||
primary_ticker = self.config.tickers[0]
|
||||
return self.data_loader.get_trading_days(
|
||||
primary_ticker,
|
||||
self.config.start_date,
|
||||
self.config.end_date,
|
||||
)
|
||||
|
||||
def _process_day(self, trading_date: date, day_index: int):
|
||||
prices = self.data_loader.get_prices_dict(self.config.tickers, trading_date)
|
||||
|
||||
if not prices:
|
||||
logger.debug("No prices available for %s", trading_date)
|
||||
return
|
||||
|
||||
for ticker in self.config.tickers:
|
||||
if ticker not in prices:
|
||||
continue
|
||||
|
||||
decision = self._get_decision(ticker, trading_date, day_index)
|
||||
if decision:
|
||||
self.decisions.append(decision)
|
||||
self._execute_decision(decision, prices[ticker], trading_date)
|
||||
|
||||
self._record_equity(trading_date, prices)
|
||||
|
||||
def _get_decision(
|
||||
self,
|
||||
ticker: str,
|
||||
trading_date: date,
|
||||
day_index: int,
|
||||
) -> Optional[TradingDecision]:
|
||||
if self.decision_callback:
|
||||
context = {
|
||||
"day_index": day_index,
|
||||
"portfolio": self.portfolio,
|
||||
"open_trade": self.open_trades.get(ticker),
|
||||
}
|
||||
return self.decision_callback(ticker, trading_date, context)
|
||||
|
||||
return self._simple_strategy(ticker, trading_date)
|
||||
|
||||
def _simple_strategy(
|
||||
self,
|
||||
ticker: str,
|
||||
trading_date: date,
|
||||
) -> Optional[TradingDecision]:
|
||||
return None
|
||||
|
||||
def _execute_decision(
|
||||
self,
|
||||
decision: TradingDecision,
|
||||
price: Decimal,
|
||||
trading_date: date,
|
||||
):
|
||||
ticker = decision.ticker
|
||||
config = self.config.portfolio_config
|
||||
position = self.portfolio.get_position(ticker)
|
||||
|
||||
if decision.is_buy and position.quantity == 0:
|
||||
execution_price = config.calculate_slippage(price, OrderSide.BUY)
|
||||
|
||||
if decision.recommended_quantity:
|
||||
quantity = decision.recommended_quantity
|
||||
else:
|
||||
max_position_value = self.portfolio.cash * (config.max_position_size_percent / 100)
|
||||
quantity = int(max_position_value / execution_price)
|
||||
|
||||
if quantity <= 0:
|
||||
return
|
||||
|
||||
if not self.portfolio.can_afford(ticker, quantity, execution_price, config):
|
||||
quantity = self.portfolio.max_shares_affordable(ticker, execution_price, config)
|
||||
|
||||
if quantity <= 0:
|
||||
return
|
||||
|
||||
commission = config.calculate_commission(quantity, execution_price)
|
||||
|
||||
order = Order(
|
||||
ticker=ticker,
|
||||
side=OrderSide.BUY,
|
||||
quantity=quantity,
|
||||
status=OrderStatus.FILLED,
|
||||
filled_quantity=quantity,
|
||||
filled_avg_price=execution_price,
|
||||
filled_at=datetime.combine(trading_date, datetime.min.time()),
|
||||
commission=commission,
|
||||
)
|
||||
|
||||
fill = Fill(
|
||||
order_id=order.id,
|
||||
ticker=ticker,
|
||||
side=OrderSide.BUY,
|
||||
quantity=quantity,
|
||||
price=execution_price,
|
||||
commission=commission,
|
||||
timestamp=datetime.combine(trading_date, datetime.min.time()),
|
||||
)
|
||||
|
||||
self.portfolio.apply_fill(fill)
|
||||
|
||||
trade = Trade(
|
||||
ticker=ticker,
|
||||
side=OrderSide.BUY,
|
||||
entry_price=execution_price,
|
||||
entry_quantity=quantity,
|
||||
entry_time=datetime.combine(trading_date, datetime.min.time()),
|
||||
entry_order_id=order.id,
|
||||
)
|
||||
self.open_trades[ticker] = trade
|
||||
|
||||
logger.debug(
|
||||
"BUY %s: %d shares @ $%.2f on %s",
|
||||
ticker, quantity, execution_price, trading_date
|
||||
)
|
||||
|
||||
elif decision.is_sell and position.quantity > 0:
|
||||
execution_price = config.calculate_slippage(price, OrderSide.SELL)
|
||||
quantity = position.quantity
|
||||
commission = config.calculate_commission(quantity, execution_price)
|
||||
|
||||
order = Order(
|
||||
ticker=ticker,
|
||||
side=OrderSide.SELL,
|
||||
quantity=quantity,
|
||||
status=OrderStatus.FILLED,
|
||||
filled_quantity=quantity,
|
||||
filled_avg_price=execution_price,
|
||||
filled_at=datetime.combine(trading_date, datetime.min.time()),
|
||||
commission=commission,
|
||||
)
|
||||
|
||||
fill = Fill(
|
||||
order_id=order.id,
|
||||
ticker=ticker,
|
||||
side=OrderSide.SELL,
|
||||
quantity=quantity,
|
||||
price=execution_price,
|
||||
commission=commission,
|
||||
timestamp=datetime.combine(trading_date, datetime.min.time()),
|
||||
)
|
||||
|
||||
self.portfolio.apply_fill(fill)
|
||||
|
||||
if ticker in self.open_trades:
|
||||
trade = self.open_trades[ticker]
|
||||
trade.exit_price = execution_price
|
||||
trade.exit_quantity = quantity
|
||||
trade.exit_time = datetime.combine(trading_date, datetime.min.time())
|
||||
trade.exit_order_id = order.id
|
||||
trade.commission = (
|
||||
config.calculate_commission(trade.entry_quantity, trade.entry_price) +
|
||||
commission
|
||||
)
|
||||
self.trade_log.add_trade(trade)
|
||||
del self.open_trades[ticker]
|
||||
|
||||
logger.debug(
|
||||
"SELL %s: %d shares @ $%.2f on %s",
|
||||
ticker, quantity, execution_price, trading_date
|
||||
)
|
||||
|
||||
def _record_equity(self, trading_date: date, prices: dict[str, Decimal]):
|
||||
equity = self.portfolio.total_equity(prices)
|
||||
positions_value = self.portfolio.positions_value(prices)
|
||||
|
||||
point = EquityCurvePoint(
|
||||
timestamp=datetime.combine(trading_date, datetime.min.time()),
|
||||
equity=equity,
|
||||
cash=self.portfolio.cash,
|
||||
positions_value=positions_value,
|
||||
)
|
||||
self.equity_curve.append(point)
|
||||
|
||||
if len(self.equity_curve) > 1:
|
||||
prev_equity = self.equity_curve[-2].equity
|
||||
if prev_equity > 0:
|
||||
daily_return = (equity - prev_equity) / prev_equity
|
||||
self.daily_returns.append(daily_return)
|
||||
|
||||
def _close_all_positions(self, final_date: date):
|
||||
prices = self.data_loader.get_prices_dict(self.config.tickers, final_date)
|
||||
|
||||
for ticker, trade in list(self.open_trades.items()):
|
||||
if ticker in prices:
|
||||
decision = TradingDecision(
|
||||
ticker=ticker,
|
||||
timestamp=datetime.now(),
|
||||
decision_date=datetime.combine(final_date, datetime.min.time()),
|
||||
signal=SignalType.SELL,
|
||||
confidence=Decimal("1.0"),
|
||||
recommended_action="SELL",
|
||||
final_decision="SELL - End of backtest",
|
||||
rationale="Closing position at end of backtest period",
|
||||
)
|
||||
self._execute_decision(decision, prices[ticker], final_date)
|
||||
|
||||
def _empty_metrics(self):
|
||||
from tradingagents.models.backtest import BacktestMetrics
|
||||
return BacktestMetrics(
|
||||
start_equity=self.config.portfolio_config.initial_cash,
|
||||
end_equity=self.portfolio.cash if self.portfolio else self.config.portfolio_config.initial_cash,
|
||||
)
|
||||
|
||||
|
||||
class SimpleBacktestEngine(BacktestEngine):
|
||||
def __init__(
|
||||
self,
|
||||
config: BacktestConfig,
|
||||
buy_signal: Callable[[str, date, dict], bool] = None,
|
||||
sell_signal: Callable[[str, date, dict], bool] = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.buy_signal = buy_signal
|
||||
self.sell_signal = sell_signal
|
||||
|
||||
def _get_decision(
|
||||
self,
|
||||
ticker: str,
|
||||
trading_date: date,
|
||||
day_index: int,
|
||||
) -> Optional[TradingDecision]:
|
||||
context = {
|
||||
"day_index": day_index,
|
||||
"portfolio": self.portfolio,
|
||||
"data_loader": self.data_loader,
|
||||
"open_trade": self.open_trades.get(ticker),
|
||||
}
|
||||
|
||||
position = self.portfolio.get_position(ticker)
|
||||
|
||||
if position.quantity == 0 and self.buy_signal and self.buy_signal(ticker, trading_date, context):
|
||||
return TradingDecision(
|
||||
ticker=ticker,
|
||||
timestamp=datetime.now(),
|
||||
decision_date=datetime.combine(trading_date, datetime.min.time()),
|
||||
signal=SignalType.BUY,
|
||||
confidence=Decimal("0.7"),
|
||||
recommended_action="BUY",
|
||||
final_decision="BUY",
|
||||
rationale="Buy signal triggered",
|
||||
)
|
||||
|
||||
if position.quantity > 0 and self.sell_signal and self.sell_signal(ticker, trading_date, context):
|
||||
return TradingDecision(
|
||||
ticker=ticker,
|
||||
timestamp=datetime.now(),
|
||||
decision_date=datetime.combine(trading_date, datetime.min.time()),
|
||||
signal=SignalType.SELL,
|
||||
confidence=Decimal("0.7"),
|
||||
recommended_action="SELL",
|
||||
final_decision="SELL",
|
||||
rationale="Sell signal triggered",
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
@ -0,0 +1,281 @@
|
|||
import math
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
|
||||
from tradingagents.models.backtest import BacktestMetrics, EquityCurvePoint, TradeLog
|
||||
|
||||
|
||||
class MetricsCalculator:
|
||||
TRADING_DAYS_PER_YEAR = 252
|
||||
|
||||
def __init__(self, risk_free_rate: Decimal = Decimal("0.05")):
|
||||
self.risk_free_rate = risk_free_rate
|
||||
|
||||
def calculate_metrics(
|
||||
self,
|
||||
equity_curve: list[EquityCurvePoint],
|
||||
trade_log: TradeLog,
|
||||
benchmark_curve: Optional[list[EquityCurvePoint]] = None,
|
||||
) -> BacktestMetrics:
|
||||
if not equity_curve:
|
||||
raise ValueError("Equity curve cannot be empty")
|
||||
|
||||
start_equity = equity_curve[0].equity
|
||||
end_equity = equity_curve[-1].equity
|
||||
trading_days = len(equity_curve)
|
||||
|
||||
total_return = end_equity - start_equity
|
||||
total_return_percent = (total_return / start_equity) * 100
|
||||
|
||||
years = Decimal(trading_days) / Decimal(self.TRADING_DAYS_PER_YEAR)
|
||||
if years > 0:
|
||||
annualized_return = ((end_equity / start_equity) ** (1 / years) - 1) * 100
|
||||
else:
|
||||
annualized_return = Decimal("0")
|
||||
|
||||
daily_returns = self._calculate_daily_returns(equity_curve)
|
||||
volatility = self._calculate_volatility(daily_returns)
|
||||
annualized_volatility = volatility * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR))
|
||||
|
||||
downside_returns = [r for r in daily_returns if r < 0]
|
||||
downside_volatility = self._calculate_volatility(downside_returns) if downside_returns else Decimal("0")
|
||||
annualized_downside_vol = downside_volatility * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR))
|
||||
|
||||
max_dd, max_dd_pct, max_dd_duration, avg_dd = self._calculate_drawdown_metrics(equity_curve)
|
||||
|
||||
sharpe = self._calculate_sharpe_ratio(annualized_return, annualized_volatility)
|
||||
sortino = self._calculate_sortino_ratio(annualized_return, annualized_downside_vol)
|
||||
calmar = self._calculate_calmar_ratio(annualized_return, max_dd_pct)
|
||||
|
||||
benchmark_return = None
|
||||
benchmark_return_percent = None
|
||||
alpha = None
|
||||
beta = None
|
||||
information_ratio = None
|
||||
|
||||
if benchmark_curve and len(benchmark_curve) == len(equity_curve):
|
||||
benchmark_return = benchmark_curve[-1].equity - benchmark_curve[0].equity
|
||||
benchmark_return_percent = (benchmark_return / benchmark_curve[0].equity) * 100
|
||||
|
||||
benchmark_daily = self._calculate_daily_returns(benchmark_curve)
|
||||
alpha, beta = self._calculate_alpha_beta(daily_returns, benchmark_daily)
|
||||
information_ratio = self._calculate_information_ratio(
|
||||
daily_returns, benchmark_daily
|
||||
)
|
||||
|
||||
all_pnls = [t.pnl for t in trade_log.trades if t.is_closed and t.pnl is not None]
|
||||
avg_trade_pnl = sum(all_pnls) / len(all_pnls) if all_pnls else None
|
||||
largest_win = max((p for p in all_pnls if p > 0), default=None)
|
||||
largest_loss = min((p for p in all_pnls if p < 0), default=None)
|
||||
|
||||
return BacktestMetrics(
|
||||
total_return=total_return,
|
||||
total_return_percent=total_return_percent,
|
||||
annualized_return=annualized_return,
|
||||
benchmark_return=benchmark_return,
|
||||
benchmark_return_percent=benchmark_return_percent,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
volatility=volatility * 100,
|
||||
annualized_volatility=annualized_volatility * 100,
|
||||
downside_volatility=annualized_downside_vol * 100,
|
||||
sharpe_ratio=sharpe,
|
||||
sortino_ratio=sortino,
|
||||
calmar_ratio=calmar,
|
||||
information_ratio=information_ratio,
|
||||
max_drawdown=max_dd,
|
||||
max_drawdown_percent=max_dd_pct,
|
||||
max_drawdown_duration=max_dd_duration,
|
||||
avg_drawdown=avg_dd,
|
||||
total_trades=trade_log.total_trades,
|
||||
win_rate=trade_log.win_rate,
|
||||
profit_factor=trade_log.profit_factor,
|
||||
avg_trade_pnl=avg_trade_pnl,
|
||||
avg_win=trade_log.avg_win,
|
||||
avg_loss=trade_log.avg_loss,
|
||||
largest_win=largest_win,
|
||||
largest_loss=largest_loss,
|
||||
avg_holding_period_days=trade_log.avg_holding_period,
|
||||
trading_days=trading_days,
|
||||
start_equity=start_equity,
|
||||
end_equity=end_equity,
|
||||
)
|
||||
|
||||
def _calculate_daily_returns(
|
||||
self,
|
||||
equity_curve: list[EquityCurvePoint],
|
||||
) -> list[Decimal]:
|
||||
returns = []
|
||||
for i in range(1, len(equity_curve)):
|
||||
prev_equity = equity_curve[i - 1].equity
|
||||
curr_equity = equity_curve[i].equity
|
||||
if prev_equity > 0:
|
||||
daily_return = (curr_equity - prev_equity) / prev_equity
|
||||
returns.append(daily_return)
|
||||
return returns
|
||||
|
||||
def _calculate_volatility(self, returns: list[Decimal]) -> Decimal:
|
||||
if len(returns) < 2:
|
||||
return Decimal("0")
|
||||
|
||||
mean = sum(returns) / len(returns)
|
||||
variance = sum((r - mean) ** 2 for r in returns) / (len(returns) - 1)
|
||||
return Decimal(str(math.sqrt(float(variance))))
|
||||
|
||||
def _calculate_drawdown_metrics(
|
||||
self,
|
||||
equity_curve: list[EquityCurvePoint],
|
||||
) -> tuple[Decimal, Decimal, Optional[int], Decimal]:
|
||||
if not equity_curve:
|
||||
return Decimal("0"), Decimal("0"), None, Decimal("0")
|
||||
|
||||
peak = equity_curve[0].equity
|
||||
max_drawdown = Decimal("0")
|
||||
max_drawdown_percent = Decimal("0")
|
||||
drawdown_start = 0
|
||||
max_drawdown_duration = 0
|
||||
current_drawdown_start = 0
|
||||
in_drawdown = False
|
||||
drawdowns = []
|
||||
|
||||
for i, point in enumerate(equity_curve):
|
||||
equity = point.equity
|
||||
|
||||
if equity > peak:
|
||||
if in_drawdown:
|
||||
duration = i - current_drawdown_start
|
||||
max_drawdown_duration = max(max_drawdown_duration, duration)
|
||||
in_drawdown = False
|
||||
peak = equity
|
||||
current_drawdown_start = i
|
||||
else:
|
||||
if not in_drawdown:
|
||||
in_drawdown = True
|
||||
current_drawdown_start = i
|
||||
|
||||
drawdown = peak - equity
|
||||
drawdown_pct = (drawdown / peak) * 100 if peak > 0 else Decimal("0")
|
||||
drawdowns.append(drawdown_pct)
|
||||
|
||||
if drawdown > max_drawdown:
|
||||
max_drawdown = drawdown
|
||||
max_drawdown_percent = drawdown_pct
|
||||
|
||||
point.drawdown = drawdown
|
||||
point.drawdown_percent = drawdown_pct
|
||||
|
||||
if in_drawdown:
|
||||
duration = len(equity_curve) - current_drawdown_start
|
||||
max_drawdown_duration = max(max_drawdown_duration, duration)
|
||||
|
||||
avg_drawdown = sum(drawdowns) / len(drawdowns) if drawdowns else Decimal("0")
|
||||
|
||||
return max_drawdown, max_drawdown_percent, max_drawdown_duration or None, avg_drawdown
|
||||
|
||||
def _calculate_sharpe_ratio(
|
||||
self,
|
||||
annualized_return: Decimal,
|
||||
annualized_volatility: Decimal,
|
||||
) -> Optional[Decimal]:
|
||||
if annualized_volatility == 0:
|
||||
return None
|
||||
|
||||
excess_return = annualized_return - (self.risk_free_rate * 100)
|
||||
return excess_return / annualized_volatility
|
||||
|
||||
def _calculate_sortino_ratio(
|
||||
self,
|
||||
annualized_return: Decimal,
|
||||
annualized_downside_vol: Decimal,
|
||||
) -> Optional[Decimal]:
|
||||
if annualized_downside_vol == 0:
|
||||
return None
|
||||
|
||||
excess_return = annualized_return - (self.risk_free_rate * 100)
|
||||
return excess_return / annualized_downside_vol
|
||||
|
||||
def _calculate_calmar_ratio(
|
||||
self,
|
||||
annualized_return: Decimal,
|
||||
max_drawdown_percent: Decimal,
|
||||
) -> Optional[Decimal]:
|
||||
if max_drawdown_percent == 0:
|
||||
return None
|
||||
|
||||
return annualized_return / max_drawdown_percent
|
||||
|
||||
def _calculate_alpha_beta(
|
||||
self,
|
||||
returns: list[Decimal],
|
||||
benchmark_returns: list[Decimal],
|
||||
) -> tuple[Optional[Decimal], Optional[Decimal]]:
|
||||
if len(returns) != len(benchmark_returns) or len(returns) < 2:
|
||||
return None, None
|
||||
|
||||
n = len(returns)
|
||||
sum_x = sum(benchmark_returns)
|
||||
sum_y = sum(returns)
|
||||
sum_xy = sum(r * b for r, b in zip(returns, benchmark_returns))
|
||||
sum_xx = sum(b * b for b in benchmark_returns)
|
||||
|
||||
denominator = n * sum_xx - sum_x * sum_x
|
||||
if denominator == 0:
|
||||
return None, None
|
||||
|
||||
beta = (n * sum_xy - sum_x * sum_y) / denominator
|
||||
alpha = (sum_y - beta * sum_x) / n
|
||||
|
||||
alpha_annualized = alpha * self.TRADING_DAYS_PER_YEAR
|
||||
|
||||
return alpha_annualized, beta
|
||||
|
||||
def _calculate_information_ratio(
|
||||
self,
|
||||
returns: list[Decimal],
|
||||
benchmark_returns: list[Decimal],
|
||||
) -> Optional[Decimal]:
|
||||
if len(returns) != len(benchmark_returns) or len(returns) < 2:
|
||||
return None
|
||||
|
||||
excess_returns = [r - b for r, b in zip(returns, benchmark_returns)]
|
||||
mean_excess = sum(excess_returns) / len(excess_returns)
|
||||
tracking_error = self._calculate_volatility(excess_returns)
|
||||
|
||||
if tracking_error == 0:
|
||||
return None
|
||||
|
||||
annualized_tracking_error = tracking_error * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR))
|
||||
annualized_excess = mean_excess * self.TRADING_DAYS_PER_YEAR
|
||||
|
||||
return annualized_excess / annualized_tracking_error
|
||||
|
||||
def calculate_rolling_metrics(
|
||||
self,
|
||||
equity_curve: list[EquityCurvePoint],
|
||||
window: int = 20,
|
||||
) -> dict[str, list[Decimal]]:
|
||||
if len(equity_curve) < window:
|
||||
return {"rolling_sharpe": [], "rolling_volatility": []}
|
||||
|
||||
rolling_sharpe = []
|
||||
rolling_volatility = []
|
||||
|
||||
daily_returns = self._calculate_daily_returns(equity_curve)
|
||||
|
||||
for i in range(window - 1, len(daily_returns)):
|
||||
window_returns = daily_returns[i - window + 1:i + 1]
|
||||
vol = self._calculate_volatility(window_returns)
|
||||
annualized_vol = vol * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR))
|
||||
|
||||
mean_return = sum(window_returns) / len(window_returns)
|
||||
annualized_return = mean_return * self.TRADING_DAYS_PER_YEAR * 100
|
||||
|
||||
sharpe = self._calculate_sharpe_ratio(annualized_return, annualized_vol * 100)
|
||||
|
||||
rolling_sharpe.append(sharpe if sharpe else Decimal("0"))
|
||||
rolling_volatility.append(annualized_vol * 100)
|
||||
|
||||
return {
|
||||
"rolling_sharpe": rolling_sharpe,
|
||||
"rolling_volatility": rolling_volatility,
|
||||
}
|
||||
|
|
@ -53,7 +53,7 @@ def _make_api_request(function_name: str, params: dict) -> dict | str:
|
|||
elif "entitlement" in api_params:
|
||||
api_params.pop("entitlement", None)
|
||||
|
||||
response = requests.get(API_BASE_URL, params=api_params)
|
||||
response = requests.get(API_BASE_URL, params=api_params, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
response_text = response.text
|
||||
|
|
@ -88,6 +88,6 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) ->
|
|||
|
||||
return filtered_df.to_csv(index=False)
|
||||
|
||||
except Exception as e:
|
||||
except (pd.errors.ParserError, KeyError, ValueError) as e:
|
||||
logger.warning("Failed to filter CSV data by date range: %s", e)
|
||||
return csv_data
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import requests
|
||||
from .alpha_vantage_common import _make_api_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -193,6 +194,6 @@ def get_indicator(
|
|||
|
||||
return result_str
|
||||
|
||||
except Exception as e:
|
||||
except (ValueError, KeyError, IndexError, requests.RequestException) as e:
|
||||
logger.error("Error getting Alpha Vantage indicator data for %s: %s", indicator, e)
|
||||
return f"Error retrieving {indicator} data: {str(e)}"
|
||||
|
|
|
|||
|
|
@ -77,7 +77,8 @@ def get_bulk_news_alpha_vantage(lookback_hours: int) -> List[Dict[str, Any]]:
|
|||
"content_snippet": item.get("summary", "")[:500],
|
||||
}
|
||||
articles.append(article)
|
||||
except Exception:
|
||||
except (KeyError, TypeError, AttributeError) as e:
|
||||
logger.debug("Error parsing news article: %s", e)
|
||||
continue
|
||||
|
||||
return articles
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ def get_bulk_news_brave(lookback_hours: int) -> List[Dict[str, Any]]:
|
|||
except requests.exceptions.RequestException as e:
|
||||
logger.debug("Brave search request failed for '%s': %s", query, e)
|
||||
continue
|
||||
except Exception as e:
|
||||
except (KeyError, TypeError, ValueError) as e:
|
||||
logger.debug("Brave search failed for query '%s': %s", query, e)
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
import logging
|
||||
import re
|
||||
import requests
|
||||
from typing import Annotated, List, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from dateutil import parser as dateutil_parser
|
||||
from .googlenews_utils import getNewsData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_google_news_date(date_str: str) -> datetime:
|
||||
if not date_str:
|
||||
|
|
@ -108,7 +112,8 @@ def get_bulk_news_google(lookback_hours: int) -> List[Dict[str, Any]]:
|
|||
}
|
||||
all_articles.append(article)
|
||||
|
||||
except Exception:
|
||||
except (TypeError, KeyError, AttributeError, requests.RequestException) as e:
|
||||
logger.debug("Google News search failed for query '%s': %s", query, e)
|
||||
continue
|
||||
|
||||
return all_articles
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def is_rate_limited(response):
|
|||
)
|
||||
def make_request(url, headers):
|
||||
time.sleep(random.uniform(2, 6))
|
||||
response = requests.get(url, headers=headers)
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
return response
|
||||
|
||||
|
||||
|
|
@ -81,7 +81,7 @@ def getNewsData(query, start_date, end_date):
|
|||
"source": source,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
except (TypeError, AttributeError, KeyError) as e:
|
||||
logger.debug("Error processing result: %s", e)
|
||||
continue
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ def getNewsData(query, start_date, end_date):
|
|||
|
||||
page += 1
|
||||
|
||||
except Exception as e:
|
||||
except (requests.RequestException, ConnectionError, TimeoutError) as e:
|
||||
logger.debug("Failed after multiple retries: %s", e)
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -191,7 +191,8 @@ def _convert_to_news_articles(raw_articles: List[Dict[str, Any]]) -> List[NewsAr
|
|||
ticker_mentions=[],
|
||||
)
|
||||
articles.append(article)
|
||||
except Exception:
|
||||
except (KeyError, TypeError, ValueError) as e:
|
||||
logger.debug("Error converting article to NewsArticle: %s", e)
|
||||
continue
|
||||
return articles
|
||||
|
||||
|
|
@ -218,7 +219,7 @@ def _fetch_bulk_news_from_vendor(lookback_period: str) -> List[Dict[str, Any]]:
|
|||
except AlphaVantageRateLimitError as e:
|
||||
logger.warning("Alpha Vantage rate limit exceeded: %s", e)
|
||||
continue
|
||||
except Exception as e:
|
||||
except (RuntimeError, ConnectionError, TimeoutError, ValueError, OSError) as e:
|
||||
logger.error("Vendor '%s' failed: %s", vendor, e)
|
||||
continue
|
||||
|
||||
|
|
@ -316,7 +317,7 @@ def route_to_vendor(method: str, *args, **kwargs):
|
|||
logger.warning("Alpha Vantage rate limit exceeded, falling back to next available vendor")
|
||||
logger.debug("Rate limit details: %s", e)
|
||||
continue
|
||||
except Exception as e:
|
||||
except (RuntimeError, ConnectionError, TimeoutError, ValueError, KeyError, OSError) as e:
|
||||
logger.error("%s from vendor '%s' failed: %s", impl_func.__name__, vendor_name, e)
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -170,8 +170,8 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=
|
|||
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json"
|
||||
)
|
||||
|
||||
data = open(data_path, "r")
|
||||
data = json.load(data)
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
filtered_data = {}
|
||||
for key, value in data.items():
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ def get_stock_news_openai(query, start_date, end_date):
|
|||
temperature=1,
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
store=False,
|
||||
)
|
||||
|
||||
return _extract_response_text(response) or ""
|
||||
|
|
@ -89,7 +89,7 @@ def get_global_news_openai(curr_date, look_back_days=7, limit=5):
|
|||
temperature=1,
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
store=False,
|
||||
)
|
||||
|
||||
return _extract_response_text(response) or ""
|
||||
|
|
@ -124,7 +124,7 @@ def get_fundamentals_openai(ticker, curr_date):
|
|||
temperature=1,
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
store=False,
|
||||
)
|
||||
|
||||
return _extract_response_text(response) or ""
|
||||
|
|
@ -187,7 +187,7 @@ Return ONLY the JSON array, no additional text."""
|
|||
temperature=0.5,
|
||||
max_output_tokens=8192,
|
||||
top_p=1,
|
||||
store=True,
|
||||
store=False,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def _search_with_retry(client, query: str, search_depth: str, topic: str, time_r
|
|||
max_results=max_results,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
except (RuntimeError, ConnectionError, TimeoutError, OSError) as e:
|
||||
error_str = str(e).lower()
|
||||
if "rate" in error_str or "limit" in error_str or "429" in error_str:
|
||||
wait_time = RETRY_BACKOFF * (attempt + 1) * 2
|
||||
|
|
@ -123,7 +123,7 @@ def get_bulk_news_tavily(lookback_hours: int) -> List[Dict[str, Any]]:
|
|||
}
|
||||
all_articles.append(article)
|
||||
|
||||
except Exception as e:
|
||||
except (RuntimeError, ConnectionError, TimeoutError, OSError, ValueError) as e:
|
||||
logger.debug("Tavily search failed for query '%s': %s", query, e)
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -261,7 +261,7 @@ def classify_sector(ticker: str) -> str:
|
|||
_sector_cache[ticker_upper] = sector
|
||||
logger.info("Classified %s as %s via LLM", ticker, sector)
|
||||
return sector
|
||||
except Exception as e:
|
||||
except (KeyError, ValueError, RuntimeError, ConnectionError, TimeoutError) as e:
|
||||
logger.error("LLM sector classification failed for %s: %s", ticker, str(e))
|
||||
_sector_cache[ticker_upper] = "other"
|
||||
return "other"
|
||||
|
|
|
|||
|
|
@ -462,7 +462,7 @@ def _search_yfinance_ticker(company_name: str) -> Optional[str]:
|
|||
info = search_result.info
|
||||
if info and "symbol" in info:
|
||||
return info["symbol"]
|
||||
except Exception as e:
|
||||
except (KeyError, ValueError, AttributeError, RuntimeError) as e:
|
||||
logger.debug("yfinance search failed for %s: %s", company_name, str(e))
|
||||
|
||||
try:
|
||||
|
|
@ -471,7 +471,7 @@ def _search_yfinance_ticker(company_name: str) -> Optional[str]:
|
|||
for quote in search.quotes:
|
||||
if "symbol" in quote:
|
||||
return quote["symbol"]
|
||||
except Exception as e:
|
||||
except (KeyError, ValueError, AttributeError, RuntimeError) as e:
|
||||
logger.debug("yfinance Search failed for %s: %s", company_name, str(e))
|
||||
|
||||
return None
|
||||
|
|
@ -495,7 +495,7 @@ def validate_us_ticker(ticker: str) -> bool:
|
|||
|
||||
logger.warning("Validation failed for %s: exchange %s is not a US exchange", ticker, exchange)
|
||||
return False
|
||||
except Exception as e:
|
||||
except (KeyError, ValueError, AttributeError, RuntimeError) as e:
|
||||
logger.warning("Validation failed for %s: %s", ticker, str(e))
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ def get_stock_stats_indicators_window(
|
|||
for date_str, value in date_values:
|
||||
ind_string += f"{date_str}: {value}\n"
|
||||
|
||||
except Exception as e:
|
||||
except (KeyError, ValueError, FileNotFoundError) as e:
|
||||
logger.error("Error getting bulk stockstats data: %s", e)
|
||||
ind_string = ""
|
||||
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
|
|
@ -260,7 +260,7 @@ def get_stockstats_indicator(
|
|||
indicator,
|
||||
curr_date,
|
||||
)
|
||||
except Exception as e:
|
||||
except (KeyError, ValueError, IndexError) as e:
|
||||
logger.error(
|
||||
"Error getting stockstats indicator data for indicator %s on %s: %s",
|
||||
indicator, curr_date, e
|
||||
|
|
@ -293,7 +293,7 @@ def get_balance_sheet(
|
|||
|
||||
return header + csv_string
|
||||
|
||||
except Exception as e:
|
||||
except (ValueError, KeyError, AttributeError) as e:
|
||||
return f"Error retrieving balance sheet for {ticker}: {str(e)}"
|
||||
|
||||
|
||||
|
|
@ -320,7 +320,7 @@ def get_cashflow(
|
|||
|
||||
return header + csv_string
|
||||
|
||||
except Exception as e:
|
||||
except (ValueError, KeyError, AttributeError) as e:
|
||||
return f"Error retrieving cash flow for {ticker}: {str(e)}"
|
||||
|
||||
|
||||
|
|
@ -347,7 +347,7 @@ def get_income_statement(
|
|||
|
||||
return header + csv_string
|
||||
|
||||
except Exception as e:
|
||||
except (ValueError, KeyError, AttributeError) as e:
|
||||
return f"Error retrieving income statement for {ticker}: {str(e)}"
|
||||
|
||||
|
||||
|
|
@ -368,5 +368,5 @@ def get_insider_transactions(
|
|||
|
||||
return header + csv_string
|
||||
|
||||
except Exception as e:
|
||||
except (ValueError, KeyError, AttributeError) as e:
|
||||
return f"Error retrieving insider transactions for {ticker}: {str(e)}"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import os
|
|||
DEFAULT_CONFIG = {
|
||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
|
||||
"data_dir": os.getenv("TRADINGAGENTS_DATA_DIR", "./data"),
|
||||
"data_cache_dir": os.path.join(
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"dataflows/data_cache",
|
||||
|
|
|
|||
|
|
@ -280,7 +280,7 @@ class TradingAgentsGraph:
|
|||
)
|
||||
|
||||
discovery_result["stocks"] = trending_stocks
|
||||
except Exception as e:
|
||||
except (ValueError, KeyError, RuntimeError, ConnectionError, TimeoutError) as e:
|
||||
discovery_result["error"] = str(e)
|
||||
|
||||
discovery_thread = threading.Thread(target=run_discovery)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
from .market_data import (
|
||||
OHLCV,
|
||||
OHLCVBar,
|
||||
TechnicalIndicators,
|
||||
MarketSnapshot,
|
||||
HistoricalDataRequest,
|
||||
HistoricalDataResponse,
|
||||
)
|
||||
from .trading import (
|
||||
OrderSide,
|
||||
OrderType,
|
||||
OrderStatus,
|
||||
PositionSide,
|
||||
Order,
|
||||
Fill,
|
||||
Position,
|
||||
Trade,
|
||||
)
|
||||
from .portfolio import (
|
||||
PortfolioSnapshot,
|
||||
PortfolioConfig,
|
||||
CashTransaction,
|
||||
TransactionType,
|
||||
)
|
||||
from .backtest import (
|
||||
BacktestConfig,
|
||||
BacktestResult,
|
||||
BacktestMetrics,
|
||||
EquityCurvePoint,
|
||||
TradeLog,
|
||||
)
|
||||
from .decisions import (
|
||||
SignalType,
|
||||
TradingSignal,
|
||||
TradingDecision,
|
||||
RiskAssessment,
|
||||
AnalystReport,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OHLCV",
|
||||
"OHLCVBar",
|
||||
"TechnicalIndicators",
|
||||
"MarketSnapshot",
|
||||
"HistoricalDataRequest",
|
||||
"HistoricalDataResponse",
|
||||
"OrderSide",
|
||||
"OrderType",
|
||||
"OrderStatus",
|
||||
"PositionSide",
|
||||
"Order",
|
||||
"Fill",
|
||||
"Position",
|
||||
"Trade",
|
||||
"PortfolioSnapshot",
|
||||
"PortfolioConfig",
|
||||
"CashTransaction",
|
||||
"TransactionType",
|
||||
"BacktestConfig",
|
||||
"BacktestResult",
|
||||
"BacktestMetrics",
|
||||
"EquityCurvePoint",
|
||||
"TradeLog",
|
||||
"SignalType",
|
||||
"TradingSignal",
|
||||
"TradingDecision",
|
||||
"RiskAssessment",
|
||||
"AnalystReport",
|
||||
]
|
||||
|
|
@ -0,0 +1,242 @@
|
|||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
|
||||
from .portfolio import PortfolioConfig
|
||||
from .trading import Trade
|
||||
|
||||
|
||||
class BacktestConfig(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
name: str = Field(default="Backtest")
|
||||
description: Optional[str] = None
|
||||
|
||||
tickers: list[str] = Field(min_length=1)
|
||||
start_date: date
|
||||
end_date: date
|
||||
interval: str = Field(default="1d")
|
||||
|
||||
portfolio_config: PortfolioConfig = Field(default_factory=PortfolioConfig)
|
||||
|
||||
warmup_period: int = Field(default=20, ge=0)
|
||||
rebalance_frequency: Optional[str] = Field(default=None)
|
||||
|
||||
use_agent_pipeline: bool = Field(default=True)
|
||||
agent_config: dict = Field(default_factory=dict)
|
||||
|
||||
benchmark_ticker: Optional[str] = Field(default="SPY")
|
||||
risk_free_rate: Decimal = Field(default=Decimal("0.05"), ge=0)
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
@field_validator("end_date")
|
||||
@classmethod
|
||||
def end_after_start(cls, v: date, info) -> date:
|
||||
if "start_date" in info.data and v <= info.data["start_date"]:
|
||||
raise ValueError("end_date must be > start_date")
|
||||
return v
|
||||
|
||||
@field_validator("tickers")
|
||||
@classmethod
|
||||
def validate_tickers(cls, v: list[str]) -> list[str]:
|
||||
return [t.upper().strip() for t in v]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def trading_days_estimate(self) -> int:
|
||||
delta = self.end_date - self.start_date
|
||||
return int(delta.days * 252 / 365)
|
||||
|
||||
|
||||
class EquityCurvePoint(BaseModel):
|
||||
timestamp: datetime
|
||||
equity: Decimal
|
||||
cash: Decimal
|
||||
positions_value: Decimal
|
||||
benchmark_value: Optional[Decimal] = None
|
||||
drawdown: Decimal = Field(default=Decimal("0"))
|
||||
drawdown_percent: Decimal = Field(default=Decimal("0"))
|
||||
|
||||
|
||||
class TradeLog(BaseModel):
|
||||
trades: list[Trade] = Field(default_factory=list)
|
||||
total_trades: int = Field(default=0, ge=0)
|
||||
winning_trades: int = Field(default=0, ge=0)
|
||||
losing_trades: int = Field(default=0, ge=0)
|
||||
break_even_trades: int = Field(default=0, ge=0)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def win_rate(self) -> Optional[Decimal]:
|
||||
if self.total_trades == 0:
|
||||
return None
|
||||
return Decimal(self.winning_trades) / Decimal(self.total_trades) * 100
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def loss_rate(self) -> Optional[Decimal]:
|
||||
if self.total_trades == 0:
|
||||
return None
|
||||
return Decimal(self.losing_trades) / Decimal(self.total_trades) * 100
|
||||
|
||||
def add_trade(self, trade: Trade) -> None:
|
||||
self.trades.append(trade)
|
||||
self.total_trades += 1
|
||||
if trade.is_closed and trade.pnl is not None:
|
||||
if trade.pnl > 0:
|
||||
self.winning_trades += 1
|
||||
elif trade.pnl < 0:
|
||||
self.losing_trades += 1
|
||||
else:
|
||||
self.break_even_trades += 1
|
||||
|
||||
@property
|
||||
def gross_profit(self) -> Decimal:
|
||||
return sum(
|
||||
t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl > 0
|
||||
) or Decimal("0")
|
||||
|
||||
@property
|
||||
def gross_loss(self) -> Decimal:
|
||||
return abs(
|
||||
sum(t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl < 0)
|
||||
or Decimal("0")
|
||||
)
|
||||
|
||||
@property
|
||||
def profit_factor(self) -> Optional[Decimal]:
|
||||
if self.gross_loss == 0:
|
||||
return None
|
||||
return self.gross_profit / self.gross_loss
|
||||
|
||||
@property
|
||||
def avg_win(self) -> Optional[Decimal]:
|
||||
wins = [t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl > 0]
|
||||
if not wins:
|
||||
return None
|
||||
return sum(wins) / len(wins)
|
||||
|
||||
@property
|
||||
def avg_loss(self) -> Optional[Decimal]:
|
||||
losses = [t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl < 0]
|
||||
if not losses:
|
||||
return None
|
||||
return sum(losses) / len(losses)
|
||||
|
||||
@property
|
||||
def avg_holding_period(self) -> Optional[float]:
|
||||
periods = [t.holding_period for t in self.trades if t.holding_period is not None]
|
||||
if not periods:
|
||||
return None
|
||||
return sum(periods) / len(periods)
|
||||
|
||||
|
||||
class BacktestMetrics(BaseModel):
|
||||
total_return: Decimal = Field(default=Decimal("0"))
|
||||
total_return_percent: Decimal = Field(default=Decimal("0"))
|
||||
annualized_return: Decimal = Field(default=Decimal("0"))
|
||||
|
||||
benchmark_return: Optional[Decimal] = None
|
||||
benchmark_return_percent: Optional[Decimal] = None
|
||||
alpha: Optional[Decimal] = None
|
||||
beta: Optional[Decimal] = None
|
||||
|
||||
volatility: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
annualized_volatility: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
downside_volatility: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
|
||||
sharpe_ratio: Optional[Decimal] = None
|
||||
sortino_ratio: Optional[Decimal] = None
|
||||
calmar_ratio: Optional[Decimal] = None
|
||||
information_ratio: Optional[Decimal] = None
|
||||
|
||||
max_drawdown: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
max_drawdown_percent: Decimal = Field(default=Decimal("0"), ge=0, le=100)
|
||||
max_drawdown_duration: Optional[int] = None
|
||||
avg_drawdown: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
|
||||
total_trades: int = Field(default=0, ge=0)
|
||||
win_rate: Optional[Decimal] = Field(default=None, ge=0, le=100)
|
||||
profit_factor: Optional[Decimal] = None
|
||||
avg_trade_pnl: Optional[Decimal] = None
|
||||
avg_win: Optional[Decimal] = None
|
||||
avg_loss: Optional[Decimal] = None
|
||||
largest_win: Optional[Decimal] = None
|
||||
largest_loss: Optional[Decimal] = None
|
||||
avg_holding_period_days: Optional[float] = None
|
||||
|
||||
total_commission: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
total_slippage: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
|
||||
trading_days: int = Field(default=0, ge=0)
|
||||
start_equity: Decimal = Field(gt=0)
|
||||
end_equity: Decimal = Field(gt=0)
|
||||
|
||||
def to_summary_dict(self) -> dict:
|
||||
return {
|
||||
"Performance": {
|
||||
"Total Return": f"{self.total_return_percent:.2f}%",
|
||||
"Annualized Return": f"{self.annualized_return:.2f}%",
|
||||
"Sharpe Ratio": f"{self.sharpe_ratio:.2f}" if self.sharpe_ratio else "N/A",
|
||||
"Sortino Ratio": f"{self.sortino_ratio:.2f}" if self.sortino_ratio else "N/A",
|
||||
"Max Drawdown": f"{self.max_drawdown_percent:.2f}%",
|
||||
},
|
||||
"Risk": {
|
||||
"Volatility (Ann.)": f"{self.annualized_volatility:.2f}%",
|
||||
"Calmar Ratio": f"{self.calmar_ratio:.2f}" if self.calmar_ratio else "N/A",
|
||||
"Beta": f"{self.beta:.2f}" if self.beta else "N/A",
|
||||
},
|
||||
"Trading": {
|
||||
"Total Trades": self.total_trades,
|
||||
"Win Rate": f"{self.win_rate:.1f}%" if self.win_rate else "N/A",
|
||||
"Profit Factor": f"{self.profit_factor:.2f}" if self.profit_factor else "N/A",
|
||||
"Avg Holding Period": f"{self.avg_holding_period_days:.1f} days" if self.avg_holding_period_days else "N/A",
|
||||
},
|
||||
"Costs": {
|
||||
"Total Commission": f"${self.total_commission:.2f}",
|
||||
"Total Slippage": f"${self.total_slippage:.2f}",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class BacktestResult(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
config: BacktestConfig
|
||||
metrics: BacktestMetrics
|
||||
trade_log: TradeLog
|
||||
equity_curve: list[EquityCurvePoint] = Field(default_factory=list)
|
||||
daily_returns: list[Decimal] = Field(default_factory=list)
|
||||
started_at: datetime
|
||||
completed_at: datetime
|
||||
status: str = Field(default="completed")
|
||||
error_message: Optional[str] = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def duration_seconds(self) -> float:
|
||||
return (self.completed_at - self.started_at).total_seconds()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"config": {
|
||||
"name": self.config.name,
|
||||
"tickers": self.config.tickers,
|
||||
"start_date": self.config.start_date.isoformat(),
|
||||
"end_date": self.config.end_date.isoformat(),
|
||||
"initial_cash": float(self.config.portfolio_config.initial_cash),
|
||||
},
|
||||
"metrics": self.metrics.to_summary_dict(),
|
||||
"trade_summary": {
|
||||
"total_trades": self.trade_log.total_trades,
|
||||
"winning_trades": self.trade_log.winning_trades,
|
||||
"losing_trades": self.trade_log.losing_trades,
|
||||
"win_rate": float(self.trade_log.win_rate) if self.trade_log.win_rate else None,
|
||||
},
|
||||
"duration_seconds": self.duration_seconds,
|
||||
"status": self.status,
|
||||
}
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SignalType(str, Enum):
|
||||
STRONG_BUY = "strong_buy"
|
||||
BUY = "buy"
|
||||
HOLD = "hold"
|
||||
SELL = "sell"
|
||||
STRONG_SELL = "strong_sell"
|
||||
|
||||
|
||||
class AnalystType(str, Enum):
|
||||
MARKET = "market"
|
||||
SENTIMENT = "sentiment"
|
||||
NEWS = "news"
|
||||
FUNDAMENTALS = "fundamentals"
|
||||
|
||||
|
||||
class AnalystReport(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
analyst_type: AnalystType
|
||||
ticker: str
|
||||
report_date: datetime
|
||||
signal: Optional[SignalType] = None
|
||||
confidence: Decimal = Field(default=Decimal("0.5"), ge=0, le=1)
|
||||
summary: str
|
||||
key_findings: list[str] = Field(default_factory=list)
|
||||
raw_content: Optional[str] = None
|
||||
data_sources: list[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class TradingSignal(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
ticker: str
|
||||
timestamp: datetime
|
||||
signal: SignalType
|
||||
strength: Decimal = Field(ge=0, le=1)
|
||||
source: str
|
||||
timeframe: str = Field(default="1d")
|
||||
price_at_signal: Optional[Decimal] = None
|
||||
target_price: Optional[Decimal] = None
|
||||
stop_loss: Optional[Decimal] = None
|
||||
expiry: Optional[datetime] = None
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class RiskAssessment(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
ticker: str
|
||||
timestamp: datetime
|
||||
overall_risk_score: Decimal = Field(ge=0, le=1)
|
||||
|
||||
market_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1)
|
||||
liquidity_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1)
|
||||
volatility_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1)
|
||||
concentration_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1)
|
||||
event_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1)
|
||||
|
||||
max_position_size: Optional[Decimal] = None
|
||||
recommended_stop_loss: Optional[Decimal] = None
|
||||
var_95: Optional[Decimal] = None
|
||||
expected_shortfall: Optional[Decimal] = None
|
||||
|
||||
risk_factors: list[str] = Field(default_factory=list)
|
||||
mitigations: list[str] = Field(default_factory=list)
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class TradingDecision(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
ticker: str
|
||||
timestamp: datetime
|
||||
decision_date: datetime
|
||||
|
||||
signal: SignalType
|
||||
confidence: Decimal = Field(ge=0, le=1)
|
||||
|
||||
recommended_action: str
|
||||
recommended_quantity: Optional[int] = None
|
||||
recommended_price: Optional[Decimal] = None
|
||||
stop_loss: Optional[Decimal] = None
|
||||
take_profit: Optional[Decimal] = None
|
||||
|
||||
analyst_reports: list[AnalystReport] = Field(default_factory=list)
|
||||
signals: list[TradingSignal] = Field(default_factory=list)
|
||||
risk_assessment: Optional[RiskAssessment] = None
|
||||
|
||||
bull_argument: Optional[str] = None
|
||||
bear_argument: Optional[str] = None
|
||||
debate_rounds: int = Field(default=0, ge=0)
|
||||
debate_winner: Optional[str] = None
|
||||
|
||||
risk_manager_approved: Optional[bool] = None
|
||||
risk_manager_notes: Optional[str] = None
|
||||
|
||||
final_decision: str
|
||||
rationale: str
|
||||
|
||||
execution_price: Optional[Decimal] = None
|
||||
executed_at: Optional[datetime] = None
|
||||
execution_notes: Optional[str] = None
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
@property
|
||||
def is_buy(self) -> bool:
|
||||
return self.signal in (SignalType.BUY, SignalType.STRONG_BUY)
|
||||
|
||||
@property
|
||||
def is_sell(self) -> bool:
|
||||
return self.signal in (SignalType.SELL, SignalType.STRONG_SELL)
|
||||
|
||||
@property
|
||||
def is_hold(self) -> bool:
|
||||
return self.signal == SignalType.HOLD
|
||||
|
||||
def get_analyst_report(self, analyst_type: AnalystType) -> Optional[AnalystReport]:
|
||||
for report in self.analyst_reports:
|
||||
if report.analyst_type == analyst_type:
|
||||
return report
|
||||
return None
|
||||
|
||||
def to_summary(self) -> dict:
|
||||
return {
|
||||
"ticker": self.ticker,
|
||||
"date": self.decision_date.isoformat(),
|
||||
"signal": self.signal.value,
|
||||
"confidence": float(self.confidence),
|
||||
"final_decision": self.final_decision,
|
||||
"risk_approved": self.risk_manager_approved,
|
||||
"debate_rounds": self.debate_rounds,
|
||||
"analyst_consensus": self._calculate_consensus(),
|
||||
}
|
||||
|
||||
def _calculate_consensus(self) -> Optional[str]:
|
||||
if not self.analyst_reports:
|
||||
return None
|
||||
|
||||
signals = [r.signal for r in self.analyst_reports if r.signal]
|
||||
if not signals:
|
||||
return None
|
||||
|
||||
buy_count = sum(1 for s in signals if s in (SignalType.BUY, SignalType.STRONG_BUY))
|
||||
sell_count = sum(1 for s in signals if s in (SignalType.SELL, SignalType.STRONG_SELL))
|
||||
|
||||
if buy_count > sell_count:
|
||||
return "bullish"
|
||||
elif sell_count > buy_count:
|
||||
return "bearish"
|
||||
return "neutral"
|
||||
|
|
@ -0,0 +1,144 @@
|
|||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class OHLCVBar(BaseModel):
|
||||
timestamp: datetime
|
||||
open: Decimal = Field(gt=0)
|
||||
high: Decimal = Field(gt=0)
|
||||
low: Decimal = Field(gt=0)
|
||||
close: Decimal = Field(gt=0)
|
||||
volume: int = Field(ge=0)
|
||||
adjusted_close: Optional[Decimal] = Field(default=None, gt=0)
|
||||
|
||||
@field_validator("high")
|
||||
@classmethod
|
||||
def high_gte_low(cls, v: Decimal, info) -> Decimal:
|
||||
if "low" in info.data and v < info.data["low"]:
|
||||
raise ValueError("high must be >= low")
|
||||
return v
|
||||
|
||||
@field_validator("high")
|
||||
@classmethod
|
||||
def high_gte_open_close(cls, v: Decimal, info) -> Decimal:
|
||||
if "open" in info.data and v < info.data["open"]:
|
||||
raise ValueError("high must be >= open")
|
||||
if "close" in info.data and v < info.data["close"]:
|
||||
raise ValueError("high must be >= close")
|
||||
return v
|
||||
|
||||
@field_validator("low")
|
||||
@classmethod
|
||||
def low_lte_open_close(cls, v: Decimal, info) -> Decimal:
|
||||
if "open" in info.data and v > info.data["open"]:
|
||||
raise ValueError("low must be <= open")
|
||||
if "close" in info.data and v > info.data["close"]:
|
||||
raise ValueError("low must be <= close")
|
||||
return v
|
||||
|
||||
|
||||
class OHLCV(BaseModel):
|
||||
ticker: str = Field(min_length=1, max_length=10)
|
||||
bars: list[OHLCVBar] = Field(default_factory=list)
|
||||
interval: str = Field(default="1d")
|
||||
currency: str = Field(default="USD")
|
||||
|
||||
@property
|
||||
def start_date(self) -> Optional[datetime]:
|
||||
return self.bars[0].timestamp if self.bars else None
|
||||
|
||||
@property
|
||||
def end_date(self) -> Optional[datetime]:
|
||||
return self.bars[-1].timestamp if self.bars else None
|
||||
|
||||
def get_bar(self, dt: datetime) -> Optional[OHLCVBar]:
|
||||
for bar in self.bars:
|
||||
if bar.timestamp.date() == dt.date():
|
||||
return bar
|
||||
return None
|
||||
|
||||
def slice(self, start: datetime, end: datetime) -> "OHLCV":
|
||||
filtered = [b for b in self.bars if start <= b.timestamp <= end]
|
||||
return OHLCV(
|
||||
ticker=self.ticker,
|
||||
bars=filtered,
|
||||
interval=self.interval,
|
||||
currency=self.currency,
|
||||
)
|
||||
|
||||
|
||||
class TechnicalIndicators(BaseModel):
|
||||
timestamp: datetime
|
||||
ticker: str
|
||||
|
||||
sma_20: Optional[Decimal] = None
|
||||
sma_50: Optional[Decimal] = None
|
||||
sma_200: Optional[Decimal] = None
|
||||
|
||||
ema_10: Optional[Decimal] = None
|
||||
ema_20: Optional[Decimal] = None
|
||||
|
||||
rsi_14: Optional[Decimal] = Field(default=None, ge=0, le=100)
|
||||
|
||||
macd: Optional[Decimal] = None
|
||||
macd_signal: Optional[Decimal] = None
|
||||
macd_histogram: Optional[Decimal] = None
|
||||
|
||||
bollinger_upper: Optional[Decimal] = None
|
||||
bollinger_middle: Optional[Decimal] = None
|
||||
bollinger_lower: Optional[Decimal] = None
|
||||
|
||||
atr_14: Optional[Decimal] = Field(default=None, ge=0)
|
||||
|
||||
mfi_14: Optional[Decimal] = Field(default=None, ge=0, le=100)
|
||||
|
||||
vwap: Optional[Decimal] = None
|
||||
|
||||
obv: Optional[int] = None
|
||||
|
||||
|
||||
class MarketSnapshot(BaseModel):
|
||||
ticker: str
|
||||
timestamp: datetime
|
||||
bar: OHLCVBar
|
||||
indicators: Optional[TechnicalIndicators] = None
|
||||
prev_close: Optional[Decimal] = None
|
||||
|
||||
@property
|
||||
def change(self) -> Optional[Decimal]:
|
||||
if self.prev_close:
|
||||
return self.bar.close - self.prev_close
|
||||
return None
|
||||
|
||||
@property
|
||||
def change_percent(self) -> Optional[Decimal]:
|
||||
if self.prev_close and self.prev_close > 0:
|
||||
return ((self.bar.close - self.prev_close) / self.prev_close) * 100
|
||||
return None
|
||||
|
||||
|
||||
class HistoricalDataRequest(BaseModel):
|
||||
ticker: str = Field(min_length=1, max_length=10)
|
||||
start_date: date
|
||||
end_date: date
|
||||
interval: str = Field(default="1d")
|
||||
include_indicators: bool = Field(default=True)
|
||||
adjusted: bool = Field(default=True)
|
||||
|
||||
@field_validator("end_date")
|
||||
@classmethod
|
||||
def end_after_start(cls, v: date, info) -> date:
|
||||
if "start_date" in info.data and v < info.data["start_date"]:
|
||||
raise ValueError("end_date must be >= start_date")
|
||||
return v
|
||||
|
||||
|
||||
class HistoricalDataResponse(BaseModel):
|
||||
request: HistoricalDataRequest
|
||||
ohlcv: OHLCV
|
||||
indicators: list[TechnicalIndicators] = Field(default_factory=list)
|
||||
fetched_at: datetime = Field(default_factory=datetime.now)
|
||||
source: str = Field(default="unknown")
|
||||
|
|
@ -0,0 +1,172 @@
|
|||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
from .trading import Position, Fill, OrderSide
|
||||
|
||||
|
||||
class TransactionType(str, Enum):
|
||||
DEPOSIT = "deposit"
|
||||
WITHDRAWAL = "withdrawal"
|
||||
DIVIDEND = "dividend"
|
||||
INTEREST = "interest"
|
||||
FEE = "fee"
|
||||
TRANSFER_IN = "transfer_in"
|
||||
TRANSFER_OUT = "transfer_out"
|
||||
|
||||
|
||||
class CashTransaction(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
transaction_type: TransactionType
|
||||
amount: Decimal
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
description: Optional[str] = None
|
||||
reference_id: Optional[UUID] = None
|
||||
|
||||
|
||||
class PortfolioConfig(BaseModel):
|
||||
initial_cash: Decimal = Field(default=Decimal("100000"), gt=0)
|
||||
commission_per_share: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
commission_per_trade: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
commission_percent: Decimal = Field(default=Decimal("0"), ge=0, le=100)
|
||||
min_commission: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
max_commission: Optional[Decimal] = Field(default=None, ge=0)
|
||||
slippage_percent: Decimal = Field(default=Decimal("0"), ge=0, le=100)
|
||||
margin_enabled: bool = Field(default=False)
|
||||
margin_rate: Decimal = Field(default=Decimal("0.05"), ge=0)
|
||||
max_position_size_percent: Decimal = Field(default=Decimal("100"), gt=0, le=100)
|
||||
allow_fractional_shares: bool = Field(default=False)
|
||||
|
||||
def calculate_commission(self, quantity: int, price: Decimal) -> Decimal:
|
||||
trade_value = quantity * price
|
||||
commission = Decimal("0")
|
||||
|
||||
commission += self.commission_per_trade
|
||||
commission += self.commission_per_share * quantity
|
||||
commission += trade_value * (self.commission_percent / 100)
|
||||
|
||||
if commission < self.min_commission:
|
||||
commission = self.min_commission
|
||||
if self.max_commission and commission > self.max_commission:
|
||||
commission = self.max_commission
|
||||
|
||||
return commission
|
||||
|
||||
def calculate_slippage(self, price: Decimal, side: OrderSide) -> Decimal:
|
||||
slippage = price * (self.slippage_percent / 100)
|
||||
if side == OrderSide.BUY:
|
||||
return price + slippage
|
||||
return price - slippage
|
||||
|
||||
|
||||
class PortfolioSnapshot(BaseModel):
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
cash: Decimal = Field(default=Decimal("0"))
|
||||
positions: dict[str, Position] = Field(default_factory=dict)
|
||||
pending_orders: int = Field(default=0, ge=0)
|
||||
total_commission_paid: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
total_realized_pnl: Decimal = Field(default=Decimal("0"))
|
||||
cash_transactions: list[CashTransaction] = Field(default_factory=list)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def position_count(self) -> int:
|
||||
return len([p for p in self.positions.values() if p.quantity != 0])
|
||||
|
||||
def positions_value(self, prices: dict[str, Decimal]) -> Decimal:
|
||||
total = Decimal("0")
|
||||
for ticker, position in self.positions.items():
|
||||
if ticker in prices and position.quantity != 0:
|
||||
total += position.market_value(prices[ticker])
|
||||
return total
|
||||
|
||||
def total_equity(self, prices: dict[str, Decimal]) -> Decimal:
|
||||
return self.cash + self.positions_value(prices)
|
||||
|
||||
def total_unrealized_pnl(self, prices: dict[str, Decimal]) -> Decimal:
|
||||
total = Decimal("0")
|
||||
for ticker, position in self.positions.items():
|
||||
if ticker in prices:
|
||||
total += position.unrealized_pnl(prices[ticker])
|
||||
return total
|
||||
|
||||
def get_position(self, ticker: str) -> Position:
|
||||
if ticker not in self.positions:
|
||||
self.positions[ticker] = Position(ticker=ticker)
|
||||
return self.positions[ticker]
|
||||
|
||||
def apply_fill(self, fill: Fill) -> None:
|
||||
position = self.get_position(fill.ticker)
|
||||
old_realized = position.realized_pnl
|
||||
|
||||
position.update_from_fill(fill)
|
||||
|
||||
pnl_change = position.realized_pnl - old_realized
|
||||
self.total_realized_pnl += pnl_change
|
||||
self.total_commission_paid += fill.commission
|
||||
|
||||
if fill.side == OrderSide.BUY:
|
||||
self.cash -= fill.total_cost
|
||||
else:
|
||||
self.cash += fill.total_value - fill.commission
|
||||
|
||||
def add_cash_transaction(self, transaction: CashTransaction) -> None:
|
||||
self.cash_transactions.append(transaction)
|
||||
if transaction.transaction_type in (
|
||||
TransactionType.DEPOSIT,
|
||||
TransactionType.DIVIDEND,
|
||||
TransactionType.INTEREST,
|
||||
TransactionType.TRANSFER_IN,
|
||||
):
|
||||
self.cash += transaction.amount
|
||||
else:
|
||||
self.cash -= abs(transaction.amount)
|
||||
|
||||
def can_afford(
|
||||
self, ticker: str, quantity: int, price: Decimal, config: PortfolioConfig
|
||||
) -> bool:
|
||||
execution_price = config.calculate_slippage(price, OrderSide.BUY)
|
||||
commission = config.calculate_commission(quantity, execution_price)
|
||||
total_cost = (quantity * execution_price) + commission
|
||||
return self.cash >= total_cost
|
||||
|
||||
def max_shares_affordable(
|
||||
self, ticker: str, price: Decimal, config: PortfolioConfig
|
||||
) -> int:
|
||||
if price <= 0:
|
||||
return 0
|
||||
|
||||
execution_price = config.calculate_slippage(price, OrderSide.BUY)
|
||||
available = self.cash
|
||||
|
||||
low, high = 0, int(available / execution_price) + 1
|
||||
result = 0
|
||||
|
||||
while low <= high:
|
||||
mid = (low + high) // 2
|
||||
commission = config.calculate_commission(mid, execution_price)
|
||||
total_cost = (mid * execution_price) + commission
|
||||
|
||||
if total_cost <= available:
|
||||
result = mid
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid - 1
|
||||
|
||||
return result
|
||||
|
||||
def to_dict(self, prices: dict[str, Decimal]) -> dict:
|
||||
return {
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"cash": float(self.cash),
|
||||
"positions_value": float(self.positions_value(prices)),
|
||||
"total_equity": float(self.total_equity(prices)),
|
||||
"position_count": self.position_count,
|
||||
"total_realized_pnl": float(self.total_realized_pnl),
|
||||
"total_unrealized_pnl": float(self.total_unrealized_pnl(prices)),
|
||||
"total_commission_paid": float(self.total_commission_paid),
|
||||
}
|
||||
|
|
@ -0,0 +1,201 @@
|
|||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
|
||||
class OrderSide(str, Enum):
|
||||
BUY = "buy"
|
||||
SELL = "sell"
|
||||
|
||||
|
||||
class OrderType(str, Enum):
|
||||
MARKET = "market"
|
||||
LIMIT = "limit"
|
||||
STOP = "stop"
|
||||
STOP_LIMIT = "stop_limit"
|
||||
|
||||
|
||||
class OrderStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
SUBMITTED = "submitted"
|
||||
PARTIAL = "partial"
|
||||
FILLED = "filled"
|
||||
CANCELLED = "cancelled"
|
||||
REJECTED = "rejected"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class PositionSide(str, Enum):
|
||||
LONG = "long"
|
||||
SHORT = "short"
|
||||
FLAT = "flat"
|
||||
|
||||
|
||||
class Order(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
ticker: str = Field(min_length=1, max_length=10)
|
||||
side: OrderSide
|
||||
order_type: OrderType = Field(default=OrderType.MARKET)
|
||||
quantity: int = Field(gt=0)
|
||||
limit_price: Optional[Decimal] = Field(default=None, gt=0)
|
||||
stop_price: Optional[Decimal] = Field(default=None, gt=0)
|
||||
status: OrderStatus = Field(default=OrderStatus.PENDING)
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
submitted_at: Optional[datetime] = None
|
||||
filled_at: Optional[datetime] = None
|
||||
filled_quantity: int = Field(default=0, ge=0)
|
||||
filled_avg_price: Optional[Decimal] = None
|
||||
commission: Decimal = Field(default=Decimal("0"))
|
||||
notes: Optional[str] = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def remaining_quantity(self) -> int:
|
||||
return self.quantity - self.filled_quantity
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
return self.status in (
|
||||
OrderStatus.FILLED,
|
||||
OrderStatus.CANCELLED,
|
||||
OrderStatus.REJECTED,
|
||||
OrderStatus.EXPIRED,
|
||||
)
|
||||
|
||||
|
||||
class Fill(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
order_id: UUID
|
||||
ticker: str
|
||||
side: OrderSide
|
||||
quantity: int = Field(gt=0)
|
||||
price: Decimal = Field(gt=0)
|
||||
commission: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def total_value(self) -> Decimal:
|
||||
return self.price * self.quantity
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def total_cost(self) -> Decimal:
|
||||
if self.side == OrderSide.BUY:
|
||||
return self.total_value + self.commission
|
||||
return self.total_value - self.commission
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
ticker: str = Field(min_length=1, max_length=10)
|
||||
quantity: int = Field(default=0)
|
||||
avg_cost: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
realized_pnl: Decimal = Field(default=Decimal("0"))
|
||||
opened_at: Optional[datetime] = None
|
||||
last_updated: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def side(self) -> PositionSide:
|
||||
if self.quantity > 0:
|
||||
return PositionSide.LONG
|
||||
elif self.quantity < 0:
|
||||
return PositionSide.SHORT
|
||||
return PositionSide.FLAT
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def cost_basis(self) -> Decimal:
|
||||
return abs(self.quantity) * self.avg_cost
|
||||
|
||||
def unrealized_pnl(self, current_price: Decimal) -> Decimal:
|
||||
if self.quantity == 0:
|
||||
return Decimal("0")
|
||||
market_value = self.quantity * current_price
|
||||
return market_value - (self.quantity * self.avg_cost)
|
||||
|
||||
def market_value(self, current_price: Decimal) -> Decimal:
|
||||
return abs(self.quantity) * current_price
|
||||
|
||||
def update_from_fill(self, fill: Fill) -> None:
|
||||
if fill.side == OrderSide.BUY:
|
||||
if self.quantity >= 0:
|
||||
total_cost = (self.quantity * self.avg_cost) + fill.total_value
|
||||
self.quantity += fill.quantity
|
||||
self.avg_cost = total_cost / self.quantity if self.quantity else Decimal("0")
|
||||
else:
|
||||
close_qty = min(fill.quantity, abs(self.quantity))
|
||||
pnl = close_qty * (self.avg_cost - fill.price)
|
||||
self.realized_pnl += pnl
|
||||
self.quantity += fill.quantity
|
||||
if self.quantity > 0:
|
||||
self.avg_cost = fill.price
|
||||
else:
|
||||
if self.quantity <= 0:
|
||||
total_cost = (abs(self.quantity) * self.avg_cost) + fill.total_value
|
||||
self.quantity -= fill.quantity
|
||||
self.avg_cost = total_cost / abs(self.quantity) if self.quantity else Decimal("0")
|
||||
else:
|
||||
close_qty = min(fill.quantity, self.quantity)
|
||||
pnl = close_qty * (fill.price - self.avg_cost)
|
||||
self.realized_pnl += pnl
|
||||
self.quantity -= fill.quantity
|
||||
if self.quantity < 0:
|
||||
self.avg_cost = fill.price
|
||||
|
||||
if self.quantity != 0 and self.opened_at is None:
|
||||
self.opened_at = fill.timestamp
|
||||
elif self.quantity == 0:
|
||||
self.opened_at = None
|
||||
|
||||
self.last_updated = fill.timestamp
|
||||
|
||||
|
||||
class Trade(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
ticker: str
|
||||
side: OrderSide
|
||||
entry_price: Decimal = Field(gt=0)
|
||||
entry_quantity: int = Field(gt=0)
|
||||
entry_time: datetime
|
||||
exit_price: Optional[Decimal] = Field(default=None, gt=0)
|
||||
exit_quantity: Optional[int] = Field(default=None, gt=0)
|
||||
exit_time: Optional[datetime] = None
|
||||
commission: Decimal = Field(default=Decimal("0"), ge=0)
|
||||
entry_order_id: Optional[UUID] = None
|
||||
exit_order_id: Optional[UUID] = None
|
||||
notes: Optional[str] = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self.exit_price is not None and self.exit_quantity is not None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def pnl(self) -> Optional[Decimal]:
|
||||
if not self.is_closed:
|
||||
return None
|
||||
if self.side == OrderSide.BUY:
|
||||
return (self.exit_price - self.entry_price) * self.exit_quantity - self.commission
|
||||
return (self.entry_price - self.exit_price) * self.exit_quantity - self.commission
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def pnl_percent(self) -> Optional[Decimal]:
|
||||
if not self.is_closed or self.entry_price == 0:
|
||||
return None
|
||||
return (self.pnl / (self.entry_price * self.entry_quantity)) * 100
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def holding_period(self) -> Optional[int]:
|
||||
if not self.exit_time:
|
||||
return None
|
||||
return (self.exit_time - self.entry_time).days
|
||||
Loading…
Reference in New Issue