feat: add backtesting framework and fix code quality issues

- Add complete backtesting engine with portfolio simulation, metrics calculation
  (Sharpe, Sortino, max drawdown), and agent integration
- Add Pydantic data models for market data, trading, portfolio, and backtest results
- Add backtest CLI command with SMA, RSI, and hold strategies
- Fix 24+ bare exception handlers with specific exception types
- Fix hardcoded path in default_config.py (use TRADINGAGENTS_DATA_DIR env var)
- Fix unclosed file handle in local.py with context manager
- Disable store=True in OpenAI API calls for data privacy
- Fix typo: rename aggresive_debator.py to aggressive_debator.py
- Add request timeouts (30s) to alpha_vantage_common.py and googlenews_utils.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Joseph O'Brien 2025-12-03 02:55:28 -05:00
parent 69c3a13883
commit f70874982a
34 changed files with 3409 additions and 33 deletions

View File

@ -34,6 +34,9 @@ from tradingagents.agents.discovery.models import (
EventCategory,
)
from tradingagents.agents.discovery.persistence import save_discovery_result
from tradingagents.backtesting import SimpleBacktestEngine, DataLoader
from tradingagents.models.backtest import BacktestConfig
from tradingagents.models.portfolio import PortfolioConfig
from cli.models import AnalystType
from cli.utils import (
ANALYST_ORDER,
@ -1715,6 +1718,195 @@ def menu():
discover_trending_flow()
@app.command()
def backtest(
ticker: str = typer.Option(None, "--ticker", "-t", help="Ticker symbol to backtest"),
start_date: str = typer.Option(None, "--start", "-s", help="Start date (YYYY-MM-DD)"),
end_date: str = typer.Option(None, "--end", "-e", help="End date (YYYY-MM-DD)"),
initial_cash: float = typer.Option(100000.0, "--cash", "-c", help="Initial portfolio cash"),
strategy: str = typer.Option("sma", "--strategy", help="Strategy: sma, rsi, or hold"),
):
from decimal import Decimal
from datetime import date as date_type
if not ticker:
console.print(create_question_box("Ticker Symbol", "Enter the ticker symbol to backtest", "AAPL"))
ticker = typer.prompt("", default="AAPL")
if not start_date:
default_start = (datetime.datetime.now() - datetime.timedelta(days=365)).strftime("%Y-%m-%d")
console.print(create_question_box("Start Date", "Enter backtest start date (YYYY-MM-DD)", default_start))
start_date = typer.prompt("", default=default_start)
if not end_date:
default_end = datetime.datetime.now().strftime("%Y-%m-%d")
console.print(create_question_box("End Date", "Enter backtest end date (YYYY-MM-DD)", default_end))
end_date = typer.prompt("", default=default_end)
try:
start = datetime.datetime.strptime(start_date, "%Y-%m-%d").date()
end = datetime.datetime.strptime(end_date, "%Y-%m-%d").date()
except ValueError:
console.print("[red]Invalid date format. Use YYYY-MM-DD[/red]")
return
if start >= end:
console.print("[red]Start date must be before end date[/red]")
return
console.print()
console.print(Panel(
f"[bold]Backtest Configuration[/bold]\n\n"
f"Ticker: [cyan]{ticker.upper()}[/cyan]\n"
f"Period: [cyan]{start_date}[/cyan] to [cyan]{end_date}[/cyan]\n"
f"Initial Cash: [cyan]${initial_cash:,.2f}[/cyan]\n"
f"Strategy: [cyan]{strategy}[/cyan]",
title="Configuration",
border_style="blue",
))
console.print()
def sma_buy(t, trading_date, ctx):
loader = ctx["data_loader"]
ohlcv = loader.load_ohlcv(t, date_type(2020, 1, 1), trading_date)
if len(ohlcv.bars) < 20:
return False
prices = [float(b.close) for b in ohlcv.bars[-20:]]
sma = sum(prices) / len(prices)
current = float(ohlcv.bars[-1].close)
return current > sma * 1.02
def sma_sell(t, trading_date, ctx):
loader = ctx["data_loader"]
ohlcv = loader.load_ohlcv(t, date_type(2020, 1, 1), trading_date)
if len(ohlcv.bars) < 20:
return False
prices = [float(b.close) for b in ohlcv.bars[-20:]]
sma = sum(prices) / len(prices)
current = float(ohlcv.bars[-1].close)
return current < sma * 0.98
def rsi_buy(t, trading_date, ctx):
loader = ctx["data_loader"]
ohlcv = loader.load_ohlcv(t, date_type(2020, 1, 1), trading_date)
if len(ohlcv.bars) < 15:
return False
changes = []
for i in range(1, min(15, len(ohlcv.bars))):
changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i-1].close))
gains = [c for c in changes if c > 0]
losses = [-c for c in changes if c < 0]
avg_gain = sum(gains) / 14 if gains else 0.001
avg_loss = sum(losses) / 14 if losses else 0.001
rs = avg_gain / avg_loss if avg_loss else 100
rsi = 100 - (100 / (1 + rs))
return rsi < 30
def rsi_sell(t, trading_date, ctx):
loader = ctx["data_loader"]
ohlcv = loader.load_ohlcv(t, date_type(2020, 1, 1), trading_date)
if len(ohlcv.bars) < 15:
return False
changes = []
for i in range(1, min(15, len(ohlcv.bars))):
changes.append(float(ohlcv.bars[-i].close) - float(ohlcv.bars[-i-1].close))
gains = [c for c in changes if c > 0]
losses = [-c for c in changes if c < 0]
avg_gain = sum(gains) / 14 if gains else 0.001
avg_loss = sum(losses) / 14 if losses else 0.001
rs = avg_gain / avg_loss if avg_loss else 100
rsi = 100 - (100 / (1 + rs))
return rsi > 70
def hold_buy(t, trading_date, ctx):
return ctx.get("day_index", 0) == 5
def hold_sell(t, trading_date, ctx):
return False
strategies = {
"sma": (sma_buy, sma_sell),
"rsi": (rsi_buy, rsi_sell),
"hold": (hold_buy, hold_sell),
}
if strategy not in strategies:
console.print(f"[red]Unknown strategy: {strategy}. Use: sma, rsi, or hold[/red]")
return
buy_fn, sell_fn = strategies[strategy]
config = BacktestConfig(
name=f"{strategy.upper()} Backtest - {ticker.upper()}",
tickers=[ticker.upper()],
start_date=start,
end_date=end,
portfolio_config=PortfolioConfig(
initial_cash=Decimal(str(initial_cash)),
commission_per_trade=Decimal("1"),
slippage_percent=Decimal("0.05"),
),
warmup_period=5,
)
with loading("Running backtest...", show_elapsed=True):
engine = SimpleBacktestEngine(config, buy_signal=buy_fn, sell_signal=sell_fn)
result = engine.run()
console.print()
if result.status == "failed":
console.print(f"[red]Backtest failed: {result.error_message}[/red]")
return
metrics = result.metrics
trade_log = result.trade_log
performance_table = Table(title="Performance Metrics", box=box.ROUNDED)
performance_table.add_column("Metric", style="cyan")
performance_table.add_column("Value", style="green")
performance_table.add_row("Total Return", f"${float(metrics.total_return):,.2f}")
performance_table.add_row("Total Return %", f"{float(metrics.total_return_percent):.2f}%")
performance_table.add_row("Annualized Return", f"{float(metrics.annualized_return):.2f}%")
performance_table.add_row("Sharpe Ratio", f"{float(metrics.sharpe_ratio):.2f}" if metrics.sharpe_ratio else "N/A")
performance_table.add_row("Sortino Ratio", f"{float(metrics.sortino_ratio):.2f}" if metrics.sortino_ratio else "N/A")
performance_table.add_row("Max Drawdown", f"{float(metrics.max_drawdown_percent):.2f}%")
performance_table.add_row("Volatility (Ann.)", f"{float(metrics.annualized_volatility):.2f}%")
console.print(performance_table)
console.print()
trading_table = Table(title="Trading Statistics", box=box.ROUNDED)
trading_table.add_column("Metric", style="cyan")
trading_table.add_column("Value", style="green")
trading_table.add_row("Total Trades", str(trade_log.total_trades))
trading_table.add_row("Winning Trades", str(trade_log.winning_trades))
trading_table.add_row("Losing Trades", str(trade_log.losing_trades))
trading_table.add_row("Win Rate", f"{float(trade_log.win_rate):.1f}%" if trade_log.win_rate else "N/A")
trading_table.add_row("Profit Factor", f"{float(trade_log.profit_factor):.2f}" if trade_log.profit_factor else "N/A")
trading_table.add_row("Avg Win", f"${float(trade_log.avg_win):,.2f}" if trade_log.avg_win else "N/A")
trading_table.add_row("Avg Loss", f"${float(trade_log.avg_loss):,.2f}" if trade_log.avg_loss else "N/A")
console.print(trading_table)
console.print()
summary_table = Table(title="Portfolio Summary", box=box.ROUNDED)
summary_table.add_column("Metric", style="cyan")
summary_table.add_column("Value", style="green")
summary_table.add_row("Start Equity", f"${float(metrics.start_equity):,.2f}")
summary_table.add_row("End Equity", f"${float(metrics.end_equity):,.2f}")
summary_table.add_row("Trading Days", str(metrics.trading_days))
summary_table.add_row("Duration", f"{result.duration_seconds:.1f} seconds")
console.print(summary_table)
console.print()
console.print(f"[green]Backtest completed successfully![/green]")
if __name__ == "__main__":
choice = show_main_menu()
if choice == "analyze":

0
tests/models/__init__.py Normal file
View File

View File

