TradingAgents/tests/unit/execution/test_broker_base.py

1278 lines
42 KiB
Python

"""Tests for Broker Base Interface module.
Issue #22: [EXEC-21] Broker base interface - abstract broker class
"""
import pytest
from datetime import datetime
from decimal import Decimal
from typing import Dict, List, Optional
from unittest.mock import AsyncMock, MagicMock, patch
from tradingagents.execution import (
# Enums
AssetClass,
OrderSide,
OrderType,
TimeInForce,
OrderStatus,
PositionSide,
# Data Classes
OrderRequest,
Order,
Position,
AccountInfo,
Quote,
AssetInfo,
# Exceptions
BrokerError,
ConnectionError,
AuthenticationError,
OrderError,
InsufficientFundsError,
InvalidOrderError,
PositionError,
RateLimitError,
# Abstract Base Class
BrokerBase,
)
# =============================================================================
# Mock Broker Implementation for Testing
# =============================================================================
class MockBroker(BrokerBase):
"""Mock broker implementation for testing abstract base class."""
def __init__(self, **kwargs):
super().__init__(
name="MockBroker",
supported_asset_classes=[AssetClass.EQUITY, AssetClass.ETF, AssetClass.CRYPTO],
**kwargs,
)
self._orders: Dict[str, Order] = {}
self._positions: Dict[str, Position] = {}
self._account: Optional[AccountInfo] = None
self._quotes: Dict[str, Quote] = {}
self._assets: Dict[str, AssetInfo] = {}
async def connect(self) -> bool:
self._connected = True
self._account = AccountInfo(
account_id="TEST123",
account_type="margin",
status="active",
currency="USD",
cash=Decimal("100000"),
portfolio_value=Decimal("150000"),
buying_power=Decimal("200000"),
equity=Decimal("150000"),
)
return True
async def disconnect(self) -> None:
self._connected = False
async def is_market_open(self) -> bool:
return True
async def get_account(self) -> AccountInfo:
if not self._connected:
raise ConnectionError("Not connected")
return self._account
async def submit_order(self, request: OrderRequest) -> Order:
if not self._connected:
raise ConnectionError("Not connected")
order = Order(
broker_order_id=f"ORD-{len(self._orders) + 1}",
client_order_id=request.client_order_id,
symbol=request.symbol,
side=request.side,
quantity=request.quantity,
order_type=request.order_type,
status=OrderStatus.NEW,
limit_price=request.limit_price,
stop_price=request.stop_price,
time_in_force=request.time_in_force,
created_at=datetime.now(),
submitted_at=datetime.now(),
)
self._orders[order.broker_order_id] = order
return order
async def cancel_order(self, order_id: str) -> Order:
if order_id not in self._orders:
raise OrderError(f"Order {order_id} not found")
order = self._orders[order_id]
order.status = OrderStatus.CANCELLED
order.cancelled_at = datetime.now()
return order
async def replace_order(
self,
order_id: str,
quantity: Optional[Decimal] = None,
limit_price: Optional[Decimal] = None,
stop_price: Optional[Decimal] = None,
time_in_force: Optional[TimeInForce] = None,
) -> Order:
if order_id not in self._orders:
raise OrderError(f"Order {order_id} not found")
old_order = self._orders[order_id]
old_order.status = OrderStatus.REPLACED
new_order = Order(
broker_order_id=f"ORD-{len(self._orders) + 1}",
client_order_id=old_order.client_order_id,
symbol=old_order.symbol,
side=old_order.side,
quantity=quantity or old_order.quantity,
order_type=old_order.order_type,
status=OrderStatus.NEW,
limit_price=limit_price or old_order.limit_price,
stop_price=stop_price or old_order.stop_price,
time_in_force=time_in_force or old_order.time_in_force,
created_at=datetime.now(),
submitted_at=datetime.now(),
)
self._orders[new_order.broker_order_id] = new_order
return new_order
async def get_order(self, order_id: str) -> Order:
if order_id not in self._orders:
raise OrderError(f"Order {order_id} not found")
return self._orders[order_id]
async def get_orders(
self,
status: Optional[OrderStatus] = None,
limit: int = 100,
symbols: Optional[List[str]] = None,
) -> List[Order]:
orders = list(self._orders.values())
if status:
orders = [o for o in orders if o.status == status]
if symbols:
orders = [o for o in orders if o.symbol in symbols]
return orders[:limit]
async def get_positions(self) -> List[Position]:
return list(self._positions.values())
async def get_position(self, symbol: str) -> Optional[Position]:
return self._positions.get(symbol)
async def get_quote(self, symbol: str) -> Quote:
if symbol in self._quotes:
return self._quotes[symbol]
# Return default quote
return Quote(
symbol=symbol,
bid_price=Decimal("100.00"),
bid_size=Decimal("100"),
ask_price=Decimal("100.05"),
ask_size=Decimal("200"),
last_price=Decimal("100.02"),
volume=1000000,
timestamp=datetime.now(),
)
async def get_asset(self, symbol: str) -> AssetInfo:
if symbol in self._assets:
return self._assets[symbol]
# Return default asset
return AssetInfo(
symbol=symbol,
name=f"{symbol} Inc.",
asset_class=AssetClass.EQUITY,
tradable=True,
marginable=True,
shortable=True,
)
# Helper methods for testing
def add_position(self, position: Position) -> None:
self._positions[position.symbol] = position
def add_quote(self, quote: Quote) -> None:
self._quotes[quote.symbol] = quote
def add_asset(self, asset: AssetInfo) -> None:
self._assets[asset.symbol] = asset
def fill_order(self, order_id: str, avg_price: Decimal) -> Order:
if order_id in self._orders:
order = self._orders[order_id]
order.status = OrderStatus.FILLED
order.filled_quantity = order.quantity
order.filled_avg_price = avg_price
order.filled_at = datetime.now()
return self._orders.get(order_id)
# =============================================================================
# Enum Tests
# =============================================================================
class TestAssetClass:
"""Tests for AssetClass enum."""
def test_all_asset_classes(self):
"""Test all asset classes are defined."""
assert AssetClass.EQUITY.value == "equity"
assert AssetClass.ETF.value == "etf"
assert AssetClass.OPTION.value == "option"
assert AssetClass.FUTURE.value == "future"
assert AssetClass.CRYPTO.value == "crypto"
assert AssetClass.FOREX.value == "forex"
assert AssetClass.BOND.value == "bond"
assert AssetClass.INDEX.value == "index"
class TestOrderSide:
"""Tests for OrderSide enum."""
def test_order_sides(self):
"""Test order sides are defined."""
assert OrderSide.BUY.value == "buy"
assert OrderSide.SELL.value == "sell"
class TestOrderType:
"""Tests for OrderType enum."""
def test_order_types(self):
"""Test all order types are defined."""
assert OrderType.MARKET.value == "market"
assert OrderType.LIMIT.value == "limit"
assert OrderType.STOP.value == "stop"
assert OrderType.STOP_LIMIT.value == "stop_limit"
assert OrderType.TRAILING_STOP.value == "trailing_stop"
class TestTimeInForce:
"""Tests for TimeInForce enum."""
def test_time_in_force_values(self):
"""Test all time in force values are defined."""
assert TimeInForce.DAY.value == "day"
assert TimeInForce.GTC.value == "gtc"
assert TimeInForce.IOC.value == "ioc"
assert TimeInForce.FOK.value == "fok"
assert TimeInForce.OPG.value == "opg"
assert TimeInForce.CLS.value == "cls"
assert TimeInForce.GTD.value == "gtd"
class TestOrderStatus:
"""Tests for OrderStatus enum."""
def test_order_status_values(self):
"""Test all order status values are defined."""
assert OrderStatus.PENDING_NEW.value == "pending_new"
assert OrderStatus.NEW.value == "new"
assert OrderStatus.PARTIALLY_FILLED.value == "partially_filled"
assert OrderStatus.FILLED.value == "filled"
assert OrderStatus.PENDING_CANCEL.value == "pending_cancel"
assert OrderStatus.CANCELLED.value == "cancelled"
assert OrderStatus.REJECTED.value == "rejected"
assert OrderStatus.EXPIRED.value == "expired"
assert OrderStatus.REPLACED.value == "replaced"
class TestPositionSide:
"""Tests for PositionSide enum."""
def test_position_sides(self):
"""Test position sides are defined."""
assert PositionSide.LONG.value == "long"
assert PositionSide.SHORT.value == "short"
# =============================================================================
# OrderRequest Tests
# =============================================================================
class TestOrderRequest:
"""Tests for OrderRequest dataclass."""
def test_create_market_order(self):
"""Test creating a market order."""
request = OrderRequest.market("AAPL", OrderSide.BUY, 100)
assert request.symbol == "AAPL"
assert request.side == OrderSide.BUY
assert request.quantity == Decimal("100")
assert request.order_type == OrderType.MARKET
assert request.time_in_force == TimeInForce.DAY
assert request.client_order_id is not None
def test_create_limit_order(self):
"""Test creating a limit order."""
request = OrderRequest.limit(
symbol="AAPL",
side=OrderSide.BUY,
quantity=100,
limit_price=150.00,
)
assert request.symbol == "AAPL"
assert request.side == OrderSide.BUY
assert request.quantity == Decimal("100")
assert request.order_type == OrderType.LIMIT
assert request.limit_price == Decimal("150.00")
assert request.time_in_force == TimeInForce.GTC
def test_create_stop_order(self):
"""Test creating a stop order."""
request = OrderRequest.stop(
symbol="AAPL",
side=OrderSide.SELL,
quantity=100,
stop_price=145.00,
)
assert request.order_type == OrderType.STOP
assert request.stop_price == Decimal("145.00")
def test_create_stop_limit_order(self):
"""Test creating a stop-limit order."""
request = OrderRequest.stop_limit(
symbol="AAPL",
side=OrderSide.SELL,
quantity=100,
stop_price=145.00,
limit_price=144.50,
)
assert request.order_type == OrderType.STOP_LIMIT
assert request.stop_price == Decimal("145.00")
assert request.limit_price == Decimal("144.50")
def test_create_trailing_stop_percent(self):
"""Test creating a trailing stop order with percent."""
request = OrderRequest.trailing_stop(
symbol="AAPL",
side=OrderSide.SELL,
quantity=100,
trail_percent=5.0,
)
assert request.order_type == OrderType.TRAILING_STOP
assert request.trail_percent == Decimal("5.0")
assert request.trail_amount is None
def test_create_trailing_stop_amount(self):
"""Test creating a trailing stop order with amount."""
request = OrderRequest.trailing_stop(
symbol="AAPL",
side=OrderSide.SELL,
quantity=100,
trail_amount=5.00,
)
assert request.order_type == OrderType.TRAILING_STOP
assert request.trail_amount == Decimal("5.00")
assert request.trail_percent is None
def test_limit_order_requires_limit_price(self):
"""Test that limit orders require limit price."""
with pytest.raises(ValueError, match="Limit orders require limit_price"):
OrderRequest(
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
order_type=OrderType.LIMIT,
)
def test_stop_order_requires_stop_price(self):
"""Test that stop orders require stop price."""
with pytest.raises(ValueError, match="Stop orders require stop_price"):
OrderRequest(
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
order_type=OrderType.STOP,
)
def test_stop_limit_requires_both_prices(self):
"""Test that stop-limit orders require both prices."""
with pytest.raises(ValueError, match="Stop-limit orders require both"):
OrderRequest(
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
order_type=OrderType.STOP_LIMIT,
limit_price=Decimal("150"),
# Missing stop_price
)
def test_trailing_stop_requires_trail_value(self):
"""Test that trailing stop orders require trail value."""
with pytest.raises(ValueError, match="Trailing stop orders require"):
OrderRequest(
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
order_type=OrderType.TRAILING_STOP,
)
def test_order_with_metadata(self):
"""Test order with custom metadata."""
request = OrderRequest.market(
"AAPL",
OrderSide.BUY,
100,
metadata={"strategy": "momentum", "signal_strength": 0.85},
)
assert request.metadata["strategy"] == "momentum"
assert request.metadata["signal_strength"] == 0.85
# =============================================================================
# Order Tests
# =============================================================================
class TestOrder:
"""Tests for Order dataclass."""
def test_order_is_open(self):
"""Test is_open property."""
order = Order(
broker_order_id="ORD-1",
client_order_id="CLT-1",
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
order_type=OrderType.MARKET,
status=OrderStatus.NEW,
)
assert order.is_open is True
order.status = OrderStatus.FILLED
assert order.is_open is False
def test_order_is_filled(self):
"""Test is_filled property."""
order = Order(
broker_order_id="ORD-1",
client_order_id="CLT-1",
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
order_type=OrderType.MARKET,
status=OrderStatus.FILLED,
)
assert order.is_filled is True
def test_order_is_cancelled(self):
"""Test is_cancelled property."""
order = Order(
broker_order_id="ORD-1",
client_order_id="CLT-1",
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
order_type=OrderType.MARKET,
status=OrderStatus.CANCELLED,
)
assert order.is_cancelled is True
def test_remaining_quantity(self):
"""Test remaining_quantity property."""
order = Order(
broker_order_id="ORD-1",
client_order_id="CLT-1",
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
order_type=OrderType.MARKET,
status=OrderStatus.PARTIALLY_FILLED,
filled_quantity=Decimal("60"),
)
assert order.remaining_quantity == Decimal("40")
def test_fill_percent(self):
"""Test fill_percent property."""
order = Order(
broker_order_id="ORD-1",
client_order_id="CLT-1",
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
order_type=OrderType.MARKET,
status=OrderStatus.PARTIALLY_FILLED,
filled_quantity=Decimal("60"),
)
assert order.fill_percent == 60.0
# =============================================================================
# Position Tests
# =============================================================================
class TestPosition:
"""Tests for Position dataclass."""
def test_position_long(self):
"""Test long position properties."""
position = Position(
symbol="AAPL",
quantity=Decimal("100"),
side=PositionSide.LONG,
avg_entry_price=Decimal("150.00"),
current_price=Decimal("160.00"),
market_value=Decimal("16000.00"),
cost_basis=Decimal("15000.00"),
unrealized_pnl=Decimal("1000.00"),
unrealized_pnl_percent=Decimal("6.67"),
)
assert position.is_long is True
assert position.is_short is False
assert position.abs_quantity == Decimal("100")
def test_position_short(self):
"""Test short position properties."""
position = Position(
symbol="AAPL",
quantity=Decimal("-100"),
side=PositionSide.SHORT,
avg_entry_price=Decimal("150.00"),
current_price=Decimal("140.00"),
market_value=Decimal("-14000.00"),
cost_basis=Decimal("-15000.00"),
unrealized_pnl=Decimal("1000.00"),
unrealized_pnl_percent=Decimal("6.67"),
)
assert position.is_long is False
assert position.is_short is True
assert position.abs_quantity == Decimal("100")
# =============================================================================
# AccountInfo Tests
# =============================================================================
class TestAccountInfo:
"""Tests for AccountInfo dataclass."""
def test_account_is_active(self):
"""Test is_active property."""
account = AccountInfo(
account_id="TEST123",
account_type="margin",
status="active",
)
assert account.is_active is True
account.status = "inactive"
assert account.is_active is False
def test_account_defaults(self):
"""Test account default values."""
account = AccountInfo(
account_id="TEST123",
account_type="cash",
status="active",
)
assert account.currency == "USD"
assert account.cash == Decimal("0")
assert account.portfolio_value == Decimal("0")
assert account.is_pattern_day_trader is False
# =============================================================================
# Quote Tests
# =============================================================================
class TestQuote:
"""Tests for Quote dataclass."""
def test_quote_mid_price(self):
"""Test mid_price calculation."""
quote = Quote(
symbol="AAPL",
bid_price=Decimal("100.00"),
ask_price=Decimal("100.10"),
)
assert quote.mid_price == Decimal("100.05")
def test_quote_mid_price_fallback(self):
"""Test mid_price fallback to last_price."""
quote = Quote(
symbol="AAPL",
last_price=Decimal("100.05"),
)
assert quote.mid_price == Decimal("100.05")
def test_quote_spread(self):
"""Test spread calculation."""
quote = Quote(
symbol="AAPL",
bid_price=Decimal("100.00"),
ask_price=Decimal("100.10"),
)
assert quote.spread == Decimal("0.10")
def test_quote_spread_percent(self):
"""Test spread_percent calculation."""
quote = Quote(
symbol="AAPL",
bid_price=Decimal("100.00"),
ask_price=Decimal("100.10"),
)
# spread / mid_price * 100 = 0.10 / 100.05 * 100 ≈ 0.0999%
assert quote.spread_percent is not None
assert abs(quote.spread_percent - 0.0999) < 0.001
# =============================================================================
# AssetInfo Tests
# =============================================================================
class TestAssetInfo:
"""Tests for AssetInfo dataclass."""
def test_asset_defaults(self):
"""Test asset default values."""
asset = AssetInfo(
symbol="AAPL",
name="Apple Inc.",
)
assert asset.asset_class == AssetClass.EQUITY
assert asset.tradable is True
assert asset.marginable is True
assert asset.shortable is True
assert asset.fractionable is False
# =============================================================================
# Exception Tests
# =============================================================================
class TestBrokerExceptions:
"""Tests for broker exceptions."""
def test_broker_error(self):
"""Test BrokerError exception."""
error = BrokerError("Something went wrong", code="ERR001", details={"key": "value"})
assert str(error) == "Something went wrong"
assert error.code == "ERR001"
assert error.details == {"key": "value"}
def test_connection_error(self):
"""Test ConnectionError exception."""
error = ConnectionError("Failed to connect")
assert isinstance(error, BrokerError)
def test_authentication_error(self):
"""Test AuthenticationError exception."""
error = AuthenticationError("Invalid credentials")
assert isinstance(error, BrokerError)
def test_order_error(self):
"""Test OrderError exception."""
error = OrderError("Order failed")
assert isinstance(error, BrokerError)
def test_insufficient_funds_error(self):
"""Test InsufficientFundsError exception."""
error = InsufficientFundsError("Not enough buying power")
assert isinstance(error, OrderError)
def test_invalid_order_error(self):
"""Test InvalidOrderError exception."""
error = InvalidOrderError("Invalid order parameters")
assert isinstance(error, OrderError)
def test_position_error(self):
"""Test PositionError exception."""
error = PositionError("Position not found")
assert isinstance(error, BrokerError)
def test_rate_limit_error(self):
"""Test RateLimitError exception."""
error = RateLimitError("Rate limit exceeded", retry_after=60.0)
assert isinstance(error, BrokerError)
assert error.retry_after == 60.0
# =============================================================================
# BrokerBase Tests
# =============================================================================
class TestBrokerBase:
"""Tests for BrokerBase abstract class."""
@pytest.fixture
def broker(self):
"""Create a mock broker instance."""
return MockBroker(paper_trading=True)
def test_broker_properties(self, broker):
"""Test broker properties."""
assert broker.name == "MockBroker"
assert broker.is_paper_trading is True
assert broker.is_connected is False
def test_supported_asset_classes(self, broker):
"""Test supported asset classes."""
assert AssetClass.EQUITY in broker.supported_asset_classes
assert AssetClass.ETF in broker.supported_asset_classes
assert AssetClass.CRYPTO in broker.supported_asset_classes
assert broker.supports_asset_class(AssetClass.EQUITY) is True
assert broker.supports_asset_class(AssetClass.FOREX) is False
@pytest.mark.asyncio
async def test_connect(self, broker):
"""Test connect method."""
result = await broker.connect()
assert result is True
assert broker.is_connected is True
@pytest.mark.asyncio
async def test_disconnect(self, broker):
"""Test disconnect method."""
await broker.connect()
await broker.disconnect()
assert broker.is_connected is False
@pytest.mark.asyncio
async def test_get_account(self, broker):
"""Test get_account method."""
await broker.connect()
account = await broker.get_account()
assert account.account_id == "TEST123"
assert account.cash == Decimal("100000")
assert account.is_active is True
@pytest.mark.asyncio
async def test_submit_market_order(self, broker):
"""Test submitting a market order."""
await broker.connect()
request = OrderRequest.market("AAPL", OrderSide.BUY, 100)
order = await broker.submit_order(request)
assert order.symbol == "AAPL"
assert order.side == OrderSide.BUY
assert order.quantity == Decimal("100")
assert order.status == OrderStatus.NEW
@pytest.mark.asyncio
async def test_submit_limit_order(self, broker):
"""Test submitting a limit order."""
await broker.connect()
request = OrderRequest.limit("AAPL", OrderSide.BUY, 100, 150.00)
order = await broker.submit_order(request)
assert order.order_type == OrderType.LIMIT
assert order.limit_price == Decimal("150.00")
@pytest.mark.asyncio
async def test_cancel_order(self, broker):
"""Test cancelling an order."""
await broker.connect()
request = OrderRequest.market("AAPL", OrderSide.BUY, 100)
order = await broker.submit_order(request)
cancelled = await broker.cancel_order(order.broker_order_id)
assert cancelled.status == OrderStatus.CANCELLED
assert cancelled.cancelled_at is not None
@pytest.mark.asyncio
async def test_cancel_nonexistent_order(self, broker):
"""Test cancelling a non-existent order."""
await broker.connect()
with pytest.raises(OrderError, match="not found"):
await broker.cancel_order("NONEXISTENT")
@pytest.mark.asyncio
async def test_replace_order(self, broker):
"""Test replacing an order."""
await broker.connect()
request = OrderRequest.limit("AAPL", OrderSide.BUY, 100, 150.00)
order = await broker.submit_order(request)
new_order = await broker.replace_order(
order.broker_order_id,
quantity=Decimal("200"),
limit_price=Decimal("155.00"),
)
assert new_order.quantity == Decimal("200")
assert new_order.limit_price == Decimal("155.00")
assert new_order.status == OrderStatus.NEW
@pytest.mark.asyncio
async def test_get_order(self, broker):
"""Test getting an order."""
await broker.connect()
request = OrderRequest.market("AAPL", OrderSide.BUY, 100)
submitted = await broker.submit_order(request)
order = await broker.get_order(submitted.broker_order_id)
assert order.broker_order_id == submitted.broker_order_id
@pytest.mark.asyncio
async def test_get_orders_filter_by_status(self, broker):
"""Test getting orders filtered by status."""
await broker.connect()
# Submit some orders
await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, 100))
order2 = await broker.submit_order(OrderRequest.market("GOOGL", OrderSide.BUY, 50))
await broker.cancel_order(order2.broker_order_id)
# Get only NEW orders
orders = await broker.get_orders(status=OrderStatus.NEW)
assert len(orders) == 1
assert orders[0].symbol == "AAPL"
@pytest.mark.asyncio
async def test_get_orders_filter_by_symbols(self, broker):
"""Test getting orders filtered by symbols."""
await broker.connect()
await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, 100))
await broker.submit_order(OrderRequest.market("GOOGL", OrderSide.BUY, 50))
await broker.submit_order(OrderRequest.market("MSFT", OrderSide.BUY, 75))
orders = await broker.get_orders(symbols=["AAPL", "MSFT"])
assert len(orders) == 2
symbols = {o.symbol for o in orders}
assert symbols == {"AAPL", "MSFT"}
@pytest.mark.asyncio
async def test_cancel_all_orders(self, broker):
"""Test cancelling all orders."""
await broker.connect()
await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, 100))
await broker.submit_order(OrderRequest.market("GOOGL", OrderSide.BUY, 50))
await broker.submit_order(OrderRequest.market("MSFT", OrderSide.BUY, 75))
cancelled = await broker.cancel_all_orders()
assert len(cancelled) == 3
# All orders should be cancelled
orders = await broker.get_orders(status=OrderStatus.NEW)
assert len(orders) == 0
@pytest.mark.asyncio
async def test_get_positions(self, broker):
"""Test getting positions."""
await broker.connect()
# Add a test position
broker.add_position(
Position(
symbol="AAPL",
quantity=Decimal("100"),
side=PositionSide.LONG,
avg_entry_price=Decimal("150.00"),
current_price=Decimal("160.00"),
market_value=Decimal("16000.00"),
cost_basis=Decimal("15000.00"),
unrealized_pnl=Decimal("1000.00"),
unrealized_pnl_percent=Decimal("6.67"),
)
)
positions = await broker.get_positions()
assert len(positions) == 1
assert positions[0].symbol == "AAPL"
@pytest.mark.asyncio
async def test_get_position(self, broker):
"""Test getting a specific position."""
await broker.connect()
broker.add_position(
Position(
symbol="AAPL",
quantity=Decimal("100"),
side=PositionSide.LONG,
avg_entry_price=Decimal("150.00"),
current_price=Decimal("160.00"),
market_value=Decimal("16000.00"),
cost_basis=Decimal("15000.00"),
unrealized_pnl=Decimal("1000.00"),
unrealized_pnl_percent=Decimal("6.67"),
)
)
position = await broker.get_position("AAPL")
assert position is not None
assert position.symbol == "AAPL"
position = await broker.get_position("NONEXISTENT")
assert position is None
@pytest.mark.asyncio
async def test_close_position(self, broker):
"""Test closing a position."""
await broker.connect()
broker.add_position(
Position(
symbol="AAPL",
quantity=Decimal("100"),
side=PositionSide.LONG,
avg_entry_price=Decimal("150.00"),
current_price=Decimal("160.00"),
market_value=Decimal("16000.00"),
cost_basis=Decimal("15000.00"),
unrealized_pnl=Decimal("1000.00"),
unrealized_pnl_percent=Decimal("6.67"),
)
)
order = await broker.close_position("AAPL")
assert order.symbol == "AAPL"
assert order.side == OrderSide.SELL # Closing long = sell
assert order.quantity == Decimal("100")
@pytest.mark.asyncio
async def test_close_position_partial(self, broker):
"""Test partially closing a position."""
await broker.connect()
broker.add_position(
Position(
symbol="AAPL",
quantity=Decimal("100"),
side=PositionSide.LONG,
avg_entry_price=Decimal("150.00"),
current_price=Decimal("160.00"),
market_value=Decimal("16000.00"),
cost_basis=Decimal("15000.00"),
unrealized_pnl=Decimal("1000.00"),
unrealized_pnl_percent=Decimal("6.67"),
)
)
order = await broker.close_position("AAPL", quantity=Decimal("50"))
assert order.quantity == Decimal("50")
@pytest.mark.asyncio
async def test_close_nonexistent_position(self, broker):
"""Test closing a non-existent position."""
await broker.connect()
with pytest.raises(PositionError, match="No position found"):
await broker.close_position("NONEXISTENT")
@pytest.mark.asyncio
async def test_close_short_position(self, broker):
"""Test closing a short position."""
await broker.connect()
broker.add_position(
Position(
symbol="AAPL",
quantity=Decimal("-100"),
side=PositionSide.SHORT,
avg_entry_price=Decimal("150.00"),
current_price=Decimal("140.00"),
market_value=Decimal("-14000.00"),
cost_basis=Decimal("-15000.00"),
unrealized_pnl=Decimal("1000.00"),
unrealized_pnl_percent=Decimal("6.67"),
)
)
order = await broker.close_position("AAPL")
assert order.side == OrderSide.BUY # Closing short = buy
@pytest.mark.asyncio
async def test_close_all_positions(self, broker):
"""Test closing all positions."""
await broker.connect()
broker.add_position(
Position(
symbol="AAPL",
quantity=Decimal("100"),
side=PositionSide.LONG,
avg_entry_price=Decimal("150.00"),
current_price=Decimal("160.00"),
market_value=Decimal("16000.00"),
cost_basis=Decimal("15000.00"),
unrealized_pnl=Decimal("1000.00"),
unrealized_pnl_percent=Decimal("6.67"),
)
)
broker.add_position(
Position(
symbol="GOOGL",
quantity=Decimal("50"),
side=PositionSide.LONG,
avg_entry_price=Decimal("2800.00"),
current_price=Decimal("2900.00"),
market_value=Decimal("145000.00"),
cost_basis=Decimal("140000.00"),
unrealized_pnl=Decimal("5000.00"),
unrealized_pnl_percent=Decimal("3.57"),
)
)
orders = await broker.close_all_positions()
assert len(orders) == 2
@pytest.mark.asyncio
async def test_get_quote(self, broker):
"""Test getting a quote."""
await broker.connect()
quote = await broker.get_quote("AAPL")
assert quote.symbol == "AAPL"
assert quote.bid_price is not None
assert quote.ask_price is not None
@pytest.mark.asyncio
async def test_get_quotes(self, broker):
"""Test getting multiple quotes."""
await broker.connect()
quotes = await broker.get_quotes(["AAPL", "GOOGL", "MSFT"])
assert len(quotes) == 3
assert "AAPL" in quotes
assert "GOOGL" in quotes
assert "MSFT" in quotes
@pytest.mark.asyncio
async def test_get_asset(self, broker):
"""Test getting asset information."""
await broker.connect()
asset = await broker.get_asset("AAPL")
assert asset.symbol == "AAPL"
assert asset.tradable is True
@pytest.mark.asyncio
async def test_validate_order_valid(self, broker):
"""Test validating a valid order."""
await broker.connect()
request = OrderRequest.market("AAPL", OrderSide.BUY, 100)
errors = await broker.validate_order(request)
assert len(errors) == 0
@pytest.mark.asyncio
async def test_validate_order_zero_quantity(self, broker):
"""Test validating order with zero quantity."""
await broker.connect()
request = OrderRequest(
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("0"),
order_type=OrderType.MARKET,
)
errors = await broker.validate_order(request)
assert "Quantity must be positive" in errors
@pytest.mark.asyncio
async def test_validate_order_non_tradable(self, broker):
"""Test validating order for non-tradable asset."""
await broker.connect()
# Add a non-tradable asset
broker.add_asset(
AssetInfo(
symbol="DELISTED",
name="Delisted Stock",
tradable=False,
)
)
request = OrderRequest.market("DELISTED", OrderSide.BUY, 100)
errors = await broker.validate_order(request)
assert any("not currently tradable" in e for e in errors)
@pytest.mark.asyncio
async def test_validate_limit_order_no_price(self, broker):
"""Test validating limit order without price."""
await broker.connect()
# Manually construct invalid request (bypassing __post_init__)
request = OrderRequest.__new__(OrderRequest)
request.symbol = "AAPL"
request.side = OrderSide.BUY
request.quantity = Decimal("100")
request.order_type = OrderType.LIMIT
request.limit_price = None
request.stop_price = None
request.time_in_force = TimeInForce.GTC
request.client_order_id = "test"
request.extended_hours = False
request.trail_amount = None
request.trail_percent = None
request.take_profit_price = None
request.stop_loss_price = None
request.metadata = {}
errors = await broker.validate_order(request)
assert any("Limit price must be positive" in e for e in errors)
def test_broker_repr(self, broker):
"""Test broker string representation."""
repr_str = repr(broker)
assert "MockBroker" in repr_str
assert "paper_trading=True" in repr_str
# =============================================================================
# Integration Tests
# =============================================================================
class TestBrokerWorkflow:
"""Integration tests for complete trading workflows."""
@pytest.fixture
def broker(self):
"""Create a mock broker instance."""
return MockBroker(paper_trading=True)
@pytest.mark.asyncio
async def test_full_trade_workflow(self, broker):
"""Test a complete trading workflow."""
# 1. Connect
await broker.connect()
assert broker.is_connected
# 2. Check account
account = await broker.get_account()
assert account.buying_power > 0
# 3. Get quote
quote = await broker.get_quote("AAPL")
assert quote.last_price is not None
# 4. Submit order
request = OrderRequest.market("AAPL", OrderSide.BUY, 100)
order = await broker.submit_order(request)
assert order.status == OrderStatus.NEW
# 5. Order gets filled (simulate)
broker.fill_order(order.broker_order_id, Decimal("150.00"))
# 6. Check filled order
filled_order = await broker.get_order(order.broker_order_id)
assert filled_order.is_filled
assert filled_order.filled_avg_price == Decimal("150.00")
# 7. Add position (simulate broker updating positions)
broker.add_position(
Position(
symbol="AAPL",
quantity=Decimal("100"),
side=PositionSide.LONG,
avg_entry_price=Decimal("150.00"),
current_price=Decimal("155.00"),
market_value=Decimal("15500.00"),
cost_basis=Decimal("15000.00"),
unrealized_pnl=Decimal("500.00"),
unrealized_pnl_percent=Decimal("3.33"),
)
)
# 8. Check positions
positions = await broker.get_positions()
assert len(positions) == 1
assert positions[0].symbol == "AAPL"
# 9. Close position
close_order = await broker.close_position("AAPL")
assert close_order.side == OrderSide.SELL
assert close_order.quantity == Decimal("100")
# 10. Disconnect
await broker.disconnect()
assert not broker.is_connected
@pytest.mark.asyncio
async def test_order_modification_workflow(self, broker):
"""Test order modification workflow."""
await broker.connect()
# Submit limit order
request = OrderRequest.limit("AAPL", OrderSide.BUY, 100, 145.00)
order = await broker.submit_order(request)
# Modify the order
new_order = await broker.replace_order(
order.broker_order_id,
limit_price=Decimal("147.50"),
)
assert new_order.limit_price == Decimal("147.50")
# Original order should be replaced
original = await broker.get_order(order.broker_order_id)
assert original.status == OrderStatus.REPLACED
@pytest.mark.asyncio
async def test_risk_management_workflow(self, broker):
"""Test risk management with stop orders."""
await broker.connect()
# Submit main order
buy_request = OrderRequest.market("AAPL", OrderSide.BUY, 100)
await broker.submit_order(buy_request)
# Submit stop loss
stop_request = OrderRequest.stop(
symbol="AAPL",
side=OrderSide.SELL,
quantity=100,
stop_price=140.00,
)
stop_order = await broker.submit_order(stop_request)
assert stop_order.order_type == OrderType.STOP
assert stop_order.stop_price == Decimal("140.00")
# Cancel stop if price moves up
await broker.cancel_order(stop_order.broker_order_id)
# Submit new stop at higher price (trailing manually)
new_stop_request = OrderRequest.stop(
symbol="AAPL",
side=OrderSide.SELL,
quantity=100,
stop_price=145.00,
)
new_stop = await broker.submit_order(new_stop_request)
assert new_stop.stop_price == Decimal("145.00")