TradingAgents/tests/portfolio/test_trade_executor.py

250 lines
8.2 KiB
Python

"""Tests for tradingagents/portfolio/trade_executor.py.
Uses MagicMock for PortfolioRepository — no DB connection required.
Run::
pytest tests/portfolio/test_trade_executor.py -v
"""
from __future__ import annotations
from unittest.mock import MagicMock, call
import pytest
from tradingagents.portfolio.models import Holding, Portfolio, PortfolioSnapshot
from tradingagents.portfolio.exceptions import (
InsufficientCashError,
InsufficientSharesError,
)
from tradingagents.portfolio.trade_executor import TradeExecutor
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_holding(ticker, shares=10.0, avg_cost=100.0, sector="Technology"):
return Holding(
holding_id="h-" + ticker,
portfolio_id="p1",
ticker=ticker,
shares=shares,
avg_cost=avg_cost,
sector=sector,
)
def _make_portfolio(cash=50_000.0, total_value=60_000.0):
p = Portfolio(
portfolio_id="p1",
name="Test",
cash=cash,
initial_cash=100_000.0,
)
p.total_value = total_value
p.equity_value = total_value - cash
p.cash_pct = cash / total_value if total_value else 1.0
return p
def _make_snapshot():
return PortfolioSnapshot(
snapshot_id="snap-1",
portfolio_id="p1",
snapshot_date="2026-01-01T00:00:00Z",
total_value=60_000.0,
cash=50_000.0,
equity_value=10_000.0,
num_positions=1,
holdings_snapshot=[],
)
def _make_repo(portfolio=None, holdings=None, snapshot=None):
repo = MagicMock()
repo.get_portfolio_with_holdings.return_value = (
portfolio or _make_portfolio(),
holdings or [],
)
repo.take_snapshot.return_value = snapshot or _make_snapshot()
return repo
_DEFAULT_CONFIG = {
"max_positions": 15,
"max_position_pct": 0.15,
"max_sector_pct": 0.35,
"min_cash_pct": 0.05,
}
PRICES = {"AAPL": 150.0, "MSFT": 300.0}
# ---------------------------------------------------------------------------
# SELL tests
# ---------------------------------------------------------------------------
def test_execute_sell_success():
"""Successful SELL calls batch_remove_holdings and is in executed_trades."""
repo = _make_repo()
# Mock batch_remove_holdings to return a tuple of (executed, failed)
repo.batch_remove_holdings.return_value = (
[{"action": "SELL", "ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Stop loss"}],
[]
)
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
decisions = {
"sells": [{"ticker": "AAPL", "shares": 5.0, "rationale": "Stop loss"}],
"buys": [],
}
result = executor.execute_decisions("p1", decisions, PRICES)
repo.batch_remove_holdings.assert_called_once()
args, kwargs = repo.batch_remove_holdings.call_args
assert args[0] == "p1"
assert args[1] == [{"ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Stop loss"}]
assert len(result["executed_trades"]) == 1
assert result["executed_trades"][0]["action"] == "SELL"
assert result["executed_trades"][0]["ticker"] == "AAPL"
assert len(result["failed_trades"]) == 0
def test_execute_sell_missing_price():
"""SELL with no price in prices dict → failed_trade."""
repo = _make_repo()
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
decisions = {
"sells": [{"ticker": "NVDA", "shares": 5.0, "rationale": "Stop loss"}],
"buys": [],
}
result = executor.execute_decisions("p1", decisions, PRICES)
repo.remove_holding.assert_not_called()
assert len(result["failed_trades"]) == 1
assert result["failed_trades"][0]["ticker"] == "NVDA"
def test_execute_sell_insufficient_shares():
"""SELL that fails due to logic in batch_remove_holdings → failed_trade."""
repo = _make_repo()
repo.batch_remove_holdings.return_value = (
[],
[{"action": "SELL", "ticker": "AAPL", "reason": "Hold 10.0 shares of AAPL, cannot sell 999.0"}]
)
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
decisions = {
"sells": [{"ticker": "AAPL", "shares": 999.0, "rationale": "Exit"}],
"buys": [],
}
result = executor.execute_decisions("p1", decisions, PRICES)
assert len(result["failed_trades"]) == 1
assert "cannot sell 999.0" in result["failed_trades"][0]["reason"]
# ---------------------------------------------------------------------------
# BUY tests
# ---------------------------------------------------------------------------
def test_execute_buy_success():
"""Successful BUY calls add_holding and is in executed_trades."""
portfolio = _make_portfolio(cash=50_000.0, total_value=60_000.0)
repo = _make_repo(portfolio=portfolio)
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
decisions = {
"sells": [],
"buys": [{"ticker": "MSFT", "shares": 10.0, "sector": "Technology", "rationale": "Growth"}],
}
result = executor.execute_decisions("p1", decisions, PRICES)
repo.add_holding.assert_called_once_with("p1", "MSFT", 10.0, 300.0, sector="Technology")
assert len(result["executed_trades"]) == 1
assert result["executed_trades"][0]["action"] == "BUY"
def test_execute_buy_missing_price():
"""BUY with no price in prices dict → failed_trade."""
repo = _make_repo()
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
decisions = {
"sells": [],
"buys": [{"ticker": "TSLA", "shares": 5.0, "sector": "Automotive", "rationale": "EV"}],
}
result = executor.execute_decisions("p1", decisions, PRICES)
repo.add_holding.assert_not_called()
assert len(result["failed_trades"]) == 1
assert result["failed_trades"][0]["ticker"] == "TSLA"
def test_execute_buy_constraint_violation():
"""BUY exceeding max_positions → failed_trade with constraint violation."""
# Fill portfolio to max positions (15)
holdings = [
_make_holding(f"T{i}", shares=10, avg_cost=100, sector="Technology")
for i in range(15)
]
portfolio = _make_portfolio(cash=5_000.0, total_value=20_000.0)
repo = _make_repo(portfolio=portfolio, holdings=holdings)
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
decisions = {
"sells": [],
"buys": [{"ticker": "NEWT", "shares": 5.0, "sector": "Healthcare", "rationale": "New"}],
}
result = executor.execute_decisions("p1", decisions, {**PRICES, "NEWT": 50.0})
repo.add_holding.assert_not_called()
assert len(result["failed_trades"]) == 1
assert result["failed_trades"][0]["reason"] == "Constraint violation"
# ---------------------------------------------------------------------------
# Ordering and snapshot
# ---------------------------------------------------------------------------
def test_execute_decisions_sells_before_buys():
"""SELLs are always executed before BUYs."""
portfolio = _make_portfolio(cash=50_000.0, total_value=60_000.0)
repo = _make_repo(portfolio=portfolio)
repo.batch_remove_holdings.return_value = (
[{"action": "SELL", "ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Exit"}],
[]
)
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
decisions = {
"sells": [{"ticker": "AAPL", "shares": 5.0, "rationale": "Exit"}],
"buys": [{"ticker": "MSFT", "shares": 3.0, "sector": "Technology", "rationale": "Add"}],
}
executor.execute_decisions("p1", decisions, PRICES)
# Verify call order: batch_remove_holdings before add_holding
call_order = [c[0] for c in repo.method_calls if c[0] in ("batch_remove_holdings", "add_holding")]
assert call_order.index("batch_remove_holdings") < call_order.index("add_holding")
def test_execute_decisions_takes_snapshot():
"""take_snapshot is always called at end of execution."""
repo = _make_repo()
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
decisions = {"sells": [], "buys": []}
result = executor.execute_decisions("p1", decisions, PRICES)
repo.take_snapshot.assert_called_once_with("p1", PRICES)
assert "snapshot" in result
assert result["snapshot"]["snapshot_id"] == "snap-1"