@ -0,0 +1,318 @@
from datetime import date, datetime
from decimal import Decimal
import pytest
from tradingagents.models.backtest import (
BacktestConfig,
BacktestResult,
BacktestMetrics,
EquityCurvePoint,
TradeLog,
)
from tradingagents.models.portfolio import PortfolioConfig
from tradingagents.models.trading import Trade, OrderSide
class TestBacktestConfig:
def test_basic_config(self):
config = BacktestConfig(
tickers=["AAPL"],
start_date=date(2024, 1, 1),
end_date=date(2024, 6, 30),
)
assert config.tickers == ["AAPL"]
assert config.interval == "1d"
assert config.benchmark_ticker == "SPY"
def test_multi_ticker_config(self):
config = BacktestConfig(
name="Multi-Stock Test",
tickers=["aapl", "googl", "msft"],
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
)
assert config.tickers == ["AAPL", "GOOGL", "MSFT"]
def test_invalid_date_range(self):
with pytest.raises(ValueError):
BacktestConfig(
tickers=["AAPL"],
start_date=date(2024, 6, 30),
end_date=date(2024, 1, 1),
)
def test_same_start_end_date(self):
with pytest.raises(ValueError):
BacktestConfig(
tickers=["AAPL"],
start_date=date(2024, 1, 1),
end_date=date(2024, 1, 1),
)
def test_empty_tickers(self):
with pytest.raises(ValueError):
BacktestConfig(
tickers=[],
start_date=date(2024, 1, 1),
end_date=date(2024, 6, 30),
)
def test_trading_days_estimate(self):
config = BacktestConfig(
tickers=["AAPL"],
start_date=date(2024, 1, 1),
end_date=date(2024, 12, 31),
)
assert config.trading_days_estimate > 200
assert config.trading_days_estimate < 260
def test_custom_portfolio_config(self):
portfolio_config = PortfolioConfig(
initial_cash=Decimal("50000"),
commission_per_trade=Decimal("5"),
)
config = BacktestConfig(
tickers=["AAPL"],
start_date=date(2024, 1, 1),
end_date=date(2024, 6, 30),
portfolio_config=portfolio_config,
)
assert config.portfolio_config.initial_cash == Decimal("50000")
class TestTradeLog:
def test_empty_trade_log(self):
log = TradeLog()
assert log.total_trades == 0
assert log.win_rate is None
assert log.profit_factor is None
def test_add_winning_trade(self):
log = TradeLog()
trade = Trade(
ticker="AAPL",
side=OrderSide.BUY,
entry_price=Decimal("100"),
entry_quantity=100,
entry_time=datetime(2024, 1, 15),
exit_price=Decimal("110"),
exit_quantity=100,
exit_time=datetime(2024, 1, 20),
)
log.add_trade(trade)
assert log.total_trades == 1
assert log.winning_trades == 1
assert log.win_rate == Decimal("100")
def test_add_losing_trade(self):
log = TradeLog()
trade = Trade(
ticker="AAPL",
side=OrderSide.BUY,
entry_price=Decimal("100"),
entry_quantity=100,
entry_time=datetime(2024, 1, 15),
exit_price=Decimal("90"),
exit_quantity=100,
exit_time=datetime(2024, 1, 20),
)
log.add_trade(trade)
assert log.total_trades == 1
assert log.losing_trades == 1
assert log.win_rate == Decimal("0")
def test_mixed_trades(self):
log = TradeLog()
win_trade = Trade(
ticker="AAPL",
side=OrderSide.BUY,
entry_price=Decimal("100"),
entry_quantity=100,
entry_time=datetime(2024, 1, 15),
exit_price=Decimal("120"),
exit_quantity=100,
exit_time=datetime(2024, 1, 20),
)
log.add_trade(win_trade)
loss_trade = Trade(
ticker="GOOGL",
side=OrderSide.BUY,
entry_price=Decimal("100"),
entry_quantity=100,
entry_time=datetime(2024, 1, 15),
exit_price=Decimal("90"),
exit_quantity=100,
exit_time=datetime(2024, 1, 20),
)
log.add_trade(loss_trade)
assert log.total_trades == 2
assert log.winning_trades == 1
assert log.losing_trades == 1
assert log.win_rate == Decimal("50")
def test_gross_profit_loss(self):
log = TradeLog()
win_trade = Trade(
ticker="AAPL",
side=OrderSide.BUY,
entry_price=Decimal("100"),
entry_quantity=100,
entry_time=datetime(2024, 1, 15),
exit_price=Decimal("120"),
exit_quantity=100,
exit_time=datetime(2024, 1, 20),
)
log.add_trade(win_trade)
loss_trade = Trade(
ticker="GOOGL",
side=OrderSide.BUY,
entry_price=Decimal("100"),
entry_quantity=100,
entry_time=datetime(2024, 1, 15),
exit_price=Decimal("90"),
exit_quantity=100,
exit_time=datetime(2024, 1, 20),
)
log.add_trade(loss_trade)
assert log.gross_profit == Decimal("2000")
assert log.gross_loss == Decimal("1000")
assert log.profit_factor == Decimal("2")
def test_avg_win_loss(self):
log = TradeLog()
for i in range(3):
trade = Trade(
ticker="AAPL",
side=OrderSide.BUY,
entry_price=Decimal("100"),
entry_quantity=100,
entry_time=datetime(2024, 1, 15),
exit_price=Decimal("110"),
exit_quantity=100,
exit_time=datetime(2024, 1, 20),
)
log.add_trade(trade)
assert log.avg_win == Decimal("1000")
class TestBacktestMetrics:
def test_basic_metrics(self):
metrics = BacktestMetrics(
total_return=Decimal("10000"),
total_return_percent=Decimal("10"),
annualized_return=Decimal("15"),
max_drawdown=Decimal("5000"),
max_drawdown_percent=Decimal("5"),
sharpe_ratio=Decimal("1.5"),
total_trades=50,
win_rate=Decimal("60"),
start_equity=Decimal("100000"),
end_equity=Decimal("110000"),
)
assert metrics.total_return_percent == Decimal("10")
assert metrics.sharpe_ratio == Decimal("1.5")
def test_to_summary_dict(self):
metrics = BacktestMetrics(
total_return=Decimal("10000"),
total_return_percent=Decimal("10"),
annualized_return=Decimal("15"),
max_drawdown=Decimal("5000"),
max_drawdown_percent=Decimal("5"),
sharpe_ratio=Decimal("1.5"),
sortino_ratio=Decimal("2.0"),
volatility=Decimal("10"),
annualized_volatility=Decimal("15"),
total_trades=50,
win_rate=Decimal("60"),
profit_factor=Decimal("1.8"),
total_commission=Decimal("500"),
total_slippage=Decimal("200"),
start_equity=Decimal("100000"),
end_equity=Decimal("110000"),
)
summary = metrics.to_summary_dict()
assert "Performance" in summary
assert "Risk" in summary
assert "Trading" in summary
assert "Costs" in summary
assert summary["Performance"]["Sharpe Ratio"] == "1.50"
class TestEquityCurvePoint:
def test_equity_point(self):
point = EquityCurvePoint(
timestamp=datetime(2024, 1, 15),
equity=Decimal("105000"),
cash=Decimal("50000"),
positions_value=Decimal("55000"),
drawdown=Decimal("2000"),
drawdown_percent=Decimal("1.9"),
)
assert point.equity == Decimal("105000")
assert point.drawdown_percent == Decimal("1.9")
class TestBacktestResult:
def test_backtest_result(self):
config = BacktestConfig(
tickers=["AAPL"],
start_date=date(2024, 1, 1),
end_date=date(2024, 6, 30),
)
metrics = BacktestMetrics(
total_return=Decimal("10000"),
total_return_percent=Decimal("10"),
start_equity=Decimal("100000"),
end_equity=Decimal("110000"),
)
trade_log = TradeLog(total_trades=10, winning_trades=6, losing_trades=4)
result = BacktestResult(
config=config,
metrics=metrics,
trade_log=trade_log,
started_at=datetime(2024, 7, 1, 10, 0, 0),
completed_at=datetime(2024, 7, 1, 10, 5, 30),
)
assert result.duration_seconds == 330.0
assert result.status == "completed"
def test_to_dict(self):
config = BacktestConfig(
name="Test Backtest",
tickers=["AAPL"],
start_date=date(2024, 1, 1),
end_date=date(2024, 6, 30),
)
metrics = BacktestMetrics(
total_return=Decimal("10000"),
total_return_percent=Decimal("10"),
start_equity=Decimal("100000"),
end_equity=Decimal("110000"),
)
trade_log = TradeLog(total_trades=10, winning_trades=6, losing_trades=4)
result = BacktestResult(
config=config,
metrics=metrics,
trade_log=trade_log,
started_at=datetime(2024, 7, 1, 10, 0, 0),
completed_at=datetime(2024, 7, 1, 10, 5, 0),
)
result_dict = result.to_dict()
assert result_dict["config"]["name"] == "Test Backtest"
assert result_dict["trade_summary"]["total_trades"] == 10
assert result_dict["trade_summary"]["win_rate"] == 60.0

View File

@ -0,0 +1,216 @@
from datetime import datetime, date
from decimal import Decimal
import pytest
from tradingagents.models.market_data import (
OHLCVBar,
OHLCV,
TechnicalIndicators,
MarketSnapshot,
HistoricalDataRequest,
HistoricalDataResponse,
)
class TestOHLCVBar:
def test_valid_bar(self):
bar = OHLCVBar(
timestamp=datetime(2024, 1, 15, 9, 30),
open=Decimal("100.00"),
high=Decimal("105.00"),
low=Decimal("99.00"),
close=Decimal("103.50"),
volume=1000000,
)
assert bar.open == Decimal("100.00")
assert bar.high == Decimal("105.00")
assert bar.volume == 1000000
def test_bar_with_adjusted_close(self):
bar = OHLCVBar(
timestamp=datetime(2024, 1, 15),
open=Decimal("100"),
high=Decimal("105"),
low=Decimal("99"),
close=Decimal("103"),
volume=1000000,
adjusted_close=Decimal("102.50"),
)
assert bar.adjusted_close == Decimal("102.50")
def test_invalid_negative_price(self):
with pytest.raises(ValueError):
OHLCVBar(
timestamp=datetime(2024, 1, 15),
open=Decimal("-100"),
high=Decimal("105"),
low=Decimal("99"),
close=Decimal("103"),
volume=1000000,
)
def test_invalid_negative_volume(self):
with pytest.raises(ValueError):
OHLCVBar(
timestamp=datetime(2024, 1, 15),
open=Decimal("100"),
high=Decimal("105"),
low=Decimal("99"),
close=Decimal("103"),
volume=-1000,
)
class TestOHLCV:
@pytest.fixture
def sample_bars(self):
return [
OHLCVBar(
timestamp=datetime(2024, 1, 15),
open=Decimal("100"),
high=Decimal("105"),
low=Decimal("99"),
close=Decimal("103"),
volume=1000000,
),
OHLCVBar(
timestamp=datetime(2024, 1, 16),
open=Decimal("103"),
high=Decimal("108"),
low=Decimal("102"),
close=Decimal("107"),
volume=1200000,
),
OHLCVBar(
timestamp=datetime(2024, 1, 17),
open=Decimal("107"),
high=Decimal("110"),
low=Decimal("105"),
close=Decimal("109"),
volume=900000,
),
]
def test_ohlcv_creation(self, sample_bars):
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
assert ohlcv.ticker == "AAPL"
assert len(ohlcv.bars) == 3
assert ohlcv.interval == "1d"
assert ohlcv.currency == "USD"
def test_start_end_dates(self, sample_bars):
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
assert ohlcv.start_date == datetime(2024, 1, 15)
assert ohlcv.end_date == datetime(2024, 1, 17)
def test_empty_ohlcv(self):
ohlcv = OHLCV(ticker="AAPL", bars=[])
assert ohlcv.start_date is None
assert ohlcv.end_date is None
def test_get_bar(self, sample_bars):
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
bar = ohlcv.get_bar(datetime(2024, 1, 16))
assert bar is not None
assert bar.close == Decimal("107")
def test_get_bar_not_found(self, sample_bars):
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
bar = ohlcv.get_bar(datetime(2024, 1, 20))
assert bar is None
def test_slice(self, sample_bars):
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
sliced = ohlcv.slice(datetime(2024, 1, 15), datetime(2024, 1, 16))
assert len(sliced.bars) == 2
assert sliced.ticker == "AAPL"
def test_invalid_ticker(self):
with pytest.raises(ValueError):
OHLCV(ticker="", bars=[])
class TestTechnicalIndicators:
def test_valid_indicators(self):
indicators = TechnicalIndicators(
timestamp=datetime(2024, 1, 15),
ticker="AAPL",
sma_50=Decimal("150.00"),
rsi_14=Decimal("65.5"),
macd=Decimal("2.5"),
)
assert indicators.sma_50 == Decimal("150.00")
assert indicators.rsi_14 == Decimal("65.5")
def test_rsi_bounds(self):
with pytest.raises(ValueError):
TechnicalIndicators(
timestamp=datetime(2024, 1, 15),
ticker="AAPL",
rsi_14=Decimal("150"),
)
def test_mfi_bounds(self):
with pytest.raises(ValueError):
TechnicalIndicators(
timestamp=datetime(2024, 1, 15),
ticker="AAPL",
mfi_14=Decimal("-10"),
)
class TestMarketSnapshot:
def test_snapshot_change_calculation(self):
bar = OHLCVBar(
timestamp=datetime(2024, 1, 15),
open=Decimal("100"),
high=Decimal("105"),
low=Decimal("99"),
close=Decimal("103"),
volume=1000000,
)
snapshot = MarketSnapshot(
ticker="AAPL",
timestamp=datetime(2024, 1, 15),
bar=bar,
prev_close=Decimal("100"),
)
assert snapshot.change == Decimal("3")
assert snapshot.change_percent == Decimal("3")
def test_snapshot_no_prev_close(self):
bar = OHLCVBar(
timestamp=datetime(2024, 1, 15),
open=Decimal("100"),
high=Decimal("105"),
low=Decimal("99"),
close=Decimal("103"),
volume=1000000,
)
snapshot = MarketSnapshot(
ticker="AAPL",
timestamp=datetime(2024, 1, 15),
bar=bar,
)
assert snapshot.change is None
assert snapshot.change_percent is None
class TestHistoricalDataRequest:
def test_valid_request(self):
request = HistoricalDataRequest(
ticker="AAPL",
start_date=date(2024, 1, 1),
end_date=date(2024, 6, 30),
)
assert request.ticker == "AAPL"
assert request.include_indicators is True
def test_invalid_date_range(self):
with pytest.raises(ValueError):
HistoricalDataRequest(
ticker="AAPL",
start_date=date(2024, 6, 30),
end_date=date(2024, 1, 1),
)

View File

@ -0,0 +1,207 @@
from datetime import datetime
from decimal import Decimal
from uuid import uuid4
import pytest
from tradingagents.models.portfolio import (
PortfolioConfig,
PortfolioSnapshot,
CashTransaction,
TransactionType,
)
from tradingagents.models.trading import OrderSide, Fill, Position
class TestPortfolioConfig:
def test_default_config(self):
config = PortfolioConfig()
assert config.initial_cash == Decimal("100000")
assert config.commission_per_share == Decimal("0")
assert config.slippage_percent == Decimal("0")
def test_custom_config(self):
config = PortfolioConfig(
initial_cash=Decimal("50000"),
commission_per_trade=Decimal("5.00"),
slippage_percent=Decimal("0.1"),
)
assert config.initial_cash == Decimal("50000")
assert config.commission_per_trade == Decimal("5.00")
def test_calculate_commission_flat(self):
config = PortfolioConfig(commission_per_trade=Decimal("5.00"))
commission = config.calculate_commission(100, Decimal("150.00"))
assert commission == Decimal("5.00")
def test_calculate_commission_per_share(self):
config = PortfolioConfig(commission_per_share=Decimal("0.01"))
commission = config.calculate_commission(100, Decimal("150.00"))
assert commission == Decimal("1.00")
def test_calculate_commission_percent(self):
config = PortfolioConfig(commission_percent=Decimal("0.1"))
commission = config.calculate_commission(100, Decimal("100.00"))
assert commission == Decimal("10.00")
def test_calculate_commission_minimum(self):
config = PortfolioConfig(
commission_per_trade=Decimal("1.00"),
min_commission=Decimal("5.00"),
)
commission = config.calculate_commission(10, Decimal("10.00"))
assert commission == Decimal("5.00")
def test_calculate_commission_maximum(self):
config = PortfolioConfig(
commission_percent=Decimal("1"),
max_commission=Decimal("50.00"),
)
commission = config.calculate_commission(1000, Decimal("100.00"))
assert commission == Decimal("50.00")
def test_calculate_slippage_buy(self):
config = PortfolioConfig(slippage_percent=Decimal("0.1"))
price = config.calculate_slippage(Decimal("100.00"), OrderSide.BUY)
assert price == Decimal("100.10")
def test_calculate_slippage_sell(self):
config = PortfolioConfig(slippage_percent=Decimal("0.1"))
price = config.calculate_slippage(Decimal("100.00"), OrderSide.SELL)
assert price == Decimal("99.90")
class TestPortfolioSnapshot:
def test_new_portfolio(self):
portfolio = PortfolioSnapshot(cash=Decimal("100000"))
assert portfolio.cash == Decimal("100000")
assert portfolio.position_count == 0
assert len(portfolio.positions) == 0
def test_get_position_creates_new(self):
portfolio = PortfolioSnapshot(cash=Decimal("100000"))
position = portfolio.get_position("AAPL")
assert position.ticker == "AAPL"
assert position.quantity == 0
assert "AAPL" in portfolio.positions
def test_positions_value(self):
portfolio = PortfolioSnapshot(
cash=Decimal("50000"),
positions={
"AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")),
"GOOGL": Position(ticker="GOOGL", quantity=50, avg_cost=Decimal("100")),
},
)
prices = {"AAPL": Decimal("160"), "GOOGL": Decimal("110")}
assert portfolio.positions_value(prices) == Decimal("21500")
def test_total_equity(self):
portfolio = PortfolioSnapshot(
cash=Decimal("50000"),
positions={
"AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")),
},
)
prices = {"AAPL": Decimal("160")}
assert portfolio.total_equity(prices) == Decimal("66000")
def test_total_unrealized_pnl(self):
portfolio = PortfolioSnapshot(
cash=Decimal("50000"),
positions={
"AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")),
},
)
prices = {"AAPL": Decimal("160")}
assert portfolio.total_unrealized_pnl(prices) == Decimal("1000")
def test_apply_buy_fill(self):
portfolio = PortfolioSnapshot(cash=Decimal("100000"))
fill = Fill(
order_id=uuid4(),
ticker="AAPL",
side=OrderSide.BUY,
quantity=100,
price=Decimal("150.00"),
commission=Decimal("5.00"),
)
portfolio.apply_fill(fill)
assert portfolio.cash == Decimal("84995.00")
assert portfolio.positions["AAPL"].quantity == 100
assert portfolio.total_commission_paid == Decimal("5.00")
def test_apply_sell_fill_with_profit(self):
portfolio = PortfolioSnapshot(
cash=Decimal("50000"),
positions={
"AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")),
},
)
fill = Fill(
order_id=uuid4(),
ticker="AAPL",
side=OrderSide.SELL,
quantity=100,
price=Decimal("160.00"),
commission=Decimal("5.00"),
)
portfolio.apply_fill(fill)
assert portfolio.cash == Decimal("65995.00")
assert portfolio.positions["AAPL"].quantity == 0
assert portfolio.total_realized_pnl == Decimal("1000.00")
def test_add_deposit(self):
portfolio = PortfolioSnapshot(cash=Decimal("100000"))
transaction = CashTransaction(
transaction_type=TransactionType.DEPOSIT,
amount=Decimal("10000"),
)
portfolio.add_cash_transaction(transaction)
assert portfolio.cash == Decimal("110000")
assert len(portfolio.cash_transactions) == 1
def test_add_withdrawal(self):
portfolio = PortfolioSnapshot(cash=Decimal("100000"))
transaction = CashTransaction(
transaction_type=TransactionType.WITHDRAWAL,
amount=Decimal("10000"),
)
portfolio.add_cash_transaction(transaction)
assert portfolio.cash == Decimal("90000")
def test_can_afford(self):
portfolio = PortfolioSnapshot(cash=Decimal("10000"))
config = PortfolioConfig(commission_per_trade=Decimal("5"))
assert portfolio.can_afford("AAPL", 10, Decimal("100"), config)
assert not portfolio.can_afford("AAPL", 100, Decimal("100"), config)
def test_max_shares_affordable(self):
portfolio = PortfolioSnapshot(cash=Decimal("10000"))
config = PortfolioConfig(commission_per_trade=Decimal("0"))
max_shares = portfolio.max_shares_affordable("AAPL", Decimal("100"), config)
assert max_shares == 100
def test_max_shares_affordable_with_commission(self):
portfolio = PortfolioSnapshot(cash=Decimal("10050"))
config = PortfolioConfig(commission_per_trade=Decimal("50"))
max_shares = portfolio.max_shares_affordable("AAPL", Decimal("100"), config)
assert max_shares == 100
def test_to_dict(self):
portfolio = PortfolioSnapshot(
cash=Decimal("50000"),
positions={
"AAPL": Position(ticker="AAPL", quantity=100, avg_cost=Decimal("150")),
},
)
prices = {"AAPL": Decimal("160")}
result = portfolio.to_dict(prices)
assert result["cash"] == 50000.0
assert result["positions_value"] == 16000.0
assert result["total_equity"] == 66000.0
assert result["position_count"] == 1

