diff --git a/tests/unit/execution/test_paper_broker.py b/tests/unit/execution/test_paper_broker.py new file mode 100644 index 00000000..3bf2470f --- /dev/null +++ b/tests/unit/execution/test_paper_broker.py @@ -0,0 +1,859 @@ +"""Tests for Paper Broker implementation. + +Issue #26: [EXEC-25] Paper broker - simulation mode +""" + +from decimal import Decimal +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch +import pytest + +from tradingagents.execution import ( + PaperBroker, + OrderRequest, + OrderSide, + OrderType, + OrderStatus, + TimeInForce, + AssetClass, + PositionSide, + ConnectionError, + OrderError, + InvalidOrderError, + InsufficientFundsError, +) + + +class TestPaperBrokerInit: + """Test PaperBroker initialization.""" + + def test_default_initialization(self): + """Test default broker initialization.""" + broker = PaperBroker() + assert broker.name == "Paper" + assert broker.initial_cash == Decimal("100000") + assert broker.cash == Decimal("100000") + assert broker.is_paper_trading is True + assert broker.is_connected is False + + def test_custom_initial_cash(self): + """Test initialization with custom initial cash.""" + broker = PaperBroker(initial_cash=Decimal("50000")) + assert broker.initial_cash == Decimal("50000") + assert broker.cash == Decimal("50000") + + def test_custom_slippage(self): + """Test initialization with custom slippage.""" + broker = PaperBroker(slippage_percent=Decimal("0.1")) + assert broker._slippage_percent == Decimal("0.1") + + def test_custom_fill_probability(self): + """Test initialization with custom fill probability.""" + broker = PaperBroker(fill_probability=0.5) + assert broker._fill_probability == 0.5 + + def test_market_closed_initialization(self): + """Test initialization with market closed.""" + broker = PaperBroker(market_open=False) + assert broker._market_open is False + + def test_supported_asset_classes(self): + """Test supported asset classes include all types.""" + broker = PaperBroker() + assert AssetClass.EQUITY in broker.supported_asset_classes + assert AssetClass.ETF in broker.supported_asset_classes + assert AssetClass.CRYPTO in broker.supported_asset_classes + assert AssetClass.FUTURE in broker.supported_asset_classes + assert AssetClass.OPTION in broker.supported_asset_classes + assert AssetClass.FOREX in broker.supported_asset_classes + + def test_custom_price_provider(self): + """Test initialization with custom price provider.""" + def price_provider(symbol: str) -> Decimal: + return Decimal("123.45") + + broker = PaperBroker(price_provider=price_provider) + assert broker.get_simulated_price("ANY") == Decimal("123.45") + + +class TestPaperBrokerConnection: + """Test PaperBroker connection methods.""" + + @pytest.mark.asyncio + async def test_connect_succeeds(self): + """Test connect always succeeds.""" + broker = PaperBroker() + result = await broker.connect() + assert result is True + assert broker.is_connected is True + + @pytest.mark.asyncio + async def test_disconnect(self): + """Test disconnect.""" + broker = PaperBroker() + await broker.connect() + await broker.disconnect() + assert broker.is_connected is False + + @pytest.mark.asyncio + async def test_multiple_connects(self): + """Test multiple connects work.""" + broker = PaperBroker() + await broker.connect() + await broker.connect() + assert broker.is_connected is True + + +class TestPaperBrokerMarketStatus: + """Test PaperBroker market status.""" + + @pytest.mark.asyncio + async def test_market_open_by_default(self): + """Test market is open by default.""" + broker = PaperBroker() + await broker.connect() + assert await broker.is_market_open() is True + + @pytest.mark.asyncio + async def test_market_closed(self): + """Test market closed simulation.""" + broker = PaperBroker(market_open=False) + await broker.connect() + assert await broker.is_market_open() is False + + @pytest.mark.asyncio + async def test_set_market_open(self): + """Test changing market open status.""" + broker = PaperBroker(market_open=False) + broker.set_market_open(True) + await broker.connect() + assert await broker.is_market_open() is True + + +class TestPaperBrokerPrices: + """Test PaperBroker price simulation.""" + + def test_set_and_get_price(self): + """Test setting and getting prices.""" + broker = PaperBroker() + broker.set_price("TEST", Decimal("99.99")) + assert broker.get_simulated_price("TEST") == Decimal("99.99") + + def test_default_prices(self): + """Test default prices for common symbols.""" + broker = PaperBroker() + assert broker.get_simulated_price("AAPL") == Decimal("175.00") + assert broker.get_simulated_price("MSFT") == Decimal("380.00") + assert broker.get_simulated_price("SPY") == Decimal("470.00") + + def test_crypto_default_prices(self): + """Test default crypto prices.""" + broker = PaperBroker() + assert broker.get_simulated_price("BTCUSD") == Decimal("45000.00") + assert broker.get_simulated_price("ETHUSD") == Decimal("2500.00") + + def test_futures_default_prices(self): + """Test default futures prices.""" + broker = PaperBroker() + assert broker.get_simulated_price("ES") == Decimal("4700.00") + assert broker.get_simulated_price("NQ") == Decimal("16500.00") + + def test_unknown_symbol_generates_price(self): + """Test unknown symbols generate random prices.""" + broker = PaperBroker() + price = broker.get_simulated_price("UNKNOWN") + # Should be around 100 +/- 10 + assert Decimal("80") < price < Decimal("120") + + +class TestPaperBrokerAccount: + """Test PaperBroker account methods.""" + + @pytest.mark.asyncio + async def test_get_account_requires_connection(self): + """Test get_account requires connection.""" + broker = PaperBroker() + with pytest.raises(ConnectionError, match="Not connected"): + await broker.get_account() + + @pytest.mark.asyncio + async def test_get_account_basic(self): + """Test basic account information.""" + broker = PaperBroker(initial_cash=Decimal("50000")) + await broker.connect() + + account = await broker.get_account() + assert account.account_type == "paper" + assert account.status == "active" + assert account.cash == Decimal("50000") + assert account.portfolio_value == Decimal("50000") + assert account.buying_power == Decimal("50000") + assert account.account_id.startswith("PAPER-") + + @pytest.mark.asyncio + async def test_get_account_with_positions(self): + """Test account includes position values.""" + broker = PaperBroker(initial_cash=Decimal("100000")) + broker.set_price("AAPL", Decimal("150")) + await broker.connect() + + # Buy some shares + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + account = await broker.get_account() + # Cash reduced by purchase + assert account.cash < Decimal("100000") + # Portfolio includes position + assert account.portfolio_value > account.cash + + +class TestPaperBrokerOrders: + """Test PaperBroker order methods.""" + + @pytest.mark.asyncio + async def test_submit_market_buy_order(self): + """Test submitting market buy order.""" + broker = PaperBroker(initial_cash=Decimal("100000")) + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + order = await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + assert order.symbol == "AAPL" + assert order.side == OrderSide.BUY + assert order.quantity == Decimal("10") + assert order.status == OrderStatus.FILLED + assert order.filled_quantity == Decimal("10") + assert order.broker_order_id.startswith("PAPER-") + + @pytest.mark.asyncio + async def test_submit_market_sell_order(self): + """Test submitting market sell order.""" + broker = PaperBroker(initial_cash=Decimal("100000")) + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + # Buy first + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + # Then sell + order = await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.SELL, Decimal("5")) + ) + + assert order.status == OrderStatus.FILLED + assert order.filled_quantity == Decimal("5") + + @pytest.mark.asyncio + async def test_submit_limit_order_fills(self): + """Test limit order that should fill.""" + broker = PaperBroker() + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + # Limit above market price - should fill + order = await broker.submit_order( + OrderRequest.limit("AAPL", OrderSide.BUY, Decimal("10"), Decimal("110")) + ) + + assert order.status == OrderStatus.FILLED + assert order.filled_avg_price == Decimal("110") + + @pytest.mark.asyncio + async def test_submit_limit_order_no_fill(self): + """Test limit order that shouldn't fill.""" + broker = PaperBroker() + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + # Limit below market price - shouldn't fill + order = await broker.submit_order( + OrderRequest.limit("AAPL", OrderSide.BUY, Decimal("10"), Decimal("90")) + ) + + assert order.status == OrderStatus.NEW + assert order.filled_quantity == Decimal("0") + + @pytest.mark.asyncio + async def test_order_requires_connection(self): + """Test order submission requires connection.""" + broker = PaperBroker() + with pytest.raises(ConnectionError, match="Not connected"): + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + @pytest.mark.asyncio + async def test_invalid_order_quantity(self): + """Test invalid order quantity.""" + broker = PaperBroker() + await broker.connect() + + with pytest.raises(InvalidOrderError, match="quantity must be positive"): + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("-10")) + ) + + @pytest.mark.asyncio + async def test_insufficient_funds(self): + """Test insufficient funds error.""" + broker = PaperBroker(initial_cash=Decimal("100")) + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + with pytest.raises(InsufficientFundsError, match="Insufficient funds"): + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + @pytest.mark.asyncio + async def test_slippage_on_buy(self): + """Test slippage applied to buy orders.""" + broker = PaperBroker(slippage_percent=Decimal("1.0")) + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + order = await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("1")) + ) + + # 1% slippage on $100 = $101 + assert order.filled_avg_price == Decimal("101.00") + + @pytest.mark.asyncio + async def test_slippage_on_sell(self): + """Test slippage applied to sell orders.""" + broker = PaperBroker(slippage_percent=Decimal("1.0")) + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + # Buy first + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + # Sell with slippage + order = await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.SELL, Decimal("5")) + ) + + # 1% slippage on $100 = $99 + assert order.filled_avg_price == Decimal("99.00") + + +class TestPaperBrokerFillProbability: + """Test PaperBroker fill probability.""" + + @pytest.mark.asyncio + async def test_zero_fill_probability(self): + """Test orders don't fill with 0% probability.""" + broker = PaperBroker(fill_probability=0.0) + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + order = await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + assert order.status == OrderStatus.NEW + assert order.filled_quantity == Decimal("0") + + @pytest.mark.asyncio + async def test_full_fill_probability(self): + """Test orders always fill with 100% probability.""" + broker = PaperBroker(fill_probability=1.0) + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + order = await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + assert order.status == OrderStatus.FILLED + + +class TestPaperBrokerCancelOrder: + """Test PaperBroker cancel order.""" + + @pytest.mark.asyncio + async def test_cancel_unfilled_order(self): + """Test cancelling unfilled order.""" + broker = PaperBroker(fill_probability=0.0) + await broker.connect() + + order = await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + cancelled = await broker.cancel_order(order.broker_order_id) + assert cancelled.status == OrderStatus.CANCELLED + + @pytest.mark.asyncio + async def test_cancel_filled_order_fails(self): + """Test cannot cancel filled order.""" + broker = PaperBroker() + broker.set_price("AAPL", Decimal("10")) + await broker.connect() + + order = await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + with pytest.raises(OrderError, match="Cannot cancel filled order"): + await broker.cancel_order(order.broker_order_id) + + @pytest.mark.asyncio + async def test_cancel_nonexistent_order(self): + """Test cancelling nonexistent order.""" + broker = PaperBroker() + await broker.connect() + + with pytest.raises(OrderError, match="not found"): + await broker.cancel_order("INVALID-123") + + +class TestPaperBrokerReplaceOrder: + """Test PaperBroker replace order.""" + + @pytest.mark.asyncio + async def test_replace_order(self): + """Test replacing an order.""" + broker = PaperBroker(fill_probability=0.0) + await broker.connect() + + order = await broker.submit_order( + OrderRequest.limit("AAPL", OrderSide.BUY, Decimal("10"), Decimal("100")) + ) + + # Replace with new quantity + new_order = await broker.replace_order( + order.broker_order_id, + quantity=Decimal("20"), + ) + + assert new_order.quantity == Decimal("20") + assert new_order.broker_order_id != order.broker_order_id + + @pytest.mark.asyncio + async def test_replace_order_marks_old_replaced(self): + """Test old order marked as replaced.""" + broker = PaperBroker(fill_probability=0.0) + await broker.connect() + + order = await broker.submit_order( + OrderRequest.limit("AAPL", OrderSide.BUY, Decimal("10"), Decimal("100")) + ) + old_id = order.broker_order_id + + await broker.replace_order(old_id, quantity=Decimal("20")) + + old_order = await broker.get_order(old_id) + assert old_order.status == OrderStatus.REPLACED + + +class TestPaperBrokerGetOrders: + """Test PaperBroker get orders.""" + + @pytest.mark.asyncio + async def test_get_order(self): + """Test getting single order.""" + broker = PaperBroker() + broker.set_price("AAPL", Decimal("10")) + await broker.connect() + + order = await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + retrieved = await broker.get_order(order.broker_order_id) + assert retrieved.broker_order_id == order.broker_order_id + + @pytest.mark.asyncio + async def test_get_order_not_found(self): + """Test getting nonexistent order.""" + broker = PaperBroker() + await broker.connect() + + with pytest.raises(OrderError, match="not found"): + await broker.get_order("INVALID-123") + + @pytest.mark.asyncio + async def test_get_orders_all(self): + """Test getting all orders.""" + broker = PaperBroker() + broker.set_price("AAPL", Decimal("10")) + await broker.connect() + + await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))) + await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, Decimal("20"))) + + orders = await broker.get_orders() + assert len(orders) == 2 + + @pytest.mark.asyncio + async def test_get_orders_filter_by_status(self): + """Test filtering orders by status.""" + broker = PaperBroker() + broker.set_price("AAPL", Decimal("10")) + await broker.connect() + + # Create filled order + await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))) + + # Create unfilled order + broker._fill_probability = 0.0 + await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))) + + filled_orders = await broker.get_orders(status=OrderStatus.FILLED) + assert len(filled_orders) == 1 + assert filled_orders[0].status == OrderStatus.FILLED + + @pytest.mark.asyncio + async def test_get_orders_filter_by_symbols(self): + """Test filtering orders by symbols.""" + broker = PaperBroker() + broker.set_price("AAPL", Decimal("10")) + broker.set_price("MSFT", Decimal("10")) + await broker.connect() + + await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))) + await broker.submit_order(OrderRequest.market("MSFT", OrderSide.BUY, Decimal("10"))) + + aapl_orders = await broker.get_orders(symbols=["AAPL"]) + assert len(aapl_orders) == 1 + assert aapl_orders[0].symbol == "AAPL" + + @pytest.mark.asyncio + async def test_get_orders_with_limit(self): + """Test getting limited number of orders.""" + broker = PaperBroker() + broker.set_price("AAPL", Decimal("10")) + await broker.connect() + + for _ in range(5): + await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, Decimal("1"))) + + orders = await broker.get_orders(limit=3) + assert len(orders) == 3 + + +class TestPaperBrokerPositions: + """Test PaperBroker position methods.""" + + @pytest.mark.asyncio + async def test_get_positions_empty(self): + """Test getting positions when empty.""" + broker = PaperBroker() + await broker.connect() + + positions = await broker.get_positions() + assert positions == [] + + @pytest.mark.asyncio + async def test_get_positions_after_buy(self): + """Test position created after buy.""" + broker = PaperBroker() + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + positions = await broker.get_positions() + assert len(positions) == 1 + assert positions[0].symbol == "AAPL" + assert positions[0].quantity == Decimal("10") + assert positions[0].side == PositionSide.LONG + + @pytest.mark.asyncio + async def test_get_position_single(self): + """Test getting single position.""" + broker = PaperBroker() + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + position = await broker.get_position("AAPL") + assert position is not None + assert position.symbol == "AAPL" + + @pytest.mark.asyncio + async def test_get_position_not_found(self): + """Test getting nonexistent position.""" + broker = PaperBroker() + await broker.connect() + + position = await broker.get_position("AAPL") + assert position is None + + @pytest.mark.asyncio + async def test_position_pnl_calculation(self): + """Test position P&L calculation.""" + broker = PaperBroker(slippage_percent=Decimal("0")) + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + # Buy at 100 + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + # Price goes up + broker.set_price("AAPL", Decimal("110")) + + position = await broker.get_position("AAPL") + assert position.unrealized_pnl == Decimal("100") # 10 shares * $10 gain + + @pytest.mark.asyncio + async def test_position_closed_on_sell(self): + """Test position closed when fully sold.""" + broker = PaperBroker(slippage_percent=Decimal("0")) + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + # Buy + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + # Sell all + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.SELL, Decimal("10")) + ) + + position = await broker.get_position("AAPL") + assert position is None + + @pytest.mark.asyncio + async def test_position_partial_sell(self): + """Test position reduced on partial sell.""" + broker = PaperBroker(slippage_percent=Decimal("0")) + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + # Buy + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + # Partial sell + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.SELL, Decimal("3")) + ) + + position = await broker.get_position("AAPL") + assert position.quantity == Decimal("7") + + +class TestPaperBrokerQuotes: + """Test PaperBroker quote methods.""" + + @pytest.mark.asyncio + async def test_get_quote(self): + """Test getting quote.""" + broker = PaperBroker() + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + quote = await broker.get_quote("AAPL") + assert quote.symbol == "AAPL" + assert quote.last_price == Decimal("100") + assert quote.bid_price is not None + assert quote.ask_price is not None + assert quote.bid_price < quote.ask_price + + @pytest.mark.asyncio + async def test_quote_spread(self): + """Test quote has bid/ask spread.""" + broker = PaperBroker() + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + quote = await broker.get_quote("AAPL") + spread = quote.ask_price - quote.bid_price + assert spread > Decimal("0") + + @pytest.mark.asyncio + async def test_quote_requires_connection(self): + """Test quote requires connection.""" + broker = PaperBroker() + with pytest.raises(ConnectionError): + await broker.get_quote("AAPL") + + +class TestPaperBrokerAssets: + """Test PaperBroker asset methods.""" + + @pytest.mark.asyncio + async def test_get_asset_equity(self): + """Test getting equity asset info.""" + broker = PaperBroker() + await broker.connect() + + asset = await broker.get_asset("AAPL") + assert asset.symbol == "AAPL" + assert asset.asset_class == AssetClass.EQUITY + assert asset.tradable is True + + @pytest.mark.asyncio + async def test_get_asset_crypto(self): + """Test getting crypto asset info.""" + broker = PaperBroker() + await broker.connect() + + asset = await broker.get_asset("BTCUSD") + assert asset.asset_class == AssetClass.CRYPTO + + @pytest.mark.asyncio + async def test_get_asset_etf(self): + """Test getting ETF asset info.""" + broker = PaperBroker() + await broker.connect() + + asset = await broker.get_asset("SPY") + assert asset.asset_class == AssetClass.ETF + + @pytest.mark.asyncio + async def test_get_asset_future(self): + """Test getting future asset info.""" + broker = PaperBroker() + await broker.connect() + + asset = await broker.get_asset("ES") + assert asset.asset_class == AssetClass.FUTURE + + +class TestPaperBrokerReset: + """Test PaperBroker reset functionality.""" + + @pytest.mark.asyncio + async def test_reset_clears_positions(self): + """Test reset clears all positions.""" + broker = PaperBroker() + broker.set_price("AAPL", Decimal("10")) + await broker.connect() + + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + broker.reset() + + positions = await broker.get_positions() + assert positions == [] + + @pytest.mark.asyncio + async def test_reset_clears_orders(self): + """Test reset clears all orders.""" + broker = PaperBroker() + broker.set_price("AAPL", Decimal("10")) + await broker.connect() + + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + broker.reset() + + orders = await broker.get_orders() + assert orders == [] + + @pytest.mark.asyncio + async def test_reset_restores_cash(self): + """Test reset restores initial cash.""" + broker = PaperBroker(initial_cash=Decimal("100000")) + broker.set_price("AAPL", Decimal("1000")) + await broker.connect() + + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + assert broker.cash < Decimal("100000") + + broker.reset() + + assert broker.cash == Decimal("100000") + + +class TestPaperBrokerCashManagement: + """Test PaperBroker cash management.""" + + @pytest.mark.asyncio + async def test_buy_reduces_cash(self): + """Test buying reduces cash.""" + broker = PaperBroker( + initial_cash=Decimal("100000"), + slippage_percent=Decimal("0"), + ) + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + # 10 shares at $100 = $1000 + assert broker.cash == Decimal("99000") + + @pytest.mark.asyncio + async def test_sell_increases_cash(self): + """Test selling increases cash.""" + broker = PaperBroker( + initial_cash=Decimal("100000"), + slippage_percent=Decimal("0"), + ) + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + # Buy first + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + # Then sell at higher price + broker.set_price("AAPL", Decimal("110")) + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.SELL, Decimal("10")) + ) + + # Should have initial + profit + # Buy: -$1000 (100*10), Sell: +$1100 (110*10) = +$100 profit + assert broker.cash == Decimal("100100") + + +class TestPaperBrokerAveragePriceCalculation: + """Test average price calculation for positions.""" + + @pytest.mark.asyncio + async def test_average_price_multiple_buys(self): + """Test average price calculation with multiple buys.""" + broker = PaperBroker(slippage_percent=Decimal("0")) + await broker.connect() + + # Buy 10 at $100 + broker.set_price("AAPL", Decimal("100")) + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + # Buy 10 more at $120 + broker.set_price("AAPL", Decimal("120")) + await broker.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + ) + + position = await broker.get_position("AAPL") + # Average: (10*100 + 10*120) / 20 = $110 + assert position.avg_entry_price == Decimal("110") + assert position.quantity == Decimal("20") diff --git a/tradingagents/execution/__init__.py b/tradingagents/execution/__init__.py index 8b4fa40c..1027e755 100644 --- a/tradingagents/execution/__init__.py +++ b/tradingagents/execution/__init__.py @@ -112,6 +112,8 @@ from .ibkr_broker import ( FUTURES_SPECS, ) +from .paper_broker import PaperBroker + __all__ = [ # Enums "AssetClass", @@ -154,4 +156,6 @@ __all__ = [ "IBKRBroker", "IB_INSYNC_AVAILABLE", "FUTURES_SPECS", + # Paper Broker + "PaperBroker", ] diff --git a/tradingagents/execution/paper_broker.py b/tradingagents/execution/paper_broker.py new file mode 100644 index 00000000..a9a427e8 --- /dev/null +++ b/tradingagents/execution/paper_broker.py @@ -0,0 +1,687 @@ +"""Paper Broker implementation for simulation trading. + +Issue #26: [EXEC-25] Paper broker - simulation mode + +This module provides a simulated broker for paper trading, backtesting, +and testing without real API connections. It simulates order execution, +position tracking, and account management. + +Features: + - Simulated order execution with configurable fill behavior + - Position tracking with P&L calculations + - Account balance management + - Slippage simulation + - Market data simulation + - No external dependencies required + +Example: + >>> from tradingagents.execution import PaperBroker, OrderRequest, OrderSide + >>> + >>> broker = PaperBroker( + ... initial_cash=100000, + ... slippage_percent=0.05, + ... ) + >>> + >>> await broker.connect() + >>> order = await broker.submit_order( + ... OrderRequest.market("AAPL", OrderSide.BUY, 100) + ... ) + >>> print(f"Order filled at {order.avg_fill_price}") +""" + +from __future__ import annotations + +import asyncio +import random +import uuid +from datetime import datetime, timezone +from decimal import Decimal +from typing import Any, Dict, List, Optional, Callable + +from .broker_base import ( + AccountInfo, + AssetClass, + AssetInfo, + AuthenticationError, + BrokerBase, + BrokerError, + ConnectionError, + InsufficientFundsError, + InvalidOrderError, + Order, + OrderError, + OrderRequest, + OrderSide, + OrderStatus, + OrderType, + Position, + PositionError, + PositionSide, + Quote, + RateLimitError, + TimeInForce, +) + + +class PaperBroker(BrokerBase): + """Simulated paper trading broker. + + Provides a fully simulated trading environment for testing, + backtesting, and paper trading without any real API connections. + + Attributes: + initial_cash: Starting cash balance + slippage_percent: Slippage applied to fills (0.05 = 0.05%) + fill_probability: Probability of order fills (0.0-1.0) + market_open: Whether market is simulated as open + + Example: + >>> broker = PaperBroker(initial_cash=100000) + >>> await broker.connect() + >>> order = await broker.submit_order( + ... OrderRequest.market("AAPL", OrderSide.BUY, 10) + ... ) + >>> positions = await broker.get_positions() + """ + + def __init__( + self, + initial_cash: Decimal = Decimal("100000"), + slippage_percent: Decimal = Decimal("0.05"), + fill_probability: float = 1.0, + market_open: bool = True, + price_provider: Optional[Callable[[str], Decimal]] = None, + **kwargs: Any, + ) -> None: + """Initialize paper broker. + + Args: + initial_cash: Starting cash balance (default: 100,000) + slippage_percent: Slippage as percentage (default: 0.05%) + fill_probability: Probability of fills (default: 1.0 = 100%) + market_open: Whether to simulate market as open + price_provider: Optional function to get prices for symbols + **kwargs: Additional arguments passed to BrokerBase. + """ + super().__init__( + name="Paper", + supported_asset_classes=[ + AssetClass.EQUITY, + AssetClass.ETF, + AssetClass.CRYPTO, + AssetClass.FUTURE, + AssetClass.OPTION, + AssetClass.FOREX, + ], + paper_trading=True, + **kwargs, + ) + + self._initial_cash = Decimal(str(initial_cash)) + self._cash = self._initial_cash + self._slippage_percent = Decimal(str(slippage_percent)) + self._fill_probability = fill_probability + self._market_open = market_open + self._price_provider = price_provider + + # Internal state + self._orders: Dict[str, Order] = {} + self._positions: Dict[str, Position] = {} + self._order_counter = 0 + + # Simulated price cache + self._prices: Dict[str, Decimal] = {} + + # Default prices for common symbols + self._default_prices = { + "AAPL": Decimal("175.00"), + "MSFT": Decimal("380.00"), + "GOOGL": Decimal("140.00"), + "AMZN": Decimal("155.00"), + "NVDA": Decimal("480.00"), + "META": Decimal("360.00"), + "TSLA": Decimal("250.00"), + "SPY": Decimal("470.00"), + "QQQ": Decimal("400.00"), + "IWM": Decimal("200.00"), + "BTCUSD": Decimal("45000.00"), + "ETHUSD": Decimal("2500.00"), + "ES": Decimal("4700.00"), + "NQ": Decimal("16500.00"), + } + + @property + def cash(self) -> Decimal: + """Get current cash balance.""" + return self._cash + + @property + def initial_cash(self) -> Decimal: + """Get initial cash balance.""" + return self._initial_cash + + def set_price(self, symbol: str, price: Decimal) -> None: + """Set simulated price for a symbol. + + Args: + symbol: Symbol to set price for + price: Price to set + """ + self._prices[symbol] = Decimal(str(price)) + + def get_simulated_price(self, symbol: str) -> Decimal: + """Get simulated price for a symbol. + + Args: + symbol: Symbol to get price for + + Returns: + Simulated price + """ + # Check custom price provider first + if self._price_provider: + return self._price_provider(symbol) + + # Check cached prices + if symbol in self._prices: + return self._prices[symbol] + + # Check default prices + if symbol in self._default_prices: + return self._default_prices[symbol] + + # Generate random price for unknown symbols + return Decimal("100.00") + Decimal(str(random.uniform(-10, 10))) + + def _require_connection(self) -> None: + """Require broker to be connected.""" + if not self.is_connected: + raise ConnectionError("Not connected to Paper broker. Call connect() first.") + + async def connect(self) -> bool: + """Connect to paper broker (always succeeds). + + Returns: + True always + """ + self._connected = True + return True + + async def disconnect(self) -> None: + """Disconnect from paper broker.""" + self._connected = False + + async def is_market_open(self) -> bool: + """Check if simulated market is open. + + Returns: + The configured market_open value + """ + return self._market_open + + def set_market_open(self, is_open: bool) -> None: + """Set market open status. + + Args: + is_open: Whether market should be open + """ + self._market_open = is_open + + async def get_account(self) -> AccountInfo: + """Get simulated account information. + + Returns: + AccountInfo with current simulated account state. + """ + self._require_connection() + + # Calculate portfolio value + portfolio_value = self._cash + for position in self._positions.values(): + portfolio_value += position.market_value + + return AccountInfo( + account_id="PAPER-" + str(uuid.uuid4())[:8].upper(), + account_type="paper", + status="active", + cash=self._cash, + portfolio_value=portfolio_value, + buying_power=self._cash, # Simplified: no margin + equity=portfolio_value, + ) + + def _generate_order_id(self) -> str: + """Generate unique order ID.""" + self._order_counter += 1 + return f"PAPER-{self._order_counter}-{uuid.uuid4().hex[:8]}" + + def _calculate_fill_price( + self, + symbol: str, + side: OrderSide, + order_type: OrderType, + limit_price: Optional[Decimal] = None, + ) -> Optional[Decimal]: + """Calculate fill price with slippage. + + Args: + symbol: Symbol being traded + side: Order side + order_type: Order type + limit_price: Limit price if applicable + + Returns: + Fill price or None if order shouldn't fill + """ + base_price = self.get_simulated_price(symbol) + + if order_type == OrderType.LIMIT: + if limit_price is None: + return None + # For limit orders, check if price is favorable + if side == OrderSide.BUY: + if base_price > limit_price: + return None # Market price above limit + return limit_price + else: + if base_price < limit_price: + return None # Market price below limit + return limit_price + + # Apply slippage for market orders + slippage_factor = self._slippage_percent / Decimal("100") + if side == OrderSide.BUY: + # Slippage increases price for buys + fill_price = base_price * (Decimal("1") + slippage_factor) + else: + # Slippage decreases price for sells + fill_price = base_price * (Decimal("1") - slippage_factor) + + return fill_price.quantize(Decimal("0.01")) + + def _should_fill(self) -> bool: + """Determine if order should fill based on fill probability.""" + return random.random() < self._fill_probability + + def _update_position( + self, + symbol: str, + side: OrderSide, + quantity: Decimal, + fill_price: Decimal, + ) -> None: + """Update position after fill. + + Args: + symbol: Symbol traded + side: Order side + quantity: Quantity filled + fill_price: Fill price + """ + if symbol in self._positions: + position = self._positions[symbol] + + if side == OrderSide.BUY: + # Add to position + new_quantity = position.quantity + quantity + total_cost = (position.avg_entry_price * position.quantity) + (fill_price * quantity) + new_avg_price = total_cost / new_quantity if new_quantity > 0 else fill_price + position.quantity = new_quantity + position.avg_entry_price = new_avg_price + else: + # Reduce position + position.quantity -= quantity + if position.quantity <= 0: + del self._positions[symbol] + return + + # Update market value and P&L + current_price = self.get_simulated_price(symbol) + position.current_price = current_price + position.market_value = position.quantity * current_price + position.cost_basis = position.quantity * position.avg_entry_price + position.unrealized_pnl = position.market_value - position.cost_basis + if position.cost_basis > 0: + position.unrealized_pnl_percent = ( + position.unrealized_pnl / position.cost_basis * Decimal("100") + ) + else: + if side == OrderSide.BUY: + # Create new long position + current_price = self.get_simulated_price(symbol) + self._positions[symbol] = Position( + symbol=symbol, + quantity=quantity, + side=PositionSide.LONG, + avg_entry_price=fill_price, + current_price=current_price, + market_value=quantity * current_price, + cost_basis=quantity * fill_price, + unrealized_pnl=quantity * (current_price - fill_price), + unrealized_pnl_percent=( + (current_price - fill_price) / fill_price * Decimal("100") + if fill_price > 0 else Decimal("0") + ), + ) + # For sells without existing position, we'd need short selling logic + # For simplicity, ignore sells without positions + + async def submit_order(self, request: OrderRequest) -> Order: + """Submit a simulated order. + + Args: + request: Order request details. + + Returns: + Order with execution details. + + Raises: + InvalidOrderError: If order parameters are invalid. + InsufficientFundsError: If insufficient funds. + """ + self._require_connection() + + # Validate order + if request.quantity <= 0: + raise InvalidOrderError("Order quantity must be positive") + + # Generate order ID + order_id = self._generate_order_id() + + # Calculate fill price + fill_price = self._calculate_fill_price( + request.symbol, + request.side, + request.order_type, + request.limit_price, + ) + + # Determine if order should fill + should_fill = self._should_fill() and fill_price is not None + + if should_fill: + # Check funds for buys + if request.side == OrderSide.BUY: + required_funds = request.quantity * fill_price + if required_funds > self._cash: + raise InsufficientFundsError( + f"Insufficient funds: need ${required_funds}, have ${self._cash}" + ) + # Deduct cash + self._cash -= required_funds + + # For sells, add cash back + else: + proceeds = request.quantity * fill_price + self._cash += proceeds + + # Update position + self._update_position( + request.symbol, + request.side, + request.quantity, + fill_price, + ) + + status = OrderStatus.FILLED + filled_qty = request.quantity + avg_fill = fill_price + filled_at = datetime.now(timezone.utc) + else: + status = OrderStatus.NEW + filled_qty = Decimal("0") + avg_fill = None + filled_at = None + + # Create order + order = Order( + broker_order_id=order_id, + client_order_id=request.client_order_id or "", + symbol=request.symbol, + side=request.side, + quantity=request.quantity, + order_type=request.order_type, + status=status, + limit_price=request.limit_price, + stop_price=request.stop_price, + time_in_force=request.time_in_force, + filled_quantity=filled_qty, + filled_avg_price=avg_fill, + created_at=datetime.now(timezone.utc), + filled_at=filled_at, + ) + + self._orders[order_id] = order + return order + + async def cancel_order(self, order_id: str) -> Order: + """Cancel a simulated order. + + Args: + order_id: Order ID to cancel. + + Returns: + Cancelled order. + + Raises: + OrderError: If order not found or cannot be cancelled. + """ + self._require_connection() + + if order_id not in self._orders: + raise OrderError(f"Order {order_id} not found") + + order = self._orders[order_id] + + # Can only cancel unfilled orders + if order.status == OrderStatus.FILLED: + raise OrderError("Cannot cancel filled order") + + order.status = OrderStatus.CANCELLED + order.cancelled_at = datetime.now(timezone.utc) + + return order + + async def replace_order( + self, + order_id: str, + quantity: Optional[Decimal] = None, + limit_price: Optional[Decimal] = None, + stop_price: Optional[Decimal] = None, + time_in_force: Optional[TimeInForce] = None, + ) -> Order: + """Replace a simulated order. + + Creates a new order with updated parameters. + """ + self._require_connection() + + if order_id not in self._orders: + raise OrderError(f"Order {order_id} not found") + + old_order = self._orders[order_id] + + # Cancel old order + if old_order.status != OrderStatus.FILLED: + old_order.status = OrderStatus.REPLACED + + # Create new order request + request = OrderRequest( + symbol=old_order.symbol, + side=old_order.side, + quantity=quantity or old_order.quantity, + order_type=old_order.order_type, + time_in_force=time_in_force or old_order.time_in_force, + limit_price=limit_price or old_order.limit_price, + stop_price=stop_price or old_order.stop_price, + ) + + return await self.submit_order(request) + + async def get_order(self, order_id: str) -> Order: + """Get order by ID. + + Args: + order_id: Order ID. + + Returns: + Order details. + + Raises: + OrderError: If order not found. + """ + self._require_connection() + + if order_id not in self._orders: + raise OrderError(f"Order {order_id} not found") + + return self._orders[order_id] + + async def get_orders( + self, + status: Optional[OrderStatus] = None, + limit: int = 100, + symbols: Optional[List[str]] = None, + ) -> List[Order]: + """Get orders with optional filters. + + Args: + status: Filter by status. + limit: Maximum number to return. + symbols: Filter by symbols. + + Returns: + List of matching orders. + """ + self._require_connection() + + orders = list(self._orders.values()) + + # Apply filters + if status: + orders = [o for o in orders if o.status == status] + if symbols: + orders = [o for o in orders if o.symbol in symbols] + + # Sort by creation time, most recent first + orders.sort(key=lambda o: o.created_at or datetime.min, reverse=True) + + return orders[:limit] + + async def get_positions(self) -> List[Position]: + """Get all positions. + + Returns: + List of current positions. + """ + self._require_connection() + + # Update current prices + for symbol, position in self._positions.items(): + current_price = self.get_simulated_price(symbol) + position.current_price = current_price + position.market_value = position.quantity * current_price + position.unrealized_pnl = position.market_value - position.cost_basis + if position.cost_basis > 0: + position.unrealized_pnl_percent = ( + position.unrealized_pnl / position.cost_basis * Decimal("100") + ) + + return list(self._positions.values()) + + async def get_position(self, symbol: str) -> Optional[Position]: + """Get position for a specific symbol. + + Args: + symbol: Symbol to get position for. + + Returns: + Position if exists, None otherwise. + """ + self._require_connection() + + if symbol in self._positions: + position = self._positions[symbol] + # Update current price + current_price = self.get_simulated_price(symbol) + position.current_price = current_price + position.market_value = position.quantity * current_price + position.unrealized_pnl = position.market_value - position.cost_basis + return position + + return None + + async def get_quote(self, symbol: str) -> Quote: + """Get simulated quote. + + Args: + symbol: Symbol to get quote for. + + Returns: + Simulated quote data. + """ + self._require_connection() + + base_price = self.get_simulated_price(symbol) + + # Simulate bid/ask spread + spread = base_price * Decimal("0.001") # 0.1% spread + bid = base_price - spread / 2 + ask = base_price + spread / 2 + + return Quote( + symbol=symbol, + bid_price=bid.quantize(Decimal("0.01")), + ask_price=ask.quantize(Decimal("0.01")), + last_price=base_price, + bid_size=random.randint(100, 1000), + ask_size=random.randint(100, 1000), + volume=random.randint(100000, 10000000), + timestamp=datetime.now(timezone.utc), + ) + + async def get_asset(self, symbol: str) -> AssetInfo: + """Get simulated asset information. + + Args: + symbol: Symbol to get info for. + + Returns: + Simulated asset information. + """ + self._require_connection() + + # Determine asset class based on symbol patterns + if symbol.endswith("USD"): + asset_class = AssetClass.CRYPTO + elif symbol in ["ES", "NQ", "CL", "GC"]: + asset_class = AssetClass.FUTURE + elif symbol in ["SPY", "QQQ", "IWM", "VTI"]: + asset_class = AssetClass.ETF + else: + asset_class = AssetClass.EQUITY + + return AssetInfo( + symbol=symbol, + name=f"{symbol} (Paper)", + asset_class=asset_class, + exchange="PAPER", + tradable=True, + shortable=True, + marginable=True, + fractionable=False, + ) + + def reset(self) -> None: + """Reset broker to initial state. + + Clears all positions and orders, resets cash to initial amount. + """ + self._cash = self._initial_cash + self._orders.clear() + self._positions.clear() + self._order_counter = 0 + + +# Export +__all__ = ["PaperBroker"]