feat(execution): add Paper broker for simulation mode - Issue #26 (63 tests)
This commit is contained in:
parent
1e32c0e965
commit
834d18fb51
|
|
@ -0,0 +1,859 @@
|
|||
"""Tests for Paper Broker implementation.
|
||||
|
||||
Issue #26: [EXEC-25] Paper broker - simulation mode
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from tradingagents.execution import (
|
||||
PaperBroker,
|
||||
OrderRequest,
|
||||
OrderSide,
|
||||
OrderType,
|
||||
OrderStatus,
|
||||
TimeInForce,
|
||||
AssetClass,
|
||||
PositionSide,
|
||||
ConnectionError,
|
||||
OrderError,
|
||||
InvalidOrderError,
|
||||
InsufficientFundsError,
|
||||
)
|
||||
|
||||
|
||||
class TestPaperBrokerInit:
|
||||
"""Test PaperBroker initialization."""
|
||||
|
||||
def test_default_initialization(self):
|
||||
"""Test default broker initialization."""
|
||||
broker = PaperBroker()
|
||||
assert broker.name == "Paper"
|
||||
assert broker.initial_cash == Decimal("100000")
|
||||
assert broker.cash == Decimal("100000")
|
||||
assert broker.is_paper_trading is True
|
||||
assert broker.is_connected is False
|
||||
|
||||
def test_custom_initial_cash(self):
|
||||
"""Test initialization with custom initial cash."""
|
||||
broker = PaperBroker(initial_cash=Decimal("50000"))
|
||||
assert broker.initial_cash == Decimal("50000")
|
||||
assert broker.cash == Decimal("50000")
|
||||
|
||||
def test_custom_slippage(self):
|
||||
"""Test initialization with custom slippage."""
|
||||
broker = PaperBroker(slippage_percent=Decimal("0.1"))
|
||||
assert broker._slippage_percent == Decimal("0.1")
|
||||
|
||||
def test_custom_fill_probability(self):
|
||||
"""Test initialization with custom fill probability."""
|
||||
broker = PaperBroker(fill_probability=0.5)
|
||||
assert broker._fill_probability == 0.5
|
||||
|
||||
def test_market_closed_initialization(self):
|
||||
"""Test initialization with market closed."""
|
||||
broker = PaperBroker(market_open=False)
|
||||
assert broker._market_open is False
|
||||
|
||||
def test_supported_asset_classes(self):
|
||||
"""Test supported asset classes include all types."""
|
||||
broker = PaperBroker()
|
||||
assert AssetClass.EQUITY in broker.supported_asset_classes
|
||||
assert AssetClass.ETF in broker.supported_asset_classes
|
||||
assert AssetClass.CRYPTO in broker.supported_asset_classes
|
||||
assert AssetClass.FUTURE in broker.supported_asset_classes
|
||||
assert AssetClass.OPTION in broker.supported_asset_classes
|
||||
assert AssetClass.FOREX in broker.supported_asset_classes
|
||||
|
||||
def test_custom_price_provider(self):
|
||||
"""Test initialization with custom price provider."""
|
||||
def price_provider(symbol: str) -> Decimal:
|
||||
return Decimal("123.45")
|
||||
|
||||
broker = PaperBroker(price_provider=price_provider)
|
||||
assert broker.get_simulated_price("ANY") == Decimal("123.45")
|
||||
|
||||
|
||||
class TestPaperBrokerConnection:
|
||||
"""Test PaperBroker connection methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_succeeds(self):
|
||||
"""Test connect always succeeds."""
|
||||
broker = PaperBroker()
|
||||
result = await broker.connect()
|
||||
assert result is True
|
||||
assert broker.is_connected is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect(self):
|
||||
"""Test disconnect."""
|
||||
broker = PaperBroker()
|
||||
await broker.connect()
|
||||
await broker.disconnect()
|
||||
assert broker.is_connected is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_connects(self):
|
||||
"""Test multiple connects work."""
|
||||
broker = PaperBroker()
|
||||
await broker.connect()
|
||||
await broker.connect()
|
||||
assert broker.is_connected is True
|
||||
|
||||
|
||||
class TestPaperBrokerMarketStatus:
|
||||
"""Test PaperBroker market status."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_open_by_default(self):
|
||||
"""Test market is open by default."""
|
||||
broker = PaperBroker()
|
||||
await broker.connect()
|
||||
assert await broker.is_market_open() is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_closed(self):
|
||||
"""Test market closed simulation."""
|
||||
broker = PaperBroker(market_open=False)
|
||||
await broker.connect()
|
||||
assert await broker.is_market_open() is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_market_open(self):
|
||||
"""Test changing market open status."""
|
||||
broker = PaperBroker(market_open=False)
|
||||
broker.set_market_open(True)
|
||||
await broker.connect()
|
||||
assert await broker.is_market_open() is True
|
||||
|
||||
|
||||
class TestPaperBrokerPrices:
|
||||
"""Test PaperBroker price simulation."""
|
||||
|
||||
def test_set_and_get_price(self):
|
||||
"""Test setting and getting prices."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("TEST", Decimal("99.99"))
|
||||
assert broker.get_simulated_price("TEST") == Decimal("99.99")
|
||||
|
||||
def test_default_prices(self):
|
||||
"""Test default prices for common symbols."""
|
||||
broker = PaperBroker()
|
||||
assert broker.get_simulated_price("AAPL") == Decimal("175.00")
|
||||
assert broker.get_simulated_price("MSFT") == Decimal("380.00")
|
||||
assert broker.get_simulated_price("SPY") == Decimal("470.00")
|
||||
|
||||
def test_crypto_default_prices(self):
|
||||
"""Test default crypto prices."""
|
||||
broker = PaperBroker()
|
||||
assert broker.get_simulated_price("BTCUSD") == Decimal("45000.00")
|
||||
assert broker.get_simulated_price("ETHUSD") == Decimal("2500.00")
|
||||
|
||||
def test_futures_default_prices(self):
|
||||
"""Test default futures prices."""
|
||||
broker = PaperBroker()
|
||||
assert broker.get_simulated_price("ES") == Decimal("4700.00")
|
||||
assert broker.get_simulated_price("NQ") == Decimal("16500.00")
|
||||
|
||||
def test_unknown_symbol_generates_price(self):
|
||||
"""Test unknown symbols generate random prices."""
|
||||
broker = PaperBroker()
|
||||
price = broker.get_simulated_price("UNKNOWN")
|
||||
# Should be around 100 +/- 10
|
||||
assert Decimal("80") < price < Decimal("120")
|
||||
|
||||
|
||||
class TestPaperBrokerAccount:
|
||||
"""Test PaperBroker account methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_account_requires_connection(self):
|
||||
"""Test get_account requires connection."""
|
||||
broker = PaperBroker()
|
||||
with pytest.raises(ConnectionError, match="Not connected"):
|
||||
await broker.get_account()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_account_basic(self):
|
||||
"""Test basic account information."""
|
||||
broker = PaperBroker(initial_cash=Decimal("50000"))
|
||||
await broker.connect()
|
||||
|
||||
account = await broker.get_account()
|
||||
assert account.account_type == "paper"
|
||||
assert account.status == "active"
|
||||
assert account.cash == Decimal("50000")
|
||||
assert account.portfolio_value == Decimal("50000")
|
||||
assert account.buying_power == Decimal("50000")
|
||||
assert account.account_id.startswith("PAPER-")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_account_with_positions(self):
|
||||
"""Test account includes position values."""
|
||||
broker = PaperBroker(initial_cash=Decimal("100000"))
|
||||
broker.set_price("AAPL", Decimal("150"))
|
||||
await broker.connect()
|
||||
|
||||
# Buy some shares
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
account = await broker.get_account()
|
||||
# Cash reduced by purchase
|
||||
assert account.cash < Decimal("100000")
|
||||
# Portfolio includes position
|
||||
assert account.portfolio_value > account.cash
|
||||
|
||||
|
||||
class TestPaperBrokerOrders:
|
||||
"""Test PaperBroker order methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_market_buy_order(self):
|
||||
"""Test submitting market buy order."""
|
||||
broker = PaperBroker(initial_cash=Decimal("100000"))
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
order = await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
assert order.symbol == "AAPL"
|
||||
assert order.side == OrderSide.BUY
|
||||
assert order.quantity == Decimal("10")
|
||||
assert order.status == OrderStatus.FILLED
|
||||
assert order.filled_quantity == Decimal("10")
|
||||
assert order.broker_order_id.startswith("PAPER-")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_market_sell_order(self):
|
||||
"""Test submitting market sell order."""
|
||||
broker = PaperBroker(initial_cash=Decimal("100000"))
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
# Buy first
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
# Then sell
|
||||
order = await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.SELL, Decimal("5"))
|
||||
)
|
||||
|
||||
assert order.status == OrderStatus.FILLED
|
||||
assert order.filled_quantity == Decimal("5")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_limit_order_fills(self):
|
||||
"""Test limit order that should fill."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
# Limit above market price - should fill
|
||||
order = await broker.submit_order(
|
||||
OrderRequest.limit("AAPL", OrderSide.BUY, Decimal("10"), Decimal("110"))
|
||||
)
|
||||
|
||||
assert order.status == OrderStatus.FILLED
|
||||
assert order.filled_avg_price == Decimal("110")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_limit_order_no_fill(self):
|
||||
"""Test limit order that shouldn't fill."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
# Limit below market price - shouldn't fill
|
||||
order = await broker.submit_order(
|
||||
OrderRequest.limit("AAPL", OrderSide.BUY, Decimal("10"), Decimal("90"))
|
||||
)
|
||||
|
||||
assert order.status == OrderStatus.NEW
|
||||
assert order.filled_quantity == Decimal("0")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_order_requires_connection(self):
|
||||
"""Test order submission requires connection."""
|
||||
broker = PaperBroker()
|
||||
with pytest.raises(ConnectionError, match="Not connected"):
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_order_quantity(self):
|
||||
"""Test invalid order quantity."""
|
||||
broker = PaperBroker()
|
||||
await broker.connect()
|
||||
|
||||
with pytest.raises(InvalidOrderError, match="quantity must be positive"):
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("-10"))
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insufficient_funds(self):
|
||||
"""Test insufficient funds error."""
|
||||
broker = PaperBroker(initial_cash=Decimal("100"))
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
with pytest.raises(InsufficientFundsError, match="Insufficient funds"):
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slippage_on_buy(self):
|
||||
"""Test slippage applied to buy orders."""
|
||||
broker = PaperBroker(slippage_percent=Decimal("1.0"))
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
order = await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("1"))
|
||||
)
|
||||
|
||||
# 1% slippage on $100 = $101
|
||||
assert order.filled_avg_price == Decimal("101.00")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slippage_on_sell(self):
|
||||
"""Test slippage applied to sell orders."""
|
||||
broker = PaperBroker(slippage_percent=Decimal("1.0"))
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
# Buy first
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
# Sell with slippage
|
||||
order = await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.SELL, Decimal("5"))
|
||||
)
|
||||
|
||||
# 1% slippage on $100 = $99
|
||||
assert order.filled_avg_price == Decimal("99.00")
|
||||
|
||||
|
||||
class TestPaperBrokerFillProbability:
|
||||
"""Test PaperBroker fill probability."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_fill_probability(self):
|
||||
"""Test orders don't fill with 0% probability."""
|
||||
broker = PaperBroker(fill_probability=0.0)
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
order = await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
assert order.status == OrderStatus.NEW
|
||||
assert order.filled_quantity == Decimal("0")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_fill_probability(self):
|
||||
"""Test orders always fill with 100% probability."""
|
||||
broker = PaperBroker(fill_probability=1.0)
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
order = await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
assert order.status == OrderStatus.FILLED
|
||||
|
||||
|
||||
class TestPaperBrokerCancelOrder:
|
||||
"""Test PaperBroker cancel order."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_unfilled_order(self):
|
||||
"""Test cancelling unfilled order."""
|
||||
broker = PaperBroker(fill_probability=0.0)
|
||||
await broker.connect()
|
||||
|
||||
order = await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
cancelled = await broker.cancel_order(order.broker_order_id)
|
||||
assert cancelled.status == OrderStatus.CANCELLED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_filled_order_fails(self):
|
||||
"""Test cannot cancel filled order."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("AAPL", Decimal("10"))
|
||||
await broker.connect()
|
||||
|
||||
order = await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
with pytest.raises(OrderError, match="Cannot cancel filled order"):
|
||||
await broker.cancel_order(order.broker_order_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_nonexistent_order(self):
|
||||
"""Test cancelling nonexistent order."""
|
||||
broker = PaperBroker()
|
||||
await broker.connect()
|
||||
|
||||
with pytest.raises(OrderError, match="not found"):
|
||||
await broker.cancel_order("INVALID-123")
|
||||
|
||||
|
||||
class TestPaperBrokerReplaceOrder:
|
||||
"""Test PaperBroker replace order."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_order(self):
|
||||
"""Test replacing an order."""
|
||||
broker = PaperBroker(fill_probability=0.0)
|
||||
await broker.connect()
|
||||
|
||||
order = await broker.submit_order(
|
||||
OrderRequest.limit("AAPL", OrderSide.BUY, Decimal("10"), Decimal("100"))
|
||||
)
|
||||
|
||||
# Replace with new quantity
|
||||
new_order = await broker.replace_order(
|
||||
order.broker_order_id,
|
||||
quantity=Decimal("20"),
|
||||
)
|
||||
|
||||
assert new_order.quantity == Decimal("20")
|
||||
assert new_order.broker_order_id != order.broker_order_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_order_marks_old_replaced(self):
|
||||
"""Test old order marked as replaced."""
|
||||
broker = PaperBroker(fill_probability=0.0)
|
||||
await broker.connect()
|
||||
|
||||
order = await broker.submit_order(
|
||||
OrderRequest.limit("AAPL", OrderSide.BUY, Decimal("10"), Decimal("100"))
|
||||
)
|
||||
old_id = order.broker_order_id
|
||||
|
||||
await broker.replace_order(old_id, quantity=Decimal("20"))
|
||||
|
||||
old_order = await broker.get_order(old_id)
|
||||
assert old_order.status == OrderStatus.REPLACED
|
||||
|
||||
|
||||
class TestPaperBrokerGetOrders:
|
||||
"""Test PaperBroker get orders."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_order(self):
|
||||
"""Test getting single order."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("AAPL", Decimal("10"))
|
||||
await broker.connect()
|
||||
|
||||
order = await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
retrieved = await broker.get_order(order.broker_order_id)
|
||||
assert retrieved.broker_order_id == order.broker_order_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_order_not_found(self):
|
||||
"""Test getting nonexistent order."""
|
||||
broker = PaperBroker()
|
||||
await broker.connect()
|
||||
|
||||
with pytest.raises(OrderError, match="not found"):
|
||||
await broker.get_order("INVALID-123")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_orders_all(self):
|
||||
"""Test getting all orders."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("AAPL", Decimal("10"))
|
||||
await broker.connect()
|
||||
|
||||
await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")))
|
||||
await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, Decimal("20")))
|
||||
|
||||
orders = await broker.get_orders()
|
||||
assert len(orders) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_orders_filter_by_status(self):
|
||||
"""Test filtering orders by status."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("AAPL", Decimal("10"))
|
||||
await broker.connect()
|
||||
|
||||
# Create filled order
|
||||
await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")))
|
||||
|
||||
# Create unfilled order
|
||||
broker._fill_probability = 0.0
|
||||
await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")))
|
||||
|
||||
filled_orders = await broker.get_orders(status=OrderStatus.FILLED)
|
||||
assert len(filled_orders) == 1
|
||||
assert filled_orders[0].status == OrderStatus.FILLED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_orders_filter_by_symbols(self):
|
||||
"""Test filtering orders by symbols."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("AAPL", Decimal("10"))
|
||||
broker.set_price("MSFT", Decimal("10"))
|
||||
await broker.connect()
|
||||
|
||||
await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")))
|
||||
await broker.submit_order(OrderRequest.market("MSFT", OrderSide.BUY, Decimal("10")))
|
||||
|
||||
aapl_orders = await broker.get_orders(symbols=["AAPL"])
|
||||
assert len(aapl_orders) == 1
|
||||
assert aapl_orders[0].symbol == "AAPL"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_orders_with_limit(self):
|
||||
"""Test getting limited number of orders."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("AAPL", Decimal("10"))
|
||||
await broker.connect()
|
||||
|
||||
for _ in range(5):
|
||||
await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, Decimal("1")))
|
||||
|
||||
orders = await broker.get_orders(limit=3)
|
||||
assert len(orders) == 3
|
||||
|
||||
|
||||
class TestPaperBrokerPositions:
|
||||
"""Test PaperBroker position methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_positions_empty(self):
|
||||
"""Test getting positions when empty."""
|
||||
broker = PaperBroker()
|
||||
await broker.connect()
|
||||
|
||||
positions = await broker.get_positions()
|
||||
assert positions == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_positions_after_buy(self):
|
||||
"""Test position created after buy."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
positions = await broker.get_positions()
|
||||
assert len(positions) == 1
|
||||
assert positions[0].symbol == "AAPL"
|
||||
assert positions[0].quantity == Decimal("10")
|
||||
assert positions[0].side == PositionSide.LONG
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_position_single(self):
|
||||
"""Test getting single position."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
position = await broker.get_position("AAPL")
|
||||
assert position is not None
|
||||
assert position.symbol == "AAPL"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_position_not_found(self):
|
||||
"""Test getting nonexistent position."""
|
||||
broker = PaperBroker()
|
||||
await broker.connect()
|
||||
|
||||
position = await broker.get_position("AAPL")
|
||||
assert position is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_position_pnl_calculation(self):
|
||||
"""Test position P&L calculation."""
|
||||
broker = PaperBroker(slippage_percent=Decimal("0"))
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
# Buy at 100
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
# Price goes up
|
||||
broker.set_price("AAPL", Decimal("110"))
|
||||
|
||||
position = await broker.get_position("AAPL")
|
||||
assert position.unrealized_pnl == Decimal("100") # 10 shares * $10 gain
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_position_closed_on_sell(self):
|
||||
"""Test position closed when fully sold."""
|
||||
broker = PaperBroker(slippage_percent=Decimal("0"))
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
# Buy
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
# Sell all
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.SELL, Decimal("10"))
|
||||
)
|
||||
|
||||
position = await broker.get_position("AAPL")
|
||||
assert position is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_position_partial_sell(self):
|
||||
"""Test position reduced on partial sell."""
|
||||
broker = PaperBroker(slippage_percent=Decimal("0"))
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
# Buy
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
# Partial sell
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.SELL, Decimal("3"))
|
||||
)
|
||||
|
||||
position = await broker.get_position("AAPL")
|
||||
assert position.quantity == Decimal("7")
|
||||
|
||||
|
||||
class TestPaperBrokerQuotes:
|
||||
"""Test PaperBroker quote methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_quote(self):
|
||||
"""Test getting quote."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
quote = await broker.get_quote("AAPL")
|
||||
assert quote.symbol == "AAPL"
|
||||
assert quote.last_price == Decimal("100")
|
||||
assert quote.bid_price is not None
|
||||
assert quote.ask_price is not None
|
||||
assert quote.bid_price < quote.ask_price
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quote_spread(self):
|
||||
"""Test quote has bid/ask spread."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
quote = await broker.get_quote("AAPL")
|
||||
spread = quote.ask_price - quote.bid_price
|
||||
assert spread > Decimal("0")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quote_requires_connection(self):
|
||||
"""Test quote requires connection."""
|
||||
broker = PaperBroker()
|
||||
with pytest.raises(ConnectionError):
|
||||
await broker.get_quote("AAPL")
|
||||
|
||||
|
||||
class TestPaperBrokerAssets:
|
||||
"""Test PaperBroker asset methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_asset_equity(self):
|
||||
"""Test getting equity asset info."""
|
||||
broker = PaperBroker()
|
||||
await broker.connect()
|
||||
|
||||
asset = await broker.get_asset("AAPL")
|
||||
assert asset.symbol == "AAPL"
|
||||
assert asset.asset_class == AssetClass.EQUITY
|
||||
assert asset.tradable is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_asset_crypto(self):
|
||||
"""Test getting crypto asset info."""
|
||||
broker = PaperBroker()
|
||||
await broker.connect()
|
||||
|
||||
asset = await broker.get_asset("BTCUSD")
|
||||
assert asset.asset_class == AssetClass.CRYPTO
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_asset_etf(self):
|
||||
"""Test getting ETF asset info."""
|
||||
broker = PaperBroker()
|
||||
await broker.connect()
|
||||
|
||||
asset = await broker.get_asset("SPY")
|
||||
assert asset.asset_class == AssetClass.ETF
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_asset_future(self):
|
||||
"""Test getting future asset info."""
|
||||
broker = PaperBroker()
|
||||
await broker.connect()
|
||||
|
||||
asset = await broker.get_asset("ES")
|
||||
assert asset.asset_class == AssetClass.FUTURE
|
||||
|
||||
|
||||
class TestPaperBrokerReset:
|
||||
"""Test PaperBroker reset functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_clears_positions(self):
|
||||
"""Test reset clears all positions."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("AAPL", Decimal("10"))
|
||||
await broker.connect()
|
||||
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
broker.reset()
|
||||
|
||||
positions = await broker.get_positions()
|
||||
assert positions == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_clears_orders(self):
|
||||
"""Test reset clears all orders."""
|
||||
broker = PaperBroker()
|
||||
broker.set_price("AAPL", Decimal("10"))
|
||||
await broker.connect()
|
||||
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
broker.reset()
|
||||
|
||||
orders = await broker.get_orders()
|
||||
assert orders == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_restores_cash(self):
|
||||
"""Test reset restores initial cash."""
|
||||
broker = PaperBroker(initial_cash=Decimal("100000"))
|
||||
broker.set_price("AAPL", Decimal("1000"))
|
||||
await broker.connect()
|
||||
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
assert broker.cash < Decimal("100000")
|
||||
|
||||
broker.reset()
|
||||
|
||||
assert broker.cash == Decimal("100000")
|
||||
|
||||
|
||||
class TestPaperBrokerCashManagement:
|
||||
"""Test PaperBroker cash management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buy_reduces_cash(self):
|
||||
"""Test buying reduces cash."""
|
||||
broker = PaperBroker(
|
||||
initial_cash=Decimal("100000"),
|
||||
slippage_percent=Decimal("0"),
|
||||
)
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
# 10 shares at $100 = $1000
|
||||
assert broker.cash == Decimal("99000")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sell_increases_cash(self):
|
||||
"""Test selling increases cash."""
|
||||
broker = PaperBroker(
|
||||
initial_cash=Decimal("100000"),
|
||||
slippage_percent=Decimal("0"),
|
||||
)
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.connect()
|
||||
|
||||
# Buy first
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
# Then sell at higher price
|
||||
broker.set_price("AAPL", Decimal("110"))
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.SELL, Decimal("10"))
|
||||
)
|
||||
|
||||
# Should have initial + profit
|
||||
# Buy: -$1000 (100*10), Sell: +$1100 (110*10) = +$100 profit
|
||||
assert broker.cash == Decimal("100100")
|
||||
|
||||
|
||||
class TestPaperBrokerAveragePriceCalculation:
|
||||
"""Test average price calculation for positions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_average_price_multiple_buys(self):
|
||||
"""Test average price calculation with multiple buys."""
|
||||
broker = PaperBroker(slippage_percent=Decimal("0"))
|
||||
await broker.connect()
|
||||
|
||||
# Buy 10 at $100
|
||||
broker.set_price("AAPL", Decimal("100"))
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
# Buy 10 more at $120
|
||||
broker.set_price("AAPL", Decimal("120"))
|
||||
await broker.submit_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
)
|
||||
|
||||
position = await broker.get_position("AAPL")
|
||||
# Average: (10*100 + 10*120) / 20 = $110
|
||||
assert position.avg_entry_price == Decimal("110")
|
||||
assert position.quantity == Decimal("20")
|
||||
|
|
@ -112,6 +112,8 @@ from .ibkr_broker import (
|
|||
FUTURES_SPECS,
|
||||
)
|
||||
|
||||
from .paper_broker import PaperBroker
|
||||
|
||||
__all__ = [
|
||||
# Enums
|
||||
"AssetClass",
|
||||
|
|
@ -154,4 +156,6 @@ __all__ = [
|
|||
"IBKRBroker",
|
||||
"IB_INSYNC_AVAILABLE",
|
||||
"FUTURES_SPECS",
|
||||
# Paper Broker
|
||||
"PaperBroker",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,687 @@
|
|||
"""Paper Broker implementation for simulation trading.
|
||||
|
||||
Issue #26: [EXEC-25] Paper broker - simulation mode
|
||||
|
||||
This module provides a simulated broker for paper trading, backtesting,
|
||||
and testing without real API connections. It simulates order execution,
|
||||
position tracking, and account management.
|
||||
|
||||
Features:
|
||||
- Simulated order execution with configurable fill behavior
|
||||
- Position tracking with P&L calculations
|
||||
- Account balance management
|
||||
- Slippage simulation
|
||||
- Market data simulation
|
||||
- No external dependencies required
|
||||
|
||||
Example:
|
||||
>>> from tradingagents.execution import PaperBroker, OrderRequest, OrderSide
|
||||
>>>
|
||||
>>> broker = PaperBroker(
|
||||
... initial_cash=100000,
|
||||
... slippage_percent=0.05,
|
||||
... )
|
||||
>>>
|
||||
>>> await broker.connect()
|
||||
>>> order = await broker.submit_order(
|
||||
... OrderRequest.market("AAPL", OrderSide.BUY, 100)
|
||||
... )
|
||||
>>> print(f"Order filled at {order.avg_fill_price}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
|
||||
from .broker_base import (
|
||||
AccountInfo,
|
||||
AssetClass,
|
||||
AssetInfo,
|
||||
AuthenticationError,
|
||||
BrokerBase,
|
||||
BrokerError,
|
||||
ConnectionError,
|
||||
InsufficientFundsError,
|
||||
InvalidOrderError,
|
||||
Order,
|
||||
OrderError,
|
||||
OrderRequest,
|
||||
OrderSide,
|
||||
OrderStatus,
|
||||
OrderType,
|
||||
Position,
|
||||
PositionError,
|
||||
PositionSide,
|
||||
Quote,
|
||||
RateLimitError,
|
||||
TimeInForce,
|
||||
)
|
||||
|
||||
|
||||
class PaperBroker(BrokerBase):
|
||||
"""Simulated paper trading broker.
|
||||
|
||||
Provides a fully simulated trading environment for testing,
|
||||
backtesting, and paper trading without any real API connections.
|
||||
|
||||
Attributes:
|
||||
initial_cash: Starting cash balance
|
||||
slippage_percent: Slippage applied to fills (0.05 = 0.05%)
|
||||
fill_probability: Probability of order fills (0.0-1.0)
|
||||
market_open: Whether market is simulated as open
|
||||
|
||||
Example:
|
||||
>>> broker = PaperBroker(initial_cash=100000)
|
||||
>>> await broker.connect()
|
||||
>>> order = await broker.submit_order(
|
||||
... OrderRequest.market("AAPL", OrderSide.BUY, 10)
|
||||
... )
|
||||
>>> positions = await broker.get_positions()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_cash: Decimal = Decimal("100000"),
|
||||
slippage_percent: Decimal = Decimal("0.05"),
|
||||
fill_probability: float = 1.0,
|
||||
market_open: bool = True,
|
||||
price_provider: Optional[Callable[[str], Decimal]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize paper broker.
|
||||
|
||||
Args:
|
||||
initial_cash: Starting cash balance (default: 100,000)
|
||||
slippage_percent: Slippage as percentage (default: 0.05%)
|
||||
fill_probability: Probability of fills (default: 1.0 = 100%)
|
||||
market_open: Whether to simulate market as open
|
||||
price_provider: Optional function to get prices for symbols
|
||||
**kwargs: Additional arguments passed to BrokerBase.
|
||||
"""
|
||||
super().__init__(
|
||||
name="Paper",
|
||||
supported_asset_classes=[
|
||||
AssetClass.EQUITY,
|
||||
AssetClass.ETF,
|
||||
AssetClass.CRYPTO,
|
||||
AssetClass.FUTURE,
|
||||
AssetClass.OPTION,
|
||||
AssetClass.FOREX,
|
||||
],
|
||||
paper_trading=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._initial_cash = Decimal(str(initial_cash))
|
||||
self._cash = self._initial_cash
|
||||
self._slippage_percent = Decimal(str(slippage_percent))
|
||||
self._fill_probability = fill_probability
|
||||
self._market_open = market_open
|
||||
self._price_provider = price_provider
|
||||
|
||||
# Internal state
|
||||
self._orders: Dict[str, Order] = {}
|
||||
self._positions: Dict[str, Position] = {}
|
||||
self._order_counter = 0
|
||||
|
||||
# Simulated price cache
|
||||
self._prices: Dict[str, Decimal] = {}
|
||||
|
||||
# Default prices for common symbols
|
||||
self._default_prices = {
|
||||
"AAPL": Decimal("175.00"),
|
||||
"MSFT": Decimal("380.00"),
|
||||
"GOOGL": Decimal("140.00"),
|
||||
"AMZN": Decimal("155.00"),
|
||||
"NVDA": Decimal("480.00"),
|
||||
"META": Decimal("360.00"),
|
||||
"TSLA": Decimal("250.00"),
|
||||
"SPY": Decimal("470.00"),
|
||||
"QQQ": Decimal("400.00"),
|
||||
"IWM": Decimal("200.00"),
|
||||
"BTCUSD": Decimal("45000.00"),
|
||||
"ETHUSD": Decimal("2500.00"),
|
||||
"ES": Decimal("4700.00"),
|
||||
"NQ": Decimal("16500.00"),
|
||||
}
|
||||
|
||||
@property
|
||||
def cash(self) -> Decimal:
|
||||
"""Get current cash balance."""
|
||||
return self._cash
|
||||
|
||||
@property
|
||||
def initial_cash(self) -> Decimal:
|
||||
"""Get initial cash balance."""
|
||||
return self._initial_cash
|
||||
|
||||
def set_price(self, symbol: str, price: Decimal) -> None:
|
||||
"""Set simulated price for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Symbol to set price for
|
||||
price: Price to set
|
||||
"""
|
||||
self._prices[symbol] = Decimal(str(price))
|
||||
|
||||
def get_simulated_price(self, symbol: str) -> Decimal:
|
||||
"""Get simulated price for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Symbol to get price for
|
||||
|
||||
Returns:
|
||||
Simulated price
|
||||
"""
|
||||
# Check custom price provider first
|
||||
if self._price_provider:
|
||||
return self._price_provider(symbol)
|
||||
|
||||
# Check cached prices
|
||||
if symbol in self._prices:
|
||||
return self._prices[symbol]
|
||||
|
||||
# Check default prices
|
||||
if symbol in self._default_prices:
|
||||
return self._default_prices[symbol]
|
||||
|
||||
# Generate random price for unknown symbols
|
||||
return Decimal("100.00") + Decimal(str(random.uniform(-10, 10)))
|
||||
|
||||
def _require_connection(self) -> None:
|
||||
"""Require broker to be connected."""
|
||||
if not self.is_connected:
|
||||
raise ConnectionError("Not connected to Paper broker. Call connect() first.")
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to paper broker (always succeeds).
|
||||
|
||||
Returns:
|
||||
True always
|
||||
"""
|
||||
self._connected = True
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from paper broker."""
|
||||
self._connected = False
|
||||
|
||||
async def is_market_open(self) -> bool:
|
||||
"""Check if simulated market is open.
|
||||
|
||||
Returns:
|
||||
The configured market_open value
|
||||
"""
|
||||
return self._market_open
|
||||
|
||||
def set_market_open(self, is_open: bool) -> None:
|
||||
"""Set market open status.
|
||||
|
||||
Args:
|
||||
is_open: Whether market should be open
|
||||
"""
|
||||
self._market_open = is_open
|
||||
|
||||
async def get_account(self) -> AccountInfo:
|
||||
"""Get simulated account information.
|
||||
|
||||
Returns:
|
||||
AccountInfo with current simulated account state.
|
||||
"""
|
||||
self._require_connection()
|
||||
|
||||
# Calculate portfolio value
|
||||
portfolio_value = self._cash
|
||||
for position in self._positions.values():
|
||||
portfolio_value += position.market_value
|
||||
|
||||
return AccountInfo(
|
||||
account_id="PAPER-" + str(uuid.uuid4())[:8].upper(),
|
||||
account_type="paper",
|
||||
status="active",
|
||||
cash=self._cash,
|
||||
portfolio_value=portfolio_value,
|
||||
buying_power=self._cash, # Simplified: no margin
|
||||
equity=portfolio_value,
|
||||
)
|
||||
|
||||
def _generate_order_id(self) -> str:
|
||||
"""Generate unique order ID."""
|
||||
self._order_counter += 1
|
||||
return f"PAPER-{self._order_counter}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
def _calculate_fill_price(
|
||||
self,
|
||||
symbol: str,
|
||||
side: OrderSide,
|
||||
order_type: OrderType,
|
||||
limit_price: Optional[Decimal] = None,
|
||||
) -> Optional[Decimal]:
|
||||
"""Calculate fill price with slippage.
|
||||
|
||||
Args:
|
||||
symbol: Symbol being traded
|
||||
side: Order side
|
||||
order_type: Order type
|
||||
limit_price: Limit price if applicable
|
||||
|
||||
Returns:
|
||||
Fill price or None if order shouldn't fill
|
||||
"""
|
||||
base_price = self.get_simulated_price(symbol)
|
||||
|
||||
if order_type == OrderType.LIMIT:
|
||||
if limit_price is None:
|
||||
return None
|
||||
# For limit orders, check if price is favorable
|
||||
if side == OrderSide.BUY:
|
||||
if base_price > limit_price:
|
||||
return None # Market price above limit
|
||||
return limit_price
|
||||
else:
|
||||
if base_price < limit_price:
|
||||
return None # Market price below limit
|
||||
return limit_price
|
||||
|
||||
# Apply slippage for market orders
|
||||
slippage_factor = self._slippage_percent / Decimal("100")
|
||||
if side == OrderSide.BUY:
|
||||
# Slippage increases price for buys
|
||||
fill_price = base_price * (Decimal("1") + slippage_factor)
|
||||
else:
|
||||
# Slippage decreases price for sells
|
||||
fill_price = base_price * (Decimal("1") - slippage_factor)
|
||||
|
||||
return fill_price.quantize(Decimal("0.01"))
|
||||
|
||||
def _should_fill(self) -> bool:
|
||||
"""Determine if order should fill based on fill probability."""
|
||||
return random.random() < self._fill_probability
|
||||
|
||||
def _update_position(
|
||||
self,
|
||||
symbol: str,
|
||||
side: OrderSide,
|
||||
quantity: Decimal,
|
||||
fill_price: Decimal,
|
||||
) -> None:
|
||||
"""Update position after fill.
|
||||
|
||||
Args:
|
||||
symbol: Symbol traded
|
||||
side: Order side
|
||||
quantity: Quantity filled
|
||||
fill_price: Fill price
|
||||
"""
|
||||
if symbol in self._positions:
|
||||
position = self._positions[symbol]
|
||||
|
||||
if side == OrderSide.BUY:
|
||||
# Add to position
|
||||
new_quantity = position.quantity + quantity
|
||||
total_cost = (position.avg_entry_price * position.quantity) + (fill_price * quantity)
|
||||
new_avg_price = total_cost / new_quantity if new_quantity > 0 else fill_price
|
||||
position.quantity = new_quantity
|
||||
position.avg_entry_price = new_avg_price
|
||||
else:
|
||||
# Reduce position
|
||||
position.quantity -= quantity
|
||||
if position.quantity <= 0:
|
||||
del self._positions[symbol]
|
||||
return
|
||||
|
||||
# Update market value and P&L
|
||||
current_price = self.get_simulated_price(symbol)
|
||||
position.current_price = current_price
|
||||
position.market_value = position.quantity * current_price
|
||||
position.cost_basis = position.quantity * position.avg_entry_price
|
||||
position.unrealized_pnl = position.market_value - position.cost_basis
|
||||
if position.cost_basis > 0:
|
||||
position.unrealized_pnl_percent = (
|
||||
position.unrealized_pnl / position.cost_basis * Decimal("100")
|
||||
)
|
||||
else:
|
||||
if side == OrderSide.BUY:
|
||||
# Create new long position
|
||||
current_price = self.get_simulated_price(symbol)
|
||||
self._positions[symbol] = Position(
|
||||
symbol=symbol,
|
||||
quantity=quantity,
|
||||
side=PositionSide.LONG,
|
||||
avg_entry_price=fill_price,
|
||||
current_price=current_price,
|
||||
market_value=quantity * current_price,
|
||||
cost_basis=quantity * fill_price,
|
||||
unrealized_pnl=quantity * (current_price - fill_price),
|
||||
unrealized_pnl_percent=(
|
||||
(current_price - fill_price) / fill_price * Decimal("100")
|
||||
if fill_price > 0 else Decimal("0")
|
||||
),
|
||||
)
|
||||
# For sells without existing position, we'd need short selling logic
|
||||
# For simplicity, ignore sells without positions
|
||||
|
||||
async def submit_order(self, request: OrderRequest) -> Order:
|
||||
"""Submit a simulated order.
|
||||
|
||||
Args:
|
||||
request: Order request details.
|
||||
|
||||
Returns:
|
||||
Order with execution details.
|
||||
|
||||
Raises:
|
||||
InvalidOrderError: If order parameters are invalid.
|
||||
InsufficientFundsError: If insufficient funds.
|
||||
"""
|
||||
self._require_connection()
|
||||
|
||||
# Validate order
|
||||
if request.quantity <= 0:
|
||||
raise InvalidOrderError("Order quantity must be positive")
|
||||
|
||||
# Generate order ID
|
||||
order_id = self._generate_order_id()
|
||||
|
||||
# Calculate fill price
|
||||
fill_price = self._calculate_fill_price(
|
||||
request.symbol,
|
||||
request.side,
|
||||
request.order_type,
|
||||
request.limit_price,
|
||||
)
|
||||
|
||||
# Determine if order should fill
|
||||
should_fill = self._should_fill() and fill_price is not None
|
||||
|
||||
if should_fill:
|
||||
# Check funds for buys
|
||||
if request.side == OrderSide.BUY:
|
||||
required_funds = request.quantity * fill_price
|
||||
if required_funds > self._cash:
|
||||
raise InsufficientFundsError(
|
||||
f"Insufficient funds: need ${required_funds}, have ${self._cash}"
|
||||
)
|
||||
# Deduct cash
|
||||
self._cash -= required_funds
|
||||
|
||||
# For sells, add cash back
|
||||
else:
|
||||
proceeds = request.quantity * fill_price
|
||||
self._cash += proceeds
|
||||
|
||||
# Update position
|
||||
self._update_position(
|
||||
request.symbol,
|
||||
request.side,
|
||||
request.quantity,
|
||||
fill_price,
|
||||
)
|
||||
|
||||
status = OrderStatus.FILLED
|
||||
filled_qty = request.quantity
|
||||
avg_fill = fill_price
|
||||
filled_at = datetime.now(timezone.utc)
|
||||
else:
|
||||
status = OrderStatus.NEW
|
||||
filled_qty = Decimal("0")
|
||||
avg_fill = None
|
||||
filled_at = None
|
||||
|
||||
# Create order
|
||||
order = Order(
|
||||
broker_order_id=order_id,
|
||||
client_order_id=request.client_order_id or "",
|
||||
symbol=request.symbol,
|
||||
side=request.side,
|
||||
quantity=request.quantity,
|
||||
order_type=request.order_type,
|
||||
status=status,
|
||||
limit_price=request.limit_price,
|
||||
stop_price=request.stop_price,
|
||||
time_in_force=request.time_in_force,
|
||||
filled_quantity=filled_qty,
|
||||
filled_avg_price=avg_fill,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
filled_at=filled_at,
|
||||
)
|
||||
|
||||
self._orders[order_id] = order
|
||||
return order
|
||||
|
||||
async def cancel_order(self, order_id: str) -> Order:
|
||||
"""Cancel a simulated order.
|
||||
|
||||
Args:
|
||||
order_id: Order ID to cancel.
|
||||
|
||||
Returns:
|
||||
Cancelled order.
|
||||
|
||||
Raises:
|
||||
OrderError: If order not found or cannot be cancelled.
|
||||
"""
|
||||
self._require_connection()
|
||||
|
||||
if order_id not in self._orders:
|
||||
raise OrderError(f"Order {order_id} not found")
|
||||
|
||||
order = self._orders[order_id]
|
||||
|
||||
# Can only cancel unfilled orders
|
||||
if order.status == OrderStatus.FILLED:
|
||||
raise OrderError("Cannot cancel filled order")
|
||||
|
||||
order.status = OrderStatus.CANCELLED
|
||||
order.cancelled_at = datetime.now(timezone.utc)
|
||||
|
||||
return order
|
||||
|
||||
async def replace_order(
|
||||
self,
|
||||
order_id: str,
|
||||
quantity: Optional[Decimal] = None,
|
||||
limit_price: Optional[Decimal] = None,
|
||||
stop_price: Optional[Decimal] = None,
|
||||
time_in_force: Optional[TimeInForce] = None,
|
||||
) -> Order:
|
||||
"""Replace a simulated order.
|
||||
|
||||
Creates a new order with updated parameters.
|
||||
"""
|
||||
self._require_connection()
|
||||
|
||||
if order_id not in self._orders:
|
||||
raise OrderError(f"Order {order_id} not found")
|
||||
|
||||
old_order = self._orders[order_id]
|
||||
|
||||
# Cancel old order
|
||||
if old_order.status != OrderStatus.FILLED:
|
||||
old_order.status = OrderStatus.REPLACED
|
||||
|
||||
# Create new order request
|
||||
request = OrderRequest(
|
||||
symbol=old_order.symbol,
|
||||
side=old_order.side,
|
||||
quantity=quantity or old_order.quantity,
|
||||
order_type=old_order.order_type,
|
||||
time_in_force=time_in_force or old_order.time_in_force,
|
||||
limit_price=limit_price or old_order.limit_price,
|
||||
stop_price=stop_price or old_order.stop_price,
|
||||
)
|
||||
|
||||
return await self.submit_order(request)
|
||||
|
||||
async def get_order(self, order_id: str) -> Order:
|
||||
"""Get order by ID.
|
||||
|
||||
Args:
|
||||
order_id: Order ID.
|
||||
|
||||
Returns:
|
||||
Order details.
|
||||
|
||||
Raises:
|
||||
OrderError: If order not found.
|
||||
"""
|
||||
self._require_connection()
|
||||
|
||||
if order_id not in self._orders:
|
||||
raise OrderError(f"Order {order_id} not found")
|
||||
|
||||
return self._orders[order_id]
|
||||
|
||||
async def get_orders(
|
||||
self,
|
||||
status: Optional[OrderStatus] = None,
|
||||
limit: int = 100,
|
||||
symbols: Optional[List[str]] = None,
|
||||
) -> List[Order]:
|
||||
"""Get orders with optional filters.
|
||||
|
||||
Args:
|
||||
status: Filter by status.
|
||||
limit: Maximum number to return.
|
||||
symbols: Filter by symbols.
|
||||
|
||||
Returns:
|
||||
List of matching orders.
|
||||
"""
|
||||
self._require_connection()
|
||||
|
||||
orders = list(self._orders.values())
|
||||
|
||||
# Apply filters
|
||||
if status:
|
||||
orders = [o for o in orders if o.status == status]
|
||||
if symbols:
|
||||
orders = [o for o in orders if o.symbol in symbols]
|
||||
|
||||
# Sort by creation time, most recent first
|
||||
orders.sort(key=lambda o: o.created_at or datetime.min, reverse=True)
|
||||
|
||||
return orders[:limit]
|
||||
|
||||
async def get_positions(self) -> List[Position]:
|
||||
"""Get all positions.
|
||||
|
||||
Returns:
|
||||
List of current positions.
|
||||
"""
|
||||
self._require_connection()
|
||||
|
||||
# Update current prices
|
||||
for symbol, position in self._positions.items():
|
||||
current_price = self.get_simulated_price(symbol)
|
||||
position.current_price = current_price
|
||||
position.market_value = position.quantity * current_price
|
||||
position.unrealized_pnl = position.market_value - position.cost_basis
|
||||
if position.cost_basis > 0:
|
||||
position.unrealized_pnl_percent = (
|
||||
position.unrealized_pnl / position.cost_basis * Decimal("100")
|
||||
)
|
||||
|
||||
return list(self._positions.values())
|
||||
|
||||
async def get_position(self, symbol: str) -> Optional[Position]:
|
||||
"""Get position for a specific symbol.
|
||||
|
||||
Args:
|
||||
symbol: Symbol to get position for.
|
||||
|
||||
Returns:
|
||||
Position if exists, None otherwise.
|
||||
"""
|
||||
self._require_connection()
|
||||
|
||||
if symbol in self._positions:
|
||||
position = self._positions[symbol]
|
||||
# Update current price
|
||||
current_price = self.get_simulated_price(symbol)
|
||||
position.current_price = current_price
|
||||
position.market_value = position.quantity * current_price
|
||||
position.unrealized_pnl = position.market_value - position.cost_basis
|
||||
return position
|
||||
|
||||
return None
|
||||
|
||||
async def get_quote(self, symbol: str) -> Quote:
|
||||
"""Get simulated quote.
|
||||
|
||||
Args:
|
||||
symbol: Symbol to get quote for.
|
||||
|
||||
Returns:
|
||||
Simulated quote data.
|
||||
"""
|
||||
self._require_connection()
|
||||
|
||||
base_price = self.get_simulated_price(symbol)
|
||||
|
||||
# Simulate bid/ask spread
|
||||
spread = base_price * Decimal("0.001") # 0.1% spread
|
||||
bid = base_price - spread / 2
|
||||
ask = base_price + spread / 2
|
||||
|
||||
return Quote(
|
||||
symbol=symbol,
|
||||
bid_price=bid.quantize(Decimal("0.01")),
|
||||
ask_price=ask.quantize(Decimal("0.01")),
|
||||
last_price=base_price,
|
||||
bid_size=random.randint(100, 1000),
|
||||
ask_size=random.randint(100, 1000),
|
||||
volume=random.randint(100000, 10000000),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
async def get_asset(self, symbol: str) -> AssetInfo:
|
||||
"""Get simulated asset information.
|
||||
|
||||
Args:
|
||||
symbol: Symbol to get info for.
|
||||
|
||||
Returns:
|
||||
Simulated asset information.
|
||||
"""
|
||||
self._require_connection()
|
||||
|
||||
# Determine asset class based on symbol patterns
|
||||
if symbol.endswith("USD"):
|
||||
asset_class = AssetClass.CRYPTO
|
||||
elif symbol in ["ES", "NQ", "CL", "GC"]:
|
||||
asset_class = AssetClass.FUTURE
|
||||
elif symbol in ["SPY", "QQQ", "IWM", "VTI"]:
|
||||
asset_class = AssetClass.ETF
|
||||
else:
|
||||
asset_class = AssetClass.EQUITY
|
||||
|
||||
return AssetInfo(
|
||||
symbol=symbol,
|
||||
name=f"{symbol} (Paper)",
|
||||
asset_class=asset_class,
|
||||
exchange="PAPER",
|
||||
tradable=True,
|
||||
shortable=True,
|
||||
marginable=True,
|
||||
fractionable=False,
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset broker to initial state.
|
||||
|
||||
Clears all positions and orders, resets cash to initial amount.
|
||||
"""
|
||||
self._cash = self._initial_cash
|
||||
self._orders.clear()
|
||||
self._positions.clear()
|
||||
self._order_counter = 0
|
||||
|
||||
|
||||
# Export
|
||||
__all__ = ["PaperBroker"]
|
||||
Loading…
Reference in New Issue