View File

@ -0,0 +1,297 @@
from datetime import datetime
from decimal import Decimal
from uuid import uuid4
import pytest
from tradingagents.models.trading import (
OrderSide,
OrderType,
OrderStatus,
PositionSide,
Order,
Fill,
Position,
Trade,
)
class TestOrder:
def test_market_order_creation(self):
order = Order(
ticker="AAPL",
side=OrderSide.BUY,
quantity=100,
)
assert order.ticker == "AAPL"
assert order.side == OrderSide.BUY
assert order.order_type == OrderType.MARKET
assert order.quantity == 100
assert order.status == OrderStatus.PENDING
assert order.remaining_quantity == 100
assert not order.is_complete
def test_limit_order_creation(self):
order = Order(
ticker="AAPL",
side=OrderSide.SELL,
order_type=OrderType.LIMIT,
quantity=50,
limit_price=Decimal("150.00"),
)
assert order.order_type == OrderType.LIMIT
assert order.limit_price == Decimal("150.00")
def test_order_partial_fill(self):
order = Order(
ticker="AAPL",
side=OrderSide.BUY,
quantity=100,
status=OrderStatus.PARTIAL,
filled_quantity=30,
)
assert order.remaining_quantity == 70
assert not order.is_complete
def test_order_complete_states(self):
for status in [OrderStatus.FILLED, OrderStatus.CANCELLED, OrderStatus.REJECTED]:
order = Order(
ticker="AAPL",
side=OrderSide.BUY,
quantity=100,
status=status,
)
assert order.is_complete
def test_invalid_quantity(self):
with pytest.raises(ValueError):
Order(ticker="AAPL", side=OrderSide.BUY, quantity=0)
def test_invalid_limit_price(self):
with pytest.raises(ValueError):
Order(
ticker="AAPL",
side=OrderSide.BUY,
order_type=OrderType.LIMIT,
quantity=100,
limit_price=Decimal("-10"),
)
class TestFill:
def test_buy_fill(self):
order_id = uuid4()
fill = Fill(
order_id=order_id,
ticker="AAPL",
side=OrderSide.BUY,
quantity=100,
price=Decimal("150.00"),
commission=Decimal("1.00"),
)
assert fill.total_value == Decimal("15000.00")
assert fill.total_cost == Decimal("15001.00")
def test_sell_fill(self):
order_id = uuid4()
fill = Fill(
order_id=order_id,
ticker="AAPL",
side=OrderSide.SELL,
quantity=100,
price=Decimal("150.00"),
commission=Decimal("1.00"),
)
assert fill.total_value == Decimal("15000.00")
assert fill.total_cost == Decimal("14999.00")
class TestPosition:
def test_new_position(self):
position = Position(ticker="AAPL")
assert position.quantity == 0
assert position.side == PositionSide.FLAT
assert position.cost_basis == Decimal("0")
def test_long_position(self):
position = Position(
ticker="AAPL",
quantity=100,
avg_cost=Decimal("150.00"),
)
assert position.side == PositionSide.LONG
assert position.cost_basis == Decimal("15000.00")
def test_short_position(self):
position = Position(
ticker="AAPL",
quantity=-100,
avg_cost=Decimal("150.00"),
)
assert position.side == PositionSide.SHORT
assert position.cost_basis == Decimal("15000.00")
def test_unrealized_pnl_long(self):
position = Position(
ticker="AAPL",
quantity=100,
avg_cost=Decimal("150.00"),
)
pnl = position.unrealized_pnl(Decimal("160.00"))
assert pnl == Decimal("1000.00")
def test_unrealized_pnl_short(self):
position = Position(
ticker="AAPL",
quantity=-100,
avg_cost=Decimal("150.00"),
)
pnl = position.unrealized_pnl(Decimal("140.00"))
assert pnl == Decimal("1000.00")
def test_update_from_buy_fill_new_position(self):
position = Position(ticker="AAPL")
fill = Fill(
order_id=uuid4(),
ticker="AAPL",
side=OrderSide.BUY,
quantity=100,
price=Decimal("150.00"),
)
position.update_from_fill(fill)
assert position.quantity == 100
assert position.avg_cost == Decimal("150.00")
def test_update_from_buy_fill_add_to_position(self):
position = Position(
ticker="AAPL",
quantity=100,
avg_cost=Decimal("150.00"),
)
fill = Fill(
order_id=uuid4(),
ticker="AAPL",
side=OrderSide.BUY,
quantity=100,
price=Decimal("160.00"),
)
position.update_from_fill(fill)
assert position.quantity == 200
assert position.avg_cost == Decimal("155.00")
def test_update_from_sell_fill_close_position(self):
position = Position(
ticker="AAPL",
quantity=100,
avg_cost=Decimal("150.00"),
)
fill = Fill(
order_id=uuid4(),
ticker="AAPL",
side=OrderSide.SELL,
quantity=100,
price=Decimal("160.00"),
)
position.update_from_fill(fill)
assert position.quantity == 0
assert position.realized_pnl == Decimal("1000.00")
def test_update_from_sell_fill_partial_close(self):
position = Position(
ticker="AAPL",
quantity=100,
avg_cost=Decimal("150.00"),
)
fill = Fill(
order_id=uuid4(),
ticker="AAPL",
side=OrderSide.SELL,
quantity=50,
price=Decimal("160.00"),
)
position.update_from_fill(fill)
assert position.quantity == 50
assert position.realized_pnl == Decimal("500.00")
assert position.avg_cost == Decimal("150.00")
class TestTrade:
def test_open_trade(self):
trade = Trade(
ticker="AAPL",
side=OrderSide.BUY,
entry_price=Decimal("150.00"),
entry_quantity=100,
entry_time=datetime(2024, 1, 15, 9, 30),
)
assert not trade.is_closed
assert trade.pnl is None
assert trade.holding_period is None
def test_closed_trade_profit(self):
trade = Trade(
ticker="AAPL",
side=OrderSide.BUY,
entry_price=Decimal("150.00"),
entry_quantity=100,
entry_time=datetime(2024, 1, 15),
exit_price=Decimal("160.00"),
exit_quantity=100,
exit_time=datetime(2024, 1, 25),
)
assert trade.is_closed
assert trade.pnl == Decimal("1000.00")
assert trade.holding_period == 10
def test_closed_trade_loss(self):
trade = Trade(
ticker="AAPL",
side=OrderSide.BUY,
entry_price=Decimal("150.00"),
entry_quantity=100,
entry_time=datetime(2024, 1, 15),
exit_price=Decimal("140.00"),
exit_quantity=100,
exit_time=datetime(2024, 1, 20),
)
assert trade.pnl == Decimal("-1000.00")
def test_trade_with_commission(self):
trade = Trade(
ticker="AAPL",
side=OrderSide.BUY,
entry_price=Decimal("150.00"),
entry_quantity=100,
entry_time=datetime(2024, 1, 15),
exit_price=Decimal("160.00"),
exit_quantity=100,
exit_time=datetime(2024, 1, 20),
commission=Decimal("10.00"),
)
assert trade.pnl == Decimal("990.00")
def test_short_trade_profit(self):
trade = Trade(
ticker="AAPL",
side=OrderSide.SELL,
entry_price=Decimal("150.00"),
entry_quantity=100,
entry_time=datetime(2024, 1, 15),
exit_price=Decimal("140.00"),
exit_quantity=100,
exit_time=datetime(2024, 1, 20),
)
assert trade.pnl == Decimal("1000.00")
def test_pnl_percent(self):
trade = Trade(
ticker="AAPL",
side=OrderSide.BUY,
entry_price=Decimal("100.00"),
entry_quantity=100,
entry_time=datetime(2024, 1, 15),
exit_price=Decimal("110.00"),
exit_quantity=100,
exit_time=datetime(2024, 1, 20),
)
assert trade.pnl_percent == Decimal("10")

View File

@ -10,7 +10,7 @@ from .analysts.social_media_analyst import create_social_media_analyst
from .researchers.bear_researcher import create_bear_researcher
from .researchers.bull_researcher import create_bull_researcher
from .risk_mgmt.aggresive_debator import create_risky_debator
from .risk_mgmt.aggressive_debator import create_risky_debator
from .risk_mgmt.conservative_debator import create_safe_debator
from .risk_mgmt.neutral_debator import create_neutral_debator

View File

@ -0,0 +1,13 @@
from .data_loader import DataLoader
from .engine import BacktestEngine, SimpleBacktestEngine
from .metrics import MetricsCalculator
from .agent_integration import AgentBacktestEngine, run_agent_backtest
__all__ = [
"DataLoader",
"BacktestEngine",
"SimpleBacktestEngine",
"MetricsCalculator",
"AgentBacktestEngine",
"run_agent_backtest",
]

View File

@ -0,0 +1,249 @@
import logging
import re
from datetime import date, datetime
from decimal import Decimal
from typing import Optional, Dict, Any
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.models.backtest import BacktestConfig, BacktestResult
from tradingagents.models.decisions import (
SignalType,
TradingDecision,
AnalystReport,
AnalystType,
)
from tradingagents.models.portfolio import PortfolioSnapshot
from .engine import BacktestEngine
from .data_loader import DataLoader
logger = logging.getLogger(__name__)
class AgentBacktestEngine(BacktestEngine):
def __init__(
self,
config: BacktestConfig,
agent_config: Optional[Dict[str, Any]] = None,
):
super().__init__(config)
self.agent_config = agent_config or config.agent_config
self.trading_graph: Optional[TradingAgentsGraph] = None
self._decision_cache: Dict[str, TradingDecision] = {}
def _initialize(self):
super()._initialize()
graph_config = {
**self.agent_config,
}
self.trading_graph = TradingAgentsGraph(
selected_analysts=self.agent_config.get(
"selected_analysts",
["market", "social", "news", "fundamentals"],
),
debug=self.agent_config.get("debug", False),
config=graph_config if graph_config else None,
)
def _get_decision(
self,
ticker: str,
trading_date: date,
day_index: int,
) -> Optional[TradingDecision]:
cache_key = f"{ticker}_{trading_date}"
if cache_key in self._decision_cache:
return self._decision_cache[cache_key]
try:
final_state, signal_info = self.trading_graph.propagate(
ticker, trading_date
)
decision = self._parse_agent_decision(
ticker, trading_date, final_state, signal_info
)
self._decision_cache[cache_key] = decision
return decision
except (ValueError, KeyError, RuntimeError, ConnectionError, TimeoutError) as e:
logger.error(
"Agent decision failed for %s on %s: %s",
ticker, trading_date, e
)
return None
def _parse_agent_decision(
self,
ticker: str,
trading_date: date,
final_state: Dict[str, Any],
signal_info: Dict[str, Any],
) -> TradingDecision:
signal = self._extract_signal(signal_info)
confidence = self._extract_confidence(signal_info)
analyst_reports = []
if final_state.get("market_report"):
analyst_reports.append(
AnalystReport(
analyst_type=AnalystType.MARKET,
ticker=ticker,
report_date=datetime.combine(trading_date, datetime.min.time()),
summary=final_state["market_report"][:500],
raw_content=final_state["market_report"],
)
)
if final_state.get("sentiment_report"):
analyst_reports.append(
AnalystReport(
analyst_type=AnalystType.SENTIMENT,
ticker=ticker,
report_date=datetime.combine(trading_date, datetime.min.time()),
summary=final_state["sentiment_report"][:500],
raw_content=final_state["sentiment_report"],
)
)
if final_state.get("news_report"):
analyst_reports.append(
AnalystReport(
analyst_type=AnalystType.NEWS,
ticker=ticker,
report_date=datetime.combine(trading_date, datetime.min.time()),
summary=final_state["news_report"][:500],
raw_content=final_state["news_report"],
)
)
if final_state.get("fundamentals_report"):
analyst_reports.append(
AnalystReport(
analyst_type=AnalystType.FUNDAMENTALS,
ticker=ticker,
report_date=datetime.combine(trading_date, datetime.min.time()),
summary=final_state["fundamentals_report"][:500],
raw_content=final_state["fundamentals_report"],
)
)
debate_state = final_state.get("investment_debate_state", {})
bull_argument = None
bear_argument = None
if debate_state.get("bull_history"):
bull_argument = debate_state["bull_history"][-1] if debate_state["bull_history"] else None
if debate_state.get("bear_history"):
bear_argument = debate_state["bear_history"][-1] if debate_state["bear_history"] else None
risk_state = final_state.get("risk_debate_state", {})
risk_approved = self._extract_risk_approval(risk_state)
final_decision_text = final_state.get("final_trade_decision", "")
recommended_action = self._extract_action(signal_info, final_decision_text)
return TradingDecision(
ticker=ticker,
timestamp=datetime.now(),
decision_date=datetime.combine(trading_date, datetime.min.time()),
signal=signal,
confidence=confidence,
recommended_action=recommended_action,
analyst_reports=analyst_reports,
bull_argument=bull_argument,
bear_argument=bear_argument,
debate_rounds=debate_state.get("count", 0),
risk_manager_approved=risk_approved,
final_decision=recommended_action,
rationale=final_decision_text[:1000] if final_decision_text else "",
)
def _extract_signal(self, signal_info: Dict[str, Any]) -> SignalType:
action = signal_info.get("action", "").upper()
direction = signal_info.get("direction", "").upper()
if action == "BUY" or direction == "BULLISH":
confidence = signal_info.get("confidence", 0.5)
if confidence > 0.8:
return SignalType.STRONG_BUY
return SignalType.BUY
elif action == "SELL" or direction == "BEARISH":
confidence = signal_info.get("confidence", 0.5)
if confidence > 0.8:
return SignalType.STRONG_SELL
return SignalType.SELL
return SignalType.HOLD
def _extract_confidence(self, signal_info: Dict[str, Any]) -> Decimal:
confidence = signal_info.get("confidence", 0.5)
if isinstance(confidence, str):
try:
confidence = float(confidence.replace("%", "")) / 100
except ValueError:
confidence = 0.5
return Decimal(str(min(max(float(confidence), 0.0), 1.0)))
def _extract_action(
self,
signal_info: Dict[str, Any],
final_decision_text: str,
) -> str:
action = signal_info.get("action", "")
if action:
return action.upper()
text_upper = final_decision_text.upper()
if "BUY" in text_upper and "DON'T BUY" not in text_upper:
return "BUY"
elif "SELL" in text_upper:
return "SELL"
return "HOLD"
def _extract_risk_approval(self, risk_state: Dict[str, Any]) -> Optional[bool]:
judge_decision = risk_state.get("judge_decision", "")
if not judge_decision:
return None
text_upper = judge_decision.upper()
if "APPROVE" in text_upper or "ACCEPT" in text_upper:
return True
elif "REJECT" in text_upper or "DENY" in text_upper:
return False
return None
def run_agent_backtest(
tickers: list[str],
start_date: date,
end_date: date,
initial_cash: Decimal = Decimal("100000"),
agent_config: Optional[Dict[str, Any]] = None,
) -> BacktestResult:
from tradingagents.models.portfolio import PortfolioConfig
config = BacktestConfig(
name=f"Agent Backtest - {', '.join(tickers)}",
tickers=tickers,
start_date=start_date,
end_date=end_date,
portfolio_config=PortfolioConfig(
initial_cash=initial_cash,
commission_per_trade=Decimal("1"),
slippage_percent=Decimal("0.05"),
),
warmup_period=5,
agent_config=agent_config or {},
)
engine = AgentBacktestEngine(config, agent_config)
return engine.run()

View File

@ -0,0 +1,244 @@
import logging
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Optional
import pandas as pd
import yfinance as yf
from stockstats import wrap
from tradingagents.models.market_data import (
OHLCV,
OHLCVBar,
TechnicalIndicators,
HistoricalDataRequest,
HistoricalDataResponse,
)
logger = logging.getLogger(__name__)
class DataLoader:
def __init__(self, cache_dir: Optional[str] = None):
self.cache_dir = cache_dir
self._cache: dict[str, pd.DataFrame] = {}
def load_ohlcv(
self,
ticker: str,
start_date: date,
end_date: date,
interval: str = "1d",
) -> OHLCV:
ticker = ticker.upper()
cache_key = f"{ticker}_{start_date}_{end_date}_{interval}"
if cache_key in self._cache:
df = self._cache[cache_key]
else:
df = self._fetch_from_yfinance(ticker, start_date, end_date, interval)
self._cache[cache_key] = df
bars = self._dataframe_to_bars(df)
return OHLCV(ticker=ticker, bars=bars, interval=interval)
def load_historical_data(
self,
request: HistoricalDataRequest,
) -> HistoricalDataResponse:
ohlcv = self.load_ohlcv(
request.ticker,
request.start_date,
request.end_date,
request.interval,
)
indicators = []
if request.include_indicators and ohlcv.bars:
indicators = self._calculate_indicators(
request.ticker,
request.start_date,
request.end_date,
)
return HistoricalDataResponse(
request=request,
ohlcv=ohlcv,
indicators=indicators,
source="yfinance",
)
def _fetch_from_yfinance(
self,
ticker: str,
start_date: date,
end_date: date,
interval: str,
) -> pd.DataFrame:
start_str = start_date.strftime("%Y-%m-%d")
end_str = (end_date + timedelta(days=1)).strftime("%Y-%m-%d")
df = yf.download(
ticker,
start=start_str,
end=end_str,
interval=interval,
multi_level_index=False,
progress=False,
auto_adjust=False,
)
if df.empty:
logger.warning("No data returned for %s from %s to %s", ticker, start_date, end_date)
return pd.DataFrame()
df = df.reset_index()
return df
def _dataframe_to_bars(self, df: pd.DataFrame) -> list[OHLCVBar]:
if df.empty:
return []
bars = []
for _, row in df.iterrows():
timestamp = row.get("Date") or row.get("Datetime")
if isinstance(timestamp, str):
timestamp = pd.to_datetime(timestamp)
if hasattr(timestamp, "to_pydatetime"):
timestamp = timestamp.to_pydatetime()
if timestamp.tzinfo is not None:
timestamp = timestamp.replace(tzinfo=None)
bar = OHLCVBar(
timestamp=timestamp,
open=Decimal(str(round(row["Open"], 4))),
high=Decimal(str(round(row["High"], 4))),
low=Decimal(str(round(row["Low"], 4))),
close=Decimal(str(round(row["Close"], 4))),
volume=int(row["Volume"]),
adjusted_close=Decimal(str(round(row["Adj Close"], 4))) if "Adj Close" in row else None,
)
bars.append(bar)
return bars
def _calculate_indicators(
self,
ticker: str,
start_date: date,
end_date: date,
) -> list[TechnicalIndicators]:
lookback_start = start_date - timedelta(days=250)
cache_key = f"{ticker}_{lookback_start}_{end_date}_1d"
if cache_key in self._cache:
df = self._cache[cache_key]
else:
df = self._fetch_from_yfinance(ticker, lookback_start, end_date, "1d")
self._cache[cache_key] = df
if df.empty:
return []
stock = wrap(df.copy())
stock["close_20_sma"]
stock["close_50_sma"]
stock["close_200_sma"]
stock["close_10_ema"]
stock["close_20_ema"]
stock["rsi_14"]
stock["macd"]
stock["macds"]
stock["macdh"]
stock["boll"]
stock["boll_ub"]
stock["boll_lb"]
stock["atr_14"]
stock["mfi_14"]
indicators = []
for _, row in stock.iterrows():
timestamp = row.get("Date") or row.get("Datetime")
if isinstance(timestamp, str):
timestamp = pd.to_datetime(timestamp)
if hasattr(timestamp, "to_pydatetime"):
timestamp = timestamp.to_pydatetime()
if timestamp.tzinfo is not None:
timestamp = timestamp.replace(tzinfo=None)
if timestamp.date() < start_date or timestamp.date() > end_date:
continue
ind = TechnicalIndicators(
timestamp=timestamp,
ticker=ticker,
sma_20=self._safe_decimal(row.get("close_20_sma")),
sma_50=self._safe_decimal(row.get("close_50_sma")),
sma_200=self._safe_decimal(row.get("close_200_sma")),
ema_10=self._safe_decimal(row.get("close_10_ema")),
ema_20=self._safe_decimal(row.get("close_20_ema")),
rsi_14=self._safe_decimal(row.get("rsi_14")),
macd=self._safe_decimal(row.get("macd")),
macd_signal=self._safe_decimal(row.get("macds")),
macd_histogram=self._safe_decimal(row.get("macdh")),
bollinger_middle=self._safe_decimal(row.get("boll")),
bollinger_upper=self._safe_decimal(row.get("boll_ub")),
bollinger_lower=self._safe_decimal(row.get("boll_lb")),
atr_14=self._safe_decimal(row.get("atr_14")),
mfi_14=self._safe_decimal(row.get("mfi_14")),
)
indicators.append(ind)
return indicators
@staticmethod
def _safe_decimal(value) -> Optional[Decimal]:
if value is None or pd.isna(value):
return None
return Decimal(str(round(float(value), 4)))
def get_price_on_date(
self,
ticker: str,
target_date: date,
ohlcv: Optional[OHLCV] = None,
) -> Optional[Decimal]:
if ohlcv is None:
ohlcv = self.load_ohlcv(ticker, target_date - timedelta(days=5), target_date)
target_datetime = datetime.combine(target_date, datetime.min.time())
bar = ohlcv.get_bar(target_datetime)
if bar:
return bar.close
for b in reversed(ohlcv.bars):
if b.timestamp.date() <= target_date:
return b.close
return None
def get_prices_dict(
self,
tickers: list[str],
target_date: date,
) -> dict[str, Decimal]:
prices = {}
for ticker in tickers:
price = self.get_price_on_date(ticker, target_date)
if price is not None:
prices[ticker] = price
return prices
def get_trading_days(
self,
ticker: str,
start_date: date,
end_date: date,
) -> list[date]:
ohlcv = self.load_ohlcv(ticker, start_date, end_date)
return [bar.timestamp.date() for bar in ohlcv.bars]
def clear_cache(self):
self._cache.clear()

View File

@ -0,0 +1,366 @@
import logging
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Optional, Callable
from uuid import uuid4
from tradingagents.models.backtest import (
BacktestConfig,
BacktestResult,
EquityCurvePoint,
TradeLog,
)
from tradingagents.models.decisions import SignalType, TradingDecision
from tradingagents.models.portfolio import PortfolioSnapshot
from tradingagents.models.trading import Order, OrderSide, OrderStatus, Fill, Trade
from .data_loader import DataLoader
from .metrics import MetricsCalculator
logger = logging.getLogger(__name__)
class BacktestEngine:
def __init__(
self,
config: BacktestConfig,
decision_callback: Optional[Callable[[str, date, dict], TradingDecision]] = None,
):
self.config = config
self.decision_callback = decision_callback
self.data_loader = DataLoader()
self.metrics_calculator = MetricsCalculator(config.risk_free_rate)
self.portfolio: Optional[PortfolioSnapshot] = None
self.trade_log: Optional[TradeLog] = None
self.equity_curve: list[EquityCurvePoint] = []
self.daily_returns: list[Decimal] = []
self.decisions: list[TradingDecision] = []
self.open_trades: dict[str, Trade] = {}
def run(self) -> BacktestResult:
started_at = datetime.now()
try:
self._initialize()
self._preload_data()
trading_days = self._get_trading_days()
for i, trading_date in enumerate(trading_days):
if i < self.config.warmup_period:
continue
self._process_day(trading_date, i)
self._close_all_positions(trading_days[-1] if trading_days else self.config.end_date)
metrics = self.metrics_calculator.calculate_metrics(
self.equity_curve,
self.trade_log,
)
completed_at = datetime.now()
return BacktestResult(
config=self.config,
metrics=metrics,
trade_log=self.trade_log,
equity_curve=self.equity_curve,
daily_returns=self.daily_returns,
started_at=started_at,
completed_at=completed_at,
status="completed",
)
except (ValueError, KeyError, RuntimeError, FileNotFoundError, OSError) as e:
logger.exception("Backtest failed: %s", e)
completed_at = datetime.now()
return BacktestResult(
config=self.config,
metrics=self._empty_metrics(),
trade_log=self.trade_log or TradeLog(),
equity_curve=self.equity_curve,
daily_returns=self.daily_returns,
started_at=started_at,
completed_at=completed_at,
status="failed",
error_message=str(e),
)
def _initialize(self):
self.portfolio = PortfolioSnapshot(
cash=self.config.portfolio_config.initial_cash,
)
self.trade_log = TradeLog()
self.equity_curve = []
self.daily_returns = []
self.decisions = []
self.open_trades = {}
def _preload_data(self):
logger.info("Preloading data for %s tickers", len(self.config.tickers))
for ticker in self.config.tickers:
self.data_loader.load_ohlcv(
ticker,
self.config.start_date - timedelta(days=self.config.warmup_period + 10),
self.config.end_date,
)
def _get_trading_days(self) -> list[date]:
primary_ticker = self.config.tickers[0]
return self.data_loader.get_trading_days(
primary_ticker,
self.config.start_date,
self.config.end_date,
)
def _process_day(self, trading_date: date, day_index: int):
prices = self.data_loader.get_prices_dict(self.config.tickers, trading_date)
if not prices:
logger.debug("No prices available for %s", trading_date)
return
for ticker in self.config.tickers:
if ticker not in prices:
continue
decision = self._get_decision(ticker, trading_date, day_index)
if decision:
self.decisions.append(decision)
self._execute_decision(decision, prices[ticker], trading_date)
self._record_equity(trading_date, prices)
def _get_decision(
self,
ticker: str,
trading_date: date,
day_index: int,
) -> Optional[TradingDecision]:
if self.decision_callback:
context = {
"day_index": day_index,
"portfolio": self.portfolio,
"open_trade": self.open_trades.get(ticker),
}
return self.decision_callback(ticker, trading_date, context)
return self._simple_strategy(ticker, trading_date)
def _simple_strategy(
self,
ticker: str,
trading_date: date,
) -> Optional[TradingDecision]:
return None
def _execute_decision(
self,
decision: TradingDecision,
price: Decimal,
trading_date: date,
):
ticker = decision.ticker
config = self.config.portfolio_config
position = self.portfolio.get_position(ticker)
if decision.is_buy and position.quantity == 0:
execution_price = config.calculate_slippage(price, OrderSide.BUY)
if decision.recommended_quantity:
quantity = decision.recommended_quantity
else:
max_position_value = self.portfolio.cash * (config.max_position_size_percent / 100)
quantity = int(max_position_value / execution_price)
if quantity <= 0:
return
if not self.portfolio.can_afford(ticker, quantity, execution_price, config):
quantity = self.portfolio.max_shares_affordable(ticker, execution_price, config)
if quantity <= 0:
return
commission = config.calculate_commission(quantity, execution_price)
order = Order(
ticker=ticker,
side=OrderSide.BUY,
quantity=quantity,
status=OrderStatus.FILLED,
filled_quantity=quantity,
filled_avg_price=execution_price,
filled_at=datetime.combine(trading_date, datetime.min.time()),
commission=commission,
)
fill = Fill(
order_id=order.id,
ticker=ticker,
side=OrderSide.BUY,
quantity=quantity,
price=execution_price,
commission=commission,
timestamp=datetime.combine(trading_date, datetime.min.time()),
)
self.portfolio.apply_fill(fill)
trade = Trade(
ticker=ticker,
side=OrderSide.BUY,
entry_price=execution_price,
entry_quantity=quantity,
entry_time=datetime.combine(trading_date, datetime.min.time()),
entry_order_id=order.id,
)
self.open_trades[ticker] = trade
logger.debug(
"BUY %s: %d shares @ $%.2f on %s",
ticker, quantity, execution_price, trading_date
)
elif decision.is_sell and position.quantity > 0:
execution_price = config.calculate_slippage(price, OrderSide.SELL)
quantity = position.quantity
commission = config.calculate_commission(quantity, execution_price)
order = Order(
ticker=ticker,
side=OrderSide.SELL,
quantity=quantity,
status=OrderStatus.FILLED,
filled_quantity=quantity,
filled_avg_price=execution_price,
filled_at=datetime.combine(trading_date, datetime.min.time()),
commission=commission,
)
fill = Fill(
order_id=order.id,
ticker=ticker,
side=OrderSide.SELL,
quantity=quantity,
price=execution_price,
commission=commission,
timestamp=datetime.combine(trading_date, datetime.min.time()),
)
self.portfolio.apply_fill(fill)
if ticker in self.open_trades:
trade = self.open_trades[ticker]
trade.exit_price = execution_price
trade.exit_quantity = quantity
trade.exit_time = datetime.combine(trading_date, datetime.min.time())
trade.exit_order_id = order.id
trade.commission = (
config.calculate_commission(trade.entry_quantity, trade.entry_price) +
commission
)
self.trade_log.add_trade(trade)
del self.open_trades[ticker]
logger.debug(
"SELL %s: %d shares @ $%.2f on %s",
ticker, quantity, execution_price, trading_date
)
def _record_equity(self, trading_date: date, prices: dict[str, Decimal]):
equity = self.portfolio.total_equity(prices)
positions_value = self.portfolio.positions_value(prices)
point = EquityCurvePoint(
timestamp=datetime.combine(trading_date, datetime.min.time()),
equity=equity,
cash=self.portfolio.cash,
positions_value=positions_value,
)
self.equity_curve.append(point)
if len(self.equity_curve) > 1:
prev_equity = self.equity_curve[-2].equity
if prev_equity > 0:
daily_return = (equity - prev_equity) / prev_equity
self.daily_returns.append(daily_return)
def _close_all_positions(self, final_date: date):
prices = self.data_loader.get_prices_dict(self.config.tickers, final_date)
for ticker, trade in list(self.open_trades.items()):
if ticker in prices:
decision = TradingDecision(
ticker=ticker,
timestamp=datetime.now(),
decision_date=datetime.combine(final_date, datetime.min.time()),
signal=SignalType.SELL,
confidence=Decimal("1.0"),
recommended_action="SELL",
final_decision="SELL - End of backtest",
rationale="Closing position at end of backtest period",
)
self._execute_decision(decision, prices[ticker], final_date)
def _empty_metrics(self):
from tradingagents.models.backtest import BacktestMetrics
return BacktestMetrics(
start_equity=self.config.portfolio_config.initial_cash,
end_equity=self.portfolio.cash if self.portfolio else self.config.portfolio_config.initial_cash,
)
class SimpleBacktestEngine(BacktestEngine):
def __init__(
self,
config: BacktestConfig,
buy_signal: Callable[[str, date, dict], bool] = None,
sell_signal: Callable[[str, date, dict], bool] = None,
):
super().__init__(config)
self.buy_signal = buy_signal
self.sell_signal = sell_signal
def _get_decision(
self,
ticker: str,
trading_date: date,
day_index: int,
) -> Optional[TradingDecision]:
context = {
"day_index": day_index,
"portfolio": self.portfolio,
"data_loader": self.data_loader,
"open_trade": self.open_trades.get(ticker),
}
position = self.portfolio.get_position(ticker)
if position.quantity == 0 and self.buy_signal and self.buy_signal(ticker, trading_date, context):
return TradingDecision(
ticker=ticker,
timestamp=datetime.now(),
decision_date=datetime.combine(trading_date, datetime.min.time()),
signal=SignalType.BUY,
confidence=Decimal("0.7"),
recommended_action="BUY",
final_decision="BUY",
rationale="Buy signal triggered",
)
if position.quantity > 0 and self.sell_signal and self.sell_signal(ticker, trading_date, context):
return TradingDecision(
ticker=ticker,
timestamp=datetime.now(),
decision_date=datetime.combine(trading_date, datetime.min.time()),
signal=SignalType.SELL,
confidence=Decimal("0.7"),
recommended_action="SELL",
final_decision="SELL",
rationale="Sell signal triggered",
)
return None

View File

@ -0,0 +1,281 @@
import math
from decimal import Decimal
from typing import Optional
from tradingagents.models.backtest import BacktestMetrics, EquityCurvePoint, TradeLog
class MetricsCalculator:
TRADING_DAYS_PER_YEAR = 252
def __init__(self, risk_free_rate: Decimal = Decimal("0.05")):
self.risk_free_rate = risk_free_rate
def calculate_metrics(
self,
equity_curve: list[EquityCurvePoint],
trade_log: TradeLog,
benchmark_curve: Optional[list[EquityCurvePoint]] = None,
) -> BacktestMetrics:
if not equity_curve:
raise ValueError("Equity curve cannot be empty")
start_equity = equity_curve[0].equity
end_equity = equity_curve[-1].equity
trading_days = len(equity_curve)
total_return = end_equity - start_equity
total_return_percent = (total_return / start_equity) * 100
years = Decimal(trading_days) / Decimal(self.TRADING_DAYS_PER_YEAR)
if years > 0:
annualized_return = ((end_equity / start_equity) ** (1 / years) - 1) * 100
else:
annualized_return = Decimal("0")
daily_returns = self._calculate_daily_returns(equity_curve)
volatility = self._calculate_volatility(daily_returns)
annualized_volatility = volatility * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR))
downside_returns = [r for r in daily_returns if r < 0]
downside_volatility = self._calculate_volatility(downside_returns) if downside_returns else Decimal("0")
annualized_downside_vol = downside_volatility * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR))
max_dd, max_dd_pct, max_dd_duration, avg_dd = self._calculate_drawdown_metrics(equity_curve)
sharpe = self._calculate_sharpe_ratio(annualized_return, annualized_volatility)
sortino = self._calculate_sortino_ratio(annualized_return, annualized_downside_vol)
calmar = self._calculate_calmar_ratio(annualized_return, max_dd_pct)
benchmark_return = None
benchmark_return_percent = None
alpha = None
beta = None
information_ratio = None
if benchmark_curve and len(benchmark_curve) == len(equity_curve):
benchmark_return = benchmark_curve[-1].equity - benchmark_curve[0].equity
benchmark_return_percent = (benchmark_return / benchmark_curve[0].equity) * 100
benchmark_daily = self._calculate_daily_returns(benchmark_curve)
alpha, beta = self._calculate_alpha_beta(daily_returns, benchmark_daily)
information_ratio = self._calculate_information_ratio(
daily_returns, benchmark_daily
)
all_pnls = [t.pnl for t in trade_log.trades if t.is_closed and t.pnl is not None]
avg_trade_pnl = sum(all_pnls) / len(all_pnls) if all_pnls else None
largest_win = max((p for p in all_pnls if p > 0), default=None)
largest_loss = min((p for p in all_pnls if p < 0), default=None)
return BacktestMetrics(
total_return=total_return,
total_return_percent=total_return_percent,
annualized_return=annualized_return,
benchmark_return=benchmark_return,
benchmark_return_percent=benchmark_return_percent,
alpha=alpha,
beta=beta,
volatility=volatility * 100,
annualized_volatility=annualized_volatility * 100,
downside_volatility=annualized_downside_vol * 100,
sharpe_ratio=sharpe,
sortino_ratio=sortino,
calmar_ratio=calmar,
information_ratio=information_ratio,
max_drawdown=max_dd,
max_drawdown_percent=max_dd_pct,
max_drawdown_duration=max_dd_duration,
avg_drawdown=avg_dd,
total_trades=trade_log.total_trades,
win_rate=trade_log.win_rate,
profit_factor=trade_log.profit_factor,
avg_trade_pnl=avg_trade_pnl,
avg_win=trade_log.avg_win,
avg_loss=trade_log.avg_loss,
largest_win=largest_win,
largest_loss=largest_loss,
avg_holding_period_days=trade_log.avg_holding_period,
trading_days=trading_days,
start_equity=start_equity,
end_equity=end_equity,
)
def _calculate_daily_returns(
self,
equity_curve: list[EquityCurvePoint],
) -> list[Decimal]:
returns = []
for i in range(1, len(equity_curve)):
prev_equity = equity_curve[i - 1].equity
curr_equity = equity_curve[i].equity
if prev_equity > 0:
daily_return = (curr_equity - prev_equity) / prev_equity
returns.append(daily_return)
return returns
def _calculate_volatility(self, returns: list[Decimal]) -> Decimal:
if len(returns) < 2:
return Decimal("0")
mean = sum(returns) / len(returns)
variance = sum((r - mean) ** 2 for r in returns) / (len(returns) - 1)
return Decimal(str(math.sqrt(float(variance))))
def _calculate_drawdown_metrics(
self,
equity_curve: list[EquityCurvePoint],
) -> tuple[Decimal, Decimal, Optional[int], Decimal]:
if not equity_curve:
return Decimal("0"), Decimal("0"), None, Decimal("0")
peak = equity_curve[0].equity
max_drawdown = Decimal("0")
max_drawdown_percent = Decimal("0")
drawdown_start = 0
max_drawdown_duration = 0
current_drawdown_start = 0
in_drawdown = False
drawdowns = []
for i, point in enumerate(equity_curve):
equity = point.equity
if equity > peak:
if in_drawdown:
duration = i - current_drawdown_start
max_drawdown_duration = max(max_drawdown_duration, duration)
in_drawdown = False
peak = equity
current_drawdown_start = i
else:
if not in_drawdown:
in_drawdown = True
current_drawdown_start = i
drawdown = peak - equity
drawdown_pct = (drawdown / peak) * 100 if peak > 0 else Decimal("0")
drawdowns.append(drawdown_pct)
if drawdown > max_drawdown:
max_drawdown = drawdown
max_drawdown_percent = drawdown_pct
point.drawdown = drawdown
point.drawdown_percent = drawdown_pct
if in_drawdown:
duration = len(equity_curve) - current_drawdown_start
max_drawdown_duration = max(max_drawdown_duration, duration)
avg_drawdown = sum(drawdowns) / len(drawdowns) if drawdowns else Decimal("0")
return max_drawdown, max_drawdown_percent, max_drawdown_duration or None, avg_drawdown
def _calculate_sharpe_ratio(
self,
annualized_return: Decimal,
annualized_volatility: Decimal,
) -> Optional[Decimal]:
if annualized_volatility == 0:
return None
excess_return = annualized_return - (self.risk_free_rate * 100)
return excess_return / annualized_volatility
def _calculate_sortino_ratio(
self,
annualized_return: Decimal,
annualized_downside_vol: Decimal,
) -> Optional[Decimal]:
if annualized_downside_vol == 0:
return None
excess_return = annualized_return - (self.risk_free_rate * 100)
return excess_return / annualized_downside_vol
def _calculate_calmar_ratio(
self,
annualized_return: Decimal,
max_drawdown_percent: Decimal,
) -> Optional[Decimal]:
if max_drawdown_percent == 0:
return None
return annualized_return / max_drawdown_percent
def _calculate_alpha_beta(
self,
returns: list[Decimal],
benchmark_returns: list[Decimal],
) -> tuple[Optional[Decimal], Optional[Decimal]]:
if len(returns) != len(benchmark_returns) or len(returns) < 2:
return None, None
n = len(returns)
sum_x = sum(benchmark_returns)
sum_y = sum(returns)
sum_xy = sum(r * b for r, b in zip(returns, benchmark_returns))
sum_xx = sum(b * b for b in benchmark_returns)
denominator = n * sum_xx - sum_x * sum_x
if denominator == 0:
return None, None
beta = (n * sum_xy - sum_x * sum_y) / denominator
alpha = (sum_y - beta * sum_x) / n
alpha_annualized = alpha * self.TRADING_DAYS_PER_YEAR
return alpha_annualized, beta
def _calculate_information_ratio(
self,
returns: list[Decimal],
benchmark_returns: list[Decimal],
) -> Optional[Decimal]:
if len(returns) != len(benchmark_returns) or len(returns) < 2:
return None
excess_returns = [r - b for r, b in zip(returns, benchmark_returns)]
mean_excess = sum(excess_returns) / len(excess_returns)
tracking_error = self._calculate_volatility(excess_returns)
if tracking_error == 0:
return None
annualized_tracking_error = tracking_error * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR))
annualized_excess = mean_excess * self.TRADING_DAYS_PER_YEAR
return annualized_excess / annualized_tracking_error
def calculate_rolling_metrics(
self,
equity_curve: list[EquityCurvePoint],
window: int = 20,
) -> dict[str, list[Decimal]]:
if len(equity_curve) < window:
return {"rolling_sharpe": [], "rolling_volatility": []}
rolling_sharpe = []
rolling_volatility = []
daily_returns = self._calculate_daily_returns(equity_curve)
for i in range(window - 1, len(daily_returns)):
window_returns = daily_returns[i - window + 1:i + 1]
vol = self._calculate_volatility(window_returns)
annualized_vol = vol * Decimal(math.sqrt(self.TRADING_DAYS_PER_YEAR))
mean_return = sum(window_returns) / len(window_returns)
annualized_return = mean_return * self.TRADING_DAYS_PER_YEAR * 100
sharpe = self._calculate_sharpe_ratio(annualized_return, annualized_vol * 100)
rolling_sharpe.append(sharpe if sharpe else Decimal("0"))
rolling_volatility.append(annualized_vol * 100)
return {
"rolling_sharpe": rolling_sharpe,
"rolling_volatility": rolling_volatility,
}

View File

@ -53,7 +53,7 @@ def _make_api_request(function_name: str, params: dict) -> dict | str:
elif "entitlement" in api_params:
api_params.pop("entitlement", None)
response = requests.get(API_BASE_URL, params=api_params)
response = requests.get(API_BASE_URL, params=api_params, timeout=30)
response.raise_for_status()
response_text = response.text
@ -88,6 +88,6 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) ->
return filtered_df.to_csv(index=False)
except Exception as e:
except (pd.errors.ParserError, KeyError, ValueError) as e:
logger.warning("Failed to filter CSV data by date range: %s", e)
return csv_data

View File

@ -1,4 +1,5 @@
import logging
import requests
from .alpha_vantage_common import _make_api_request
logger = logging.getLogger(__name__)
@ -193,6 +194,6 @@ def get_indicator(
return result_str
except Exception as e:
except (ValueError, KeyError, IndexError, requests.RequestException) as e:
logger.error("Error getting Alpha Vantage indicator data for %s: %s", indicator, e)
return f"Error retrieving {indicator} data: {str(e)}"

View File

@ -77,7 +77,8 @@ def get_bulk_news_alpha_vantage(lookback_hours: int) -> List[Dict[str, Any]]:
"content_snippet": item.get("summary", "")[:500],
}
articles.append(article)
except Exception:
except (KeyError, TypeError, AttributeError) as e:
logger.debug("Error parsing news article: %s", e)
continue
return articles

View File

@ -120,7 +120,7 @@ def get_bulk_news_brave(lookback_hours: int) -> List[Dict[str, Any]]:
except requests.exceptions.RequestException as e:
logger.debug("Brave search request failed for '%s': %s", query, e)
continue
except Exception as e:
except (KeyError, TypeError, ValueError) as e:
logger.debug("Brave search failed for query '%s': %s", query, e)
continue

View File

@ -1,10 +1,14 @@
import logging
import re
import requests
from typing import Annotated, List, Dict, Any
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
from dateutil import parser as dateutil_parser
from .googlenews_utils import getNewsData
logger = logging.getLogger(__name__)
def _parse_google_news_date(date_str: str) -> datetime:
if not date_str:
@ -108,7 +112,8 @@ def get_bulk_news_google(lookback_hours: int) -> List[Dict[str, Any]]:
}
all_articles.append(article)
except Exception:
except (TypeError, KeyError, AttributeError, requests.RequestException) as e:
logger.debug("Google News search failed for query '%s': %s", query, e)
continue
return all_articles

View File

@ -27,7 +27,7 @@ def is_rate_limited(response):
)
def make_request(url, headers):
time.sleep(random.uniform(2, 6))
response = requests.get(url, headers=headers)
response = requests.get(url, headers=headers, timeout=30)
return response
@ -81,7 +81,7 @@ def getNewsData(query, start_date, end_date):
"source": source,
}
)
except Exception as e:
except (TypeError, AttributeError, KeyError) as e:
logger.debug("Error processing result: %s", e)
continue
@ -91,7 +91,7 @@ def getNewsData(query, start_date, end_date):
page += 1
except Exception as e:
except (requests.RequestException, ConnectionError, TimeoutError) as e:
logger.debug("Failed after multiple retries: %s", e)
break

View File

@ -191,7 +191,8 @@ def _convert_to_news_articles(raw_articles: List[Dict[str, Any]]) -> List[NewsAr
ticker_mentions=[],
)
articles.append(article)
except Exception:
except (KeyError, TypeError, ValueError) as e:
logger.debug("Error converting article to NewsArticle: %s", e)
continue
return articles
@ -218,7 +219,7 @@ def _fetch_bulk_news_from_vendor(lookback_period: str) -> List[Dict[str, Any]]:
except AlphaVantageRateLimitError as e:
logger.warning("Alpha Vantage rate limit exceeded: %s", e)
continue
except Exception as e:
except (RuntimeError, ConnectionError, TimeoutError, ValueError, OSError) as e:
logger.error("Vendor '%s' failed: %s", vendor, e)
continue
@ -316,7 +317,7 @@ def route_to_vendor(method: str, *args, **kwargs):
logger.warning("Alpha Vantage rate limit exceeded, falling back to next available vendor")
logger.debug("Rate limit details: %s", e)
continue
except Exception as e:
except (RuntimeError, ConnectionError, TimeoutError, ValueError, KeyError, OSError) as e:
logger.error("%s from vendor '%s' failed: %s", impl_func.__name__, vendor_name, e)
continue

View File

@ -170,8 +170,8 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json"
)
data = open(data_path, "r")
data = json.load(data)
with open(data_path, "r") as f:
data = json.load(f)
filtered_data = {}
for key, value in data.items():

View File

@ -54,7 +54,7 @@ def get_stock_news_openai(query, start_date, end_date):
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
store=False,
)
return _extract_response_text(response) or ""
@ -89,7 +89,7 @@ def get_global_news_openai(curr_date, look_back_days=7, limit=5):
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
store=False,
)
return _extract_response_text(response) or ""
@ -124,7 +124,7 @@ def get_fundamentals_openai(ticker, curr_date):
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
store=False,
)
return _extract_response_text(response) or ""
@ -187,7 +187,7 @@ Return ONLY the JSON array, no additional text."""
temperature=0.5,
max_output_tokens=8192,
top_p=1,
store=True,
store=False,
)
try:

View File

@ -36,7 +36,7 @@ def _search_with_retry(client, query: str, search_depth: str, topic: str, time_r
max_results=max_results,
)
return response
except Exception as e:
except (RuntimeError, ConnectionError, TimeoutError, OSError) as e:
error_str = str(e).lower()
if "rate" in error_str or "limit" in error_str or "429" in error_str:
wait_time = RETRY_BACKOFF * (attempt + 1) * 2
@ -123,7 +123,7 @@ def get_bulk_news_tavily(lookback_hours: int) -> List[Dict[str, Any]]:
}
all_articles.append(article)
except Exception as e:
except (RuntimeError, ConnectionError, TimeoutError, OSError, ValueError) as e:
logger.debug("Tavily search failed for query '%s': %s", query, e)
continue

View File

@ -261,7 +261,7 @@ def classify_sector(ticker: str) -> str:
_sector_cache[ticker_upper] = sector
logger.info("Classified %s as %s via LLM", ticker, sector)
return sector
except Exception as e:
except (KeyError, ValueError, RuntimeError, ConnectionError, TimeoutError) as e:
logger.error("LLM sector classification failed for %s: %s", ticker, str(e))
_sector_cache[ticker_upper] = "other"
return "other"

View File

@ -462,7 +462,7 @@ def _search_yfinance_ticker(company_name: str) -> Optional[str]:
info = search_result.info
if info and "symbol" in info:
return info["symbol"]
except Exception as e:
except (KeyError, ValueError, AttributeError, RuntimeError) as e:
logger.debug("yfinance search failed for %s: %s", company_name, str(e))
try:
@ -471,7 +471,7 @@ def _search_yfinance_ticker(company_name: str) -> Optional[str]:
for quote in search.quotes:
if "symbol" in quote:
return quote["symbol"]
except Exception as e:
except (KeyError, ValueError, AttributeError, RuntimeError) as e:
logger.debug("yfinance Search failed for %s: %s", company_name, str(e))
return None
@ -495,7 +495,7 @@ def validate_us_ticker(ticker: str) -> bool:
logger.warning("Validation failed for %s: exchange %s is not a US exchange", ticker, exchange)
return False
except Exception as e:
except (KeyError, ValueError, AttributeError, RuntimeError) as e:
logger.warning("Validation failed for %s: %s", ticker, str(e))
return False

View File

@ -149,7 +149,7 @@ def get_stock_stats_indicators_window(
for date_str, value in date_values:
ind_string += f"{date_str}: {value}\n"
except Exception as e:
except (KeyError, ValueError, FileNotFoundError) as e:
logger.error("Error getting bulk stockstats data: %s", e)
ind_string = ""
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
@ -260,7 +260,7 @@ def get_stockstats_indicator(
indicator,
curr_date,
)
except Exception as e:
except (KeyError, ValueError, IndexError) as e:
logger.error(
"Error getting stockstats indicator data for indicator %s on %s: %s",
indicator, curr_date, e
@ -293,7 +293,7 @@ def get_balance_sheet(
return header + csv_string
except Exception as e:
except (ValueError, KeyError, AttributeError) as e:
return f"Error retrieving balance sheet for {ticker}: {str(e)}"
@ -320,7 +320,7 @@ def get_cashflow(
return header + csv_string
except Exception as e:
except (ValueError, KeyError, AttributeError) as e:
return f"Error retrieving cash flow for {ticker}: {str(e)}"
@ -347,7 +347,7 @@ def get_income_statement(
return header + csv_string
except Exception as e:
except (ValueError, KeyError, AttributeError) as e:
return f"Error retrieving income statement for {ticker}: {str(e)}"
@ -368,5 +368,5 @@ def get_insider_transactions(
return header + csv_string
except Exception as e:
except (ValueError, KeyError, AttributeError) as e:
return f"Error retrieving insider transactions for {ticker}: {str(e)}"

View File

@ -3,7 +3,7 @@ import os
DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
"data_dir": os.getenv("TRADINGAGENTS_DATA_DIR", "./data"),
"data_cache_dir": os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",

View File

@ -280,7 +280,7 @@ class TradingAgentsGraph:
)
discovery_result["stocks"] = trending_stocks
except Exception as e:
except (ValueError, KeyError, RuntimeError, ConnectionError, TimeoutError) as e:
discovery_result["error"] = str(e)
discovery_thread = threading.Thread(target=run_discovery)

View File

@ -0,0 +1,69 @@
from .market_data import (
OHLCV,
OHLCVBar,
TechnicalIndicators,
MarketSnapshot,
HistoricalDataRequest,
HistoricalDataResponse,
)
from .trading import (
OrderSide,
OrderType,
OrderStatus,
PositionSide,
Order,
Fill,
Position,
Trade,
)
from .portfolio import (
PortfolioSnapshot,
PortfolioConfig,
CashTransaction,
TransactionType,
)
from .backtest import (
BacktestConfig,
BacktestResult,
BacktestMetrics,
EquityCurvePoint,
TradeLog,
)
from .decisions import (
SignalType,
TradingSignal,
TradingDecision,
RiskAssessment,
AnalystReport,
)
__all__ = [
"OHLCV",
"OHLCVBar",
"TechnicalIndicators",
"MarketSnapshot",
"HistoricalDataRequest",
"HistoricalDataResponse",
"OrderSide",
"OrderType",
"OrderStatus",
"PositionSide",
"Order",
"Fill",
"Position",
"Trade",
"PortfolioSnapshot",
"PortfolioConfig",
"CashTransaction",
"TransactionType",
"BacktestConfig",
"BacktestResult",
"BacktestMetrics",
"EquityCurvePoint",
"TradeLog",
"SignalType",
"TradingSignal",
"TradingDecision",
"RiskAssessment",
"AnalystReport",
]

View File

@ -0,0 +1,242 @@
from datetime import date, datetime, timedelta
from decimal import Decimal
from enum import Enum
from typing import Optional
from uuid import UUID, uuid4
from pydantic import BaseModel, Field, computed_field, field_validator
from .portfolio import PortfolioConfig
from .trading import Trade
class BacktestConfig(BaseModel):
id: UUID = Field(default_factory=uuid4)
name: str = Field(default="Backtest")
description: Optional[str] = None
tickers: list[str] = Field(min_length=1)
start_date: date
end_date: date
interval: str = Field(default="1d")
portfolio_config: PortfolioConfig = Field(default_factory=PortfolioConfig)
warmup_period: int = Field(default=20, ge=0)
rebalance_frequency: Optional[str] = Field(default=None)
use_agent_pipeline: bool = Field(default=True)
agent_config: dict = Field(default_factory=dict)
benchmark_ticker: Optional[str] = Field(default="SPY")
risk_free_rate: Decimal = Field(default=Decimal("0.05"), ge=0)
created_at: datetime = Field(default_factory=datetime.now)
@field_validator("end_date")
@classmethod
def end_after_start(cls, v: date, info) -> date:
if "start_date" in info.data and v <= info.data["start_date"]:
raise ValueError("end_date must be > start_date")
return v
@field_validator("tickers")
@classmethod
def validate_tickers(cls, v: list[str]) -> list[str]:
return [t.upper().strip() for t in v]
@computed_field
@property
def trading_days_estimate(self) -> int:
delta = self.end_date - self.start_date
return int(delta.days * 252 / 365)
class EquityCurvePoint(BaseModel):
timestamp: datetime
equity: Decimal
cash: Decimal
positions_value: Decimal
benchmark_value: Optional[Decimal] = None
drawdown: Decimal = Field(default=Decimal("0"))
drawdown_percent: Decimal = Field(default=Decimal("0"))
class TradeLog(BaseModel):
trades: list[Trade] = Field(default_factory=list)
total_trades: int = Field(default=0, ge=0)
winning_trades: int = Field(default=0, ge=0)
losing_trades: int = Field(default=0, ge=0)
break_even_trades: int = Field(default=0, ge=0)
@computed_field
@property
def win_rate(self) -> Optional[Decimal]:
if self.total_trades == 0:
return None
return Decimal(self.winning_trades) / Decimal(self.total_trades) * 100
@computed_field
@property
def loss_rate(self) -> Optional[Decimal]:
if self.total_trades == 0:
return None
return Decimal(self.losing_trades) / Decimal(self.total_trades) * 100
def add_trade(self, trade: Trade) -> None:
self.trades.append(trade)
self.total_trades += 1
if trade.is_closed and trade.pnl is not None:
if trade.pnl > 0:
self.winning_trades += 1
elif trade.pnl < 0:
self.losing_trades += 1
else:
self.break_even_trades += 1
@property
def gross_profit(self) -> Decimal:
return sum(
t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl > 0
) or Decimal("0")
@property
def gross_loss(self) -> Decimal:
return abs(
sum(t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl < 0)
or Decimal("0")
)
@property
def profit_factor(self) -> Optional[Decimal]:
if self.gross_loss == 0:
return None
return self.gross_profit / self.gross_loss
@property
def avg_win(self) -> Optional[Decimal]:
wins = [t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl > 0]
if not wins:
return None
return sum(wins) / len(wins)
@property
def avg_loss(self) -> Optional[Decimal]:
losses = [t.pnl for t in self.trades if t.is_closed and t.pnl and t.pnl < 0]
if not losses:
return None
return sum(losses) / len(losses)
@property
def avg_holding_period(self) -> Optional[float]:
periods = [t.holding_period for t in self.trades if t.holding_period is not None]
if not periods:
return None
return sum(periods) / len(periods)
class BacktestMetrics(BaseModel):
total_return: Decimal = Field(default=Decimal("0"))
total_return_percent: Decimal = Field(default=Decimal("0"))
annualized_return: Decimal = Field(default=Decimal("0"))
benchmark_return: Optional[Decimal] = None
benchmark_return_percent: Optional[Decimal] = None
alpha: Optional[Decimal] = None
beta: Optional[Decimal] = None
volatility: Decimal = Field(default=Decimal("0"), ge=0)
annualized_volatility: Decimal = Field(default=Decimal("0"), ge=0)
downside_volatility: Decimal = Field(default=Decimal("0"), ge=0)
sharpe_ratio: Optional[Decimal] = None
sortino_ratio: Optional[Decimal] = None
calmar_ratio: Optional[Decimal] = None
information_ratio: Optional[Decimal] = None
max_drawdown: Decimal = Field(default=Decimal("0"), ge=0)
max_drawdown_percent: Decimal = Field(default=Decimal("0"), ge=0, le=100)
max_drawdown_duration: Optional[int] = None
avg_drawdown: Decimal = Field(default=Decimal("0"), ge=0)
total_trades: int = Field(default=0, ge=0)
win_rate: Optional[Decimal] = Field(default=None, ge=0, le=100)
profit_factor: Optional[Decimal] = None
avg_trade_pnl: Optional[Decimal] = None
avg_win: Optional[Decimal] = None
avg_loss: Optional[Decimal] = None
largest_win: Optional[Decimal] = None
largest_loss: Optional[Decimal] = None
avg_holding_period_days: Optional[float] = None
total_commission: Decimal = Field(default=Decimal("0"), ge=0)
total_slippage: Decimal = Field(default=Decimal("0"), ge=0)
trading_days: int = Field(default=0, ge=0)
start_equity: Decimal = Field(gt=0)
end_equity: Decimal = Field(gt=0)
def to_summary_dict(self) -> dict:
return {
"Performance": {
"Total Return": f"{self.total_return_percent:.2f}%",
"Annualized Return": f"{self.annualized_return:.2f}%",
"Sharpe Ratio": f"{self.sharpe_ratio:.2f}" if self.sharpe_ratio else "N/A",
"Sortino Ratio": f"{self.sortino_ratio:.2f}" if self.sortino_ratio else "N/A",
"Max Drawdown": f"{self.max_drawdown_percent:.2f}%",
},
"Risk": {
"Volatility (Ann.)": f"{self.annualized_volatility:.2f}%",
"Calmar Ratio": f"{self.calmar_ratio:.2f}" if self.calmar_ratio else "N/A",
"Beta": f"{self.beta:.2f}" if self.beta else "N/A",
},
"Trading": {
"Total Trades": self.total_trades,
"Win Rate": f"{self.win_rate:.1f}%" if self.win_rate else "N/A",
"Profit Factor": f"{self.profit_factor:.2f}" if self.profit_factor else "N/A",
"Avg Holding Period": f"{self.avg_holding_period_days:.1f} days" if self.avg_holding_period_days else "N/A",
},
"Costs": {
"Total Commission": f"${self.total_commission:.2f}",
"Total Slippage": f"${self.total_slippage:.2f}",
},
}
class BacktestResult(BaseModel):
id: UUID = Field(default_factory=uuid4)
config: BacktestConfig
metrics: BacktestMetrics
trade_log: TradeLog
equity_curve: list[EquityCurvePoint] = Field(default_factory=list)
daily_returns: list[Decimal] = Field(default_factory=list)
started_at: datetime
completed_at: datetime
status: str = Field(default="completed")
error_message: Optional[str] = None
@computed_field
@property
def duration_seconds(self) -> float:
return (self.completed_at - self.started_at).total_seconds()
def to_dict(self) -> dict:
return {
"id": str(self.id),
"config": {
"name": self.config.name,
"tickers": self.config.tickers,
"start_date": self.config.start_date.isoformat(),
"end_date": self.config.end_date.isoformat(),
"initial_cash": float(self.config.portfolio_config.initial_cash),
},
"metrics": self.metrics.to_summary_dict(),
"trade_summary": {
"total_trades": self.trade_log.total_trades,
"winning_trades": self.trade_log.winning_trades,
"losing_trades": self.trade_log.losing_trades,
"win_rate": float(self.trade_log.win_rate) if self.trade_log.win_rate else None,
},
"duration_seconds": self.duration_seconds,
"status": self.status,
}

View File

@ -0,0 +1,157 @@
from datetime import datetime
from decimal import Decimal
from enum import Enum
from typing import Optional
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
class SignalType(str, Enum):
STRONG_BUY = "strong_buy"
BUY = "buy"
HOLD = "hold"
SELL = "sell"
STRONG_SELL = "strong_sell"
class AnalystType(str, Enum):
MARKET = "market"
SENTIMENT = "sentiment"
NEWS = "news"
FUNDAMENTALS = "fundamentals"
class AnalystReport(BaseModel):
id: UUID = Field(default_factory=uuid4)
analyst_type: AnalystType
ticker: str
report_date: datetime
signal: Optional[SignalType] = None
confidence: Decimal = Field(default=Decimal("0.5"), ge=0, le=1)
summary: str
key_findings: list[str] = Field(default_factory=list)
raw_content: Optional[str] = None
data_sources: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=datetime.now)
class TradingSignal(BaseModel):
id: UUID = Field(default_factory=uuid4)
ticker: str
timestamp: datetime
signal: SignalType
strength: Decimal = Field(ge=0, le=1)
source: str
timeframe: str = Field(default="1d")
price_at_signal: Optional[Decimal] = None
target_price: Optional[Decimal] = None
stop_loss: Optional[Decimal] = None
expiry: Optional[datetime] = None
metadata: dict = Field(default_factory=dict)
class RiskAssessment(BaseModel):
id: UUID = Field(default_factory=uuid4)
ticker: str
timestamp: datetime
overall_risk_score: Decimal = Field(ge=0, le=1)
market_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1)
liquidity_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1)
volatility_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1)
concentration_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1)
event_risk: Decimal = Field(default=Decimal("0.5"), ge=0, le=1)
max_position_size: Optional[Decimal] = None
recommended_stop_loss: Optional[Decimal] = None
var_95: Optional[Decimal] = None
expected_shortfall: Optional[Decimal] = None
risk_factors: list[str] = Field(default_factory=list)
mitigations: list[str] = Field(default_factory=list)
notes: Optional[str] = None
class TradingDecision(BaseModel):
id: UUID = Field(default_factory=uuid4)
ticker: str
timestamp: datetime
decision_date: datetime
signal: SignalType
confidence: Decimal = Field(ge=0, le=1)
recommended_action: str
recommended_quantity: Optional[int] = None
recommended_price: Optional[Decimal] = None
stop_loss: Optional[Decimal] = None
take_profit: Optional[Decimal] = None
analyst_reports: list[AnalystReport] = Field(default_factory=list)
signals: list[TradingSignal] = Field(default_factory=list)
risk_assessment: Optional[RiskAssessment] = None
bull_argument: Optional[str] = None
bear_argument: Optional[str] = None
debate_rounds: int = Field(default=0, ge=0)
debate_winner: Optional[str] = None
risk_manager_approved: Optional[bool] = None
risk_manager_notes: Optional[str] = None
final_decision: str
rationale: str
execution_price: Optional[Decimal] = None
executed_at: Optional[datetime] = None
execution_notes: Optional[str] = None
created_at: datetime = Field(default_factory=datetime.now)
@property
def is_buy(self) -> bool:
return self.signal in (SignalType.BUY, SignalType.STRONG_BUY)
@property
def is_sell(self) -> bool:
return self.signal in (SignalType.SELL, SignalType.STRONG_SELL)
@property
def is_hold(self) -> bool:
return self.signal == SignalType.HOLD
def get_analyst_report(self, analyst_type: AnalystType) -> Optional[AnalystReport]:
for report in self.analyst_reports:
if report.analyst_type == analyst_type:
return report
return None
def to_summary(self) -> dict:
return {
"ticker": self.ticker,
"date": self.decision_date.isoformat(),
"signal": self.signal.value,
"confidence": float(self.confidence),
"final_decision": self.final_decision,
"risk_approved": self.risk_manager_approved,
"debate_rounds": self.debate_rounds,
"analyst_consensus": self._calculate_consensus(),
}
def _calculate_consensus(self) -> Optional[str]:
if not self.analyst_reports:
return None
signals = [r.signal for r in self.analyst_reports if r.signal]
if not signals:
return None
buy_count = sum(1 for s in signals if s in (SignalType.BUY, SignalType.STRONG_BUY))
sell_count = sum(1 for s in signals if s in (SignalType.SELL, SignalType.STRONG_SELL))
if buy_count > sell_count:
return "bullish"
elif sell_count > buy_count:
return "bearish"
return "neutral"

View File

@ -0,0 +1,144 @@
from datetime import date, datetime
from decimal import Decimal
from typing import Optional
from pydantic import BaseModel, Field, field_validator
class OHLCVBar(BaseModel):
timestamp: datetime
open: Decimal = Field(gt=0)
high: Decimal = Field(gt=0)
low: Decimal = Field(gt=0)
close: Decimal = Field(gt=0)
volume: int = Field(ge=0)
adjusted_close: Optional[Decimal] = Field(default=None, gt=0)
@field_validator("high")
@classmethod
def high_gte_low(cls, v: Decimal, info) -> Decimal:
if "low" in info.data and v < info.data["low"]:
raise ValueError("high must be >= low")
return v
@field_validator("high")
@classmethod
def high_gte_open_close(cls, v: Decimal, info) -> Decimal:
if "open" in info.data and v < info.data["open"]:
raise ValueError("high must be >= open")
if "close" in info.data and v < info.data["close"]:
raise ValueError("high must be >= close")
return v
@field_validator("low")
@classmethod
def low_lte_open_close(cls, v: Decimal, info) -> Decimal:
if "open" in info.data and v > info.data["open"]:
raise ValueError("low must be <= open")
if "close" in info.data and v > info.data["close"]:
raise ValueError("low must be <= close")
return v
class OHLCV(BaseModel):
ticker: str = Field(min_length=1, max_length=10)
bars: list[OHLCVBar] = Field(default_factory=list)
interval: str = Field(default="1d")
currency: str = Field(default="USD")
@property
def start_date(self) -> Optional[datetime]:
return self.bars[0].timestamp if self.bars else None
@property
def end_date(self) -> Optional[datetime]:
return self.bars[-1].timestamp if self.bars else None
def get_bar(self, dt: datetime) -> Optional[OHLCVBar]:
for bar in self.bars:
if bar.timestamp.date() == dt.date():
return bar
return None
def slice(self, start: datetime, end: datetime) -> "OHLCV":
filtered = [b for b in self.bars if start <= b.timestamp <= end]
return OHLCV(
ticker=self.ticker,
bars=filtered,
interval=self.interval,
currency=self.currency,
)
class TechnicalIndicators(BaseModel):
timestamp: datetime
ticker: str
sma_20: Optional[Decimal] = None
sma_50: Optional[Decimal] = None
sma_200: Optional[Decimal] = None
ema_10: Optional[Decimal] = None
ema_20: Optional[Decimal] = None
rsi_14: Optional[Decimal] = Field(default=None, ge=0, le=100)
macd: Optional[Decimal] = None
macd_signal: Optional[Decimal] = None
macd_histogram: Optional[Decimal] = None
bollinger_upper: Optional[Decimal] = None
bollinger_middle: Optional[Decimal] = None
bollinger_lower: Optional[Decimal] = None
atr_14: Optional[Decimal] = Field(default=None, ge=0)
mfi_14: Optional[Decimal] = Field(default=None, ge=0, le=100)
vwap: Optional[Decimal] = None
obv: Optional[int] = None
class MarketSnapshot(BaseModel):
ticker: str
timestamp: datetime
bar: OHLCVBar
indicators: Optional[TechnicalIndicators] = None
prev_close: Optional[Decimal] = None
@property
def change(self) -> Optional[Decimal]:
if self.prev_close:
return self.bar.close - self.prev_close
return None
@property
def change_percent(self) -> Optional[Decimal]:
if self.prev_close and self.prev_close > 0:
return ((self.bar.close - self.prev_close) / self.prev_close) * 100
return None
class HistoricalDataRequest(BaseModel):
ticker: str = Field(min_length=1, max_length=10)
start_date: date
end_date: date
interval: str = Field(default="1d")
include_indicators: bool = Field(default=True)
adjusted: bool = Field(default=True)
@field_validator("end_date")
@classmethod
def end_after_start(cls, v: date, info) -> date:
if "start_date" in info.data and v < info.data["start_date"]:
raise ValueError("end_date must be >= start_date")
return v
class HistoricalDataResponse(BaseModel):
request: HistoricalDataRequest
ohlcv: OHLCV
indicators: list[TechnicalIndicators] = Field(default_factory=list)
fetched_at: datetime = Field(default_factory=datetime.now)
source: str = Field(default="unknown")

View File

@ -0,0 +1,172 @@
from datetime import datetime
from decimal import Decimal
from enum import Enum
from typing import Optional
from uuid import UUID, uuid4
from pydantic import BaseModel, Field, computed_field
from .trading import Position, Fill, OrderSide
class TransactionType(str, Enum):
DEPOSIT = "deposit"
WITHDRAWAL = "withdrawal"
DIVIDEND = "dividend"
INTEREST = "interest"
FEE = "fee"
TRANSFER_IN = "transfer_in"
TRANSFER_OUT = "transfer_out"
class CashTransaction(BaseModel):
id: UUID = Field(default_factory=uuid4)
transaction_type: TransactionType
amount: Decimal
timestamp: datetime = Field(default_factory=datetime.now)
description: Optional[str] = None
reference_id: Optional[UUID] = None
class PortfolioConfig(BaseModel):
initial_cash: Decimal = Field(default=Decimal("100000"), gt=0)
commission_per_share: Decimal = Field(default=Decimal("0"), ge=0)
commission_per_trade: Decimal = Field(default=Decimal("0"), ge=0)
commission_percent: Decimal = Field(default=Decimal("0"), ge=0, le=100)
min_commission: Decimal = Field(default=Decimal("0"), ge=0)
max_commission: Optional[Decimal] = Field(default=None, ge=0)
slippage_percent: Decimal = Field(default=Decimal("0"), ge=0, le=100)
margin_enabled: bool = Field(default=False)
margin_rate: Decimal = Field(default=Decimal("0.05"), ge=0)
max_position_size_percent: Decimal = Field(default=Decimal("100"), gt=0, le=100)
allow_fractional_shares: bool = Field(default=False)
def calculate_commission(self, quantity: int, price: Decimal) -> Decimal:
trade_value = quantity * price
commission = Decimal("0")
commission += self.commission_per_trade
commission += self.commission_per_share * quantity
commission += trade_value * (self.commission_percent / 100)
if commission < self.min_commission:
commission = self.min_commission
if self.max_commission and commission > self.max_commission:
commission = self.max_commission
return commission
def calculate_slippage(self, price: Decimal, side: OrderSide) -> Decimal:
slippage = price * (self.slippage_percent / 100)
if side == OrderSide.BUY:
return price + slippage
return price - slippage
class PortfolioSnapshot(BaseModel):
timestamp: datetime = Field(default_factory=datetime.now)
cash: Decimal = Field(default=Decimal("0"))
positions: dict[str, Position] = Field(default_factory=dict)
pending_orders: int = Field(default=0, ge=0)
total_commission_paid: Decimal = Field(default=Decimal("0"), ge=0)
total_realized_pnl: Decimal = Field(default=Decimal("0"))
cash_transactions: list[CashTransaction] = Field(default_factory=list)
@computed_field
@property
def position_count(self) -> int:
return len([p for p in self.positions.values() if p.quantity != 0])
def positions_value(self, prices: dict[str, Decimal]) -> Decimal:
total = Decimal("0")
for ticker, position in self.positions.items():
if ticker in prices and position.quantity != 0:
total += position.market_value(prices[ticker])
return total
def total_equity(self, prices: dict[str, Decimal]) -> Decimal:
return self.cash + self.positions_value(prices)
def total_unrealized_pnl(self, prices: dict[str, Decimal]) -> Decimal:
total = Decimal("0")
for ticker, position in self.positions.items():
if ticker in prices:
total += position.unrealized_pnl(prices[ticker])
return total
def get_position(self, ticker: str) -> Position:
if ticker not in self.positions:
self.positions[ticker] = Position(ticker=ticker)
return self.positions[ticker]
def apply_fill(self, fill: Fill) -> None:
position = self.get_position(fill.ticker)
old_realized = position.realized_pnl
position.update_from_fill(fill)
pnl_change = position.realized_pnl - old_realized
self.total_realized_pnl += pnl_change
self.total_commission_paid += fill.commission
if fill.side == OrderSide.BUY:
self.cash -= fill.total_cost
else:
self.cash += fill.total_value - fill.commission
def add_cash_transaction(self, transaction: CashTransaction) -> None:
self.cash_transactions.append(transaction)
if transaction.transaction_type in (
TransactionType.DEPOSIT,
TransactionType.DIVIDEND,
TransactionType.INTEREST,
TransactionType.TRANSFER_IN,
):
self.cash += transaction.amount
else:
self.cash -= abs(transaction.amount)
def can_afford(
self, ticker: str, quantity: int, price: Decimal, config: PortfolioConfig
) -> bool:
execution_price = config.calculate_slippage(price, OrderSide.BUY)
commission = config.calculate_commission(quantity, execution_price)
total_cost = (quantity * execution_price) + commission
return self.cash >= total_cost
def max_shares_affordable(
self, ticker: str, price: Decimal, config: PortfolioConfig
) -> int:
if price <= 0:
return 0
execution_price = config.calculate_slippage(price, OrderSide.BUY)
available = self.cash
low, high = 0, int(available / execution_price) + 1
result = 0
while low <= high:
mid = (low + high) // 2
commission = config.calculate_commission(mid, execution_price)
total_cost = (mid * execution_price) + commission
if total_cost <= available:
result = mid
low = mid + 1
else:
high = mid - 1
return result
def to_dict(self, prices: dict[str, Decimal]) -> dict:
return {
"timestamp": self.timestamp.isoformat(),
"cash": float(self.cash),
"positions_value": float(self.positions_value(prices)),
"total_equity": float(self.total_equity(prices)),
"position_count": self.position_count,
"total_realized_pnl": float(self.total_realized_pnl),
"total_unrealized_pnl": float(self.total_unrealized_pnl(prices)),
"total_commission_paid": float(self.total_commission_paid),
}

View File

@ -0,0 +1,201 @@
from datetime import datetime
from decimal import Decimal
from enum import Enum
from typing import Optional
from uuid import UUID, uuid4
from pydantic import BaseModel, Field, computed_field
class OrderSide(str, Enum):
BUY = "buy"
SELL = "sell"
class OrderType(str, Enum):
MARKET = "market"
LIMIT = "limit"
STOP = "stop"
STOP_LIMIT = "stop_limit"
class OrderStatus(str, Enum):
PENDING = "pending"
SUBMITTED = "submitted"
PARTIAL = "partial"
FILLED = "filled"
CANCELLED = "cancelled"
REJECTED = "rejected"
EXPIRED = "expired"
class PositionSide(str, Enum):
LONG = "long"
SHORT = "short"
FLAT = "flat"
class Order(BaseModel):
id: UUID = Field(default_factory=uuid4)
ticker: str = Field(min_length=1, max_length=10)
side: OrderSide
order_type: OrderType = Field(default=OrderType.MARKET)
quantity: int = Field(gt=0)
limit_price: Optional[Decimal] = Field(default=None, gt=0)
stop_price: Optional[Decimal] = Field(default=None, gt=0)
status: OrderStatus = Field(default=OrderStatus.PENDING)
created_at: datetime = Field(default_factory=datetime.now)
submitted_at: Optional[datetime] = None
filled_at: Optional[datetime] = None
filled_quantity: int = Field(default=0, ge=0)
filled_avg_price: Optional[Decimal] = None
commission: Decimal = Field(default=Decimal("0"))
notes: Optional[str] = None
@computed_field
@property
def remaining_quantity(self) -> int:
return self.quantity - self.filled_quantity
@computed_field
@property
def is_complete(self) -> bool:
return self.status in (
OrderStatus.FILLED,
OrderStatus.CANCELLED,
OrderStatus.REJECTED,
OrderStatus.EXPIRED,
)
class Fill(BaseModel):
id: UUID = Field(default_factory=uuid4)
order_id: UUID
ticker: str
side: OrderSide
quantity: int = Field(gt=0)
price: Decimal = Field(gt=0)
commission: Decimal = Field(default=Decimal("0"), ge=0)
timestamp: datetime = Field(default_factory=datetime.now)
@computed_field
@property
def total_value(self) -> Decimal:
return self.price * self.quantity
@computed_field
@property
def total_cost(self) -> Decimal:
if self.side == OrderSide.BUY:
return self.total_value + self.commission
return self.total_value - self.commission
class Position(BaseModel):
ticker: str = Field(min_length=1, max_length=10)
quantity: int = Field(default=0)
avg_cost: Decimal = Field(default=Decimal("0"), ge=0)
realized_pnl: Decimal = Field(default=Decimal("0"))
opened_at: Optional[datetime] = None
last_updated: datetime = Field(default_factory=datetime.now)
@computed_field
@property
def side(self) -> PositionSide:
if self.quantity > 0:
return PositionSide.LONG
elif self.quantity < 0:
return PositionSide.SHORT
return PositionSide.FLAT
@computed_field
@property
def cost_basis(self) -> Decimal:
return abs(self.quantity) * self.avg_cost
def unrealized_pnl(self, current_price: Decimal) -> Decimal:
if self.quantity == 0:
return Decimal("0")
market_value = self.quantity * current_price
return market_value - (self.quantity * self.avg_cost)
def market_value(self, current_price: Decimal) -> Decimal:
return abs(self.quantity) * current_price
def update_from_fill(self, fill: Fill) -> None:
if fill.side == OrderSide.BUY:
if self.quantity >= 0:
total_cost = (self.quantity * self.avg_cost) + fill.total_value
self.quantity += fill.quantity
self.avg_cost = total_cost / self.quantity if self.quantity else Decimal("0")
else:
close_qty = min(fill.quantity, abs(self.quantity))
pnl = close_qty * (self.avg_cost - fill.price)
self.realized_pnl += pnl
self.quantity += fill.quantity
if self.quantity > 0:
self.avg_cost = fill.price
else:
if self.quantity <= 0:
total_cost = (abs(self.quantity) * self.avg_cost) + fill.total_value
self.quantity -= fill.quantity
self.avg_cost = total_cost / abs(self.quantity) if self.quantity else Decimal("0")
else:
close_qty = min(fill.quantity, self.quantity)
pnl = close_qty * (fill.price - self.avg_cost)
self.realized_pnl += pnl
self.quantity -= fill.quantity
if self.quantity < 0:
self.avg_cost = fill.price
if self.quantity != 0 and self.opened_at is None:
self.opened_at = fill.timestamp
elif self.quantity == 0:
self.opened_at = None
self.last_updated = fill.timestamp
class Trade(BaseModel):
id: UUID = Field(default_factory=uuid4)
ticker: str
side: OrderSide
entry_price: Decimal = Field(gt=0)
entry_quantity: int = Field(gt=0)
entry_time: datetime
exit_price: Optional[Decimal] = Field(default=None, gt=0)
exit_quantity: Optional[int] = Field(default=None, gt=0)
exit_time: Optional[datetime] = None
commission: Decimal = Field(default=Decimal("0"), ge=0)
entry_order_id: Optional[UUID] = None
exit_order_id: Optional[UUID] = None
notes: Optional[str] = None
tags: list[str] = Field(default_factory=list)
@computed_field
@property
def is_closed(self) -> bool:
return self.exit_price is not None and self.exit_quantity is not None
@computed_field
@property
def pnl(self) -> Optional[Decimal]:
if not self.is_closed:
return None
if self.side == OrderSide.BUY:
return (self.exit_price - self.entry_price) * self.exit_quantity - self.commission
return (self.entry_price - self.exit_price) * self.exit_quantity - self.commission
@computed_field
@property
def pnl_percent(self) -> Optional[Decimal]:
if not self.is_closed or self.entry_price == 0:
return None
return (self.pnl / (self.entry_price * self.entry_quantity)) * 100
@computed_field
@property
def holding_period(self) -> Optional[int]:
if not self.exit_time:
return None
return (self.exit_time - self.entry_time).days