diff --git a/tests/unit/execution/__init__.py b/tests/unit/execution/__init__.py new file mode 100644 index 00000000..2990ea39 --- /dev/null +++ b/tests/unit/execution/__init__.py @@ -0,0 +1 @@ +"""Tests for execution module.""" diff --git a/tests/unit/execution/test_broker_base.py b/tests/unit/execution/test_broker_base.py new file mode 100644 index 00000000..9d3327ae --- /dev/null +++ b/tests/unit/execution/test_broker_base.py @@ -0,0 +1,1277 @@ +"""Tests for Broker Base Interface module. + +Issue #22: [EXEC-21] Broker base interface - abstract broker class +""" + +import pytest +from datetime import datetime +from decimal import Decimal +from typing import Dict, List, Optional +from unittest.mock import AsyncMock, MagicMock, patch + +from tradingagents.execution import ( + # Enums + AssetClass, + OrderSide, + OrderType, + TimeInForce, + OrderStatus, + PositionSide, + # Data Classes + OrderRequest, + Order, + Position, + AccountInfo, + Quote, + AssetInfo, + # Exceptions + BrokerError, + ConnectionError, + AuthenticationError, + OrderError, + InsufficientFundsError, + InvalidOrderError, + PositionError, + RateLimitError, + # Abstract Base Class + BrokerBase, +) + + +# ============================================================================= +# Mock Broker Implementation for Testing +# ============================================================================= + + +class MockBroker(BrokerBase): + """Mock broker implementation for testing abstract base class.""" + + def __init__(self, **kwargs): + super().__init__( + name="MockBroker", + supported_asset_classes=[AssetClass.EQUITY, AssetClass.ETF, AssetClass.CRYPTO], + **kwargs, + ) + self._orders: Dict[str, Order] = {} + self._positions: Dict[str, Position] = {} + self._account: Optional[AccountInfo] = None + self._quotes: Dict[str, Quote] = {} + self._assets: Dict[str, AssetInfo] = {} + + async def connect(self) -> bool: + self._connected = True + self._account = AccountInfo( + account_id="TEST123", + account_type="margin", + status="active", + currency="USD", + cash=Decimal("100000"), + portfolio_value=Decimal("150000"), + buying_power=Decimal("200000"), + equity=Decimal("150000"), + ) + return True + + async def disconnect(self) -> None: + self._connected = False + + async def is_market_open(self) -> bool: + return True + + async def get_account(self) -> AccountInfo: + if not self._connected: + raise ConnectionError("Not connected") + return self._account + + async def submit_order(self, request: OrderRequest) -> Order: + if not self._connected: + raise ConnectionError("Not connected") + + order = Order( + broker_order_id=f"ORD-{len(self._orders) + 1}", + client_order_id=request.client_order_id, + symbol=request.symbol, + side=request.side, + quantity=request.quantity, + order_type=request.order_type, + status=OrderStatus.NEW, + limit_price=request.limit_price, + stop_price=request.stop_price, + time_in_force=request.time_in_force, + created_at=datetime.now(), + submitted_at=datetime.now(), + ) + self._orders[order.broker_order_id] = order + return order + + async def cancel_order(self, order_id: str) -> Order: + if order_id not in self._orders: + raise OrderError(f"Order {order_id} not found") + order = self._orders[order_id] + order.status = OrderStatus.CANCELLED + order.cancelled_at = datetime.now() + return order + + async def replace_order( + self, + order_id: str, + quantity: Optional[Decimal] = None, + limit_price: Optional[Decimal] = None, + stop_price: Optional[Decimal] = None, + time_in_force: Optional[TimeInForce] = None, + ) -> Order: + if order_id not in self._orders: + raise OrderError(f"Order {order_id} not found") + + old_order = self._orders[order_id] + old_order.status = OrderStatus.REPLACED + + new_order = Order( + broker_order_id=f"ORD-{len(self._orders) + 1}", + client_order_id=old_order.client_order_id, + symbol=old_order.symbol, + side=old_order.side, + quantity=quantity or old_order.quantity, + order_type=old_order.order_type, + status=OrderStatus.NEW, + limit_price=limit_price or old_order.limit_price, + stop_price=stop_price or old_order.stop_price, + time_in_force=time_in_force or old_order.time_in_force, + created_at=datetime.now(), + submitted_at=datetime.now(), + ) + self._orders[new_order.broker_order_id] = new_order + return new_order + + async def get_order(self, order_id: str) -> Order: + if order_id not in self._orders: + raise OrderError(f"Order {order_id} not found") + return self._orders[order_id] + + async def get_orders( + self, + status: Optional[OrderStatus] = None, + limit: int = 100, + symbols: Optional[List[str]] = None, + ) -> List[Order]: + orders = list(self._orders.values()) + if status: + orders = [o for o in orders if o.status == status] + if symbols: + orders = [o for o in orders if o.symbol in symbols] + return orders[:limit] + + async def get_positions(self) -> List[Position]: + return list(self._positions.values()) + + async def get_position(self, symbol: str) -> Optional[Position]: + return self._positions.get(symbol) + + async def get_quote(self, symbol: str) -> Quote: + if symbol in self._quotes: + return self._quotes[symbol] + # Return default quote + return Quote( + symbol=symbol, + bid_price=Decimal("100.00"), + bid_size=Decimal("100"), + ask_price=Decimal("100.05"), + ask_size=Decimal("200"), + last_price=Decimal("100.02"), + volume=1000000, + timestamp=datetime.now(), + ) + + async def get_asset(self, symbol: str) -> AssetInfo: + if symbol in self._assets: + return self._assets[symbol] + # Return default asset + return AssetInfo( + symbol=symbol, + name=f"{symbol} Inc.", + asset_class=AssetClass.EQUITY, + tradable=True, + marginable=True, + shortable=True, + ) + + # Helper methods for testing + def add_position(self, position: Position) -> None: + self._positions[position.symbol] = position + + def add_quote(self, quote: Quote) -> None: + self._quotes[quote.symbol] = quote + + def add_asset(self, asset: AssetInfo) -> None: + self._assets[asset.symbol] = asset + + def fill_order(self, order_id: str, avg_price: Decimal) -> Order: + if order_id in self._orders: + order = self._orders[order_id] + order.status = OrderStatus.FILLED + order.filled_quantity = order.quantity + order.filled_avg_price = avg_price + order.filled_at = datetime.now() + return self._orders.get(order_id) + + +# ============================================================================= +# Enum Tests +# ============================================================================= + + +class TestAssetClass: + """Tests for AssetClass enum.""" + + def test_all_asset_classes(self): + """Test all asset classes are defined.""" + assert AssetClass.EQUITY.value == "equity" + assert AssetClass.ETF.value == "etf" + assert AssetClass.OPTION.value == "option" + assert AssetClass.FUTURE.value == "future" + assert AssetClass.CRYPTO.value == "crypto" + assert AssetClass.FOREX.value == "forex" + assert AssetClass.BOND.value == "bond" + assert AssetClass.INDEX.value == "index" + + +class TestOrderSide: + """Tests for OrderSide enum.""" + + def test_order_sides(self): + """Test order sides are defined.""" + assert OrderSide.BUY.value == "buy" + assert OrderSide.SELL.value == "sell" + + +class TestOrderType: + """Tests for OrderType enum.""" + + def test_order_types(self): + """Test all order types are defined.""" + assert OrderType.MARKET.value == "market" + assert OrderType.LIMIT.value == "limit" + assert OrderType.STOP.value == "stop" + assert OrderType.STOP_LIMIT.value == "stop_limit" + assert OrderType.TRAILING_STOP.value == "trailing_stop" + + +class TestTimeInForce: + """Tests for TimeInForce enum.""" + + def test_time_in_force_values(self): + """Test all time in force values are defined.""" + assert TimeInForce.DAY.value == "day" + assert TimeInForce.GTC.value == "gtc" + assert TimeInForce.IOC.value == "ioc" + assert TimeInForce.FOK.value == "fok" + assert TimeInForce.OPG.value == "opg" + assert TimeInForce.CLS.value == "cls" + assert TimeInForce.GTD.value == "gtd" + + +class TestOrderStatus: + """Tests for OrderStatus enum.""" + + def test_order_status_values(self): + """Test all order status values are defined.""" + assert OrderStatus.PENDING_NEW.value == "pending_new" + assert OrderStatus.NEW.value == "new" + assert OrderStatus.PARTIALLY_FILLED.value == "partially_filled" + assert OrderStatus.FILLED.value == "filled" + assert OrderStatus.PENDING_CANCEL.value == "pending_cancel" + assert OrderStatus.CANCELLED.value == "cancelled" + assert OrderStatus.REJECTED.value == "rejected" + assert OrderStatus.EXPIRED.value == "expired" + assert OrderStatus.REPLACED.value == "replaced" + + +class TestPositionSide: + """Tests for PositionSide enum.""" + + def test_position_sides(self): + """Test position sides are defined.""" + assert PositionSide.LONG.value == "long" + assert PositionSide.SHORT.value == "short" + + +# ============================================================================= +# OrderRequest Tests +# ============================================================================= + + +class TestOrderRequest: + """Tests for OrderRequest dataclass.""" + + def test_create_market_order(self): + """Test creating a market order.""" + request = OrderRequest.market("AAPL", OrderSide.BUY, 100) + + assert request.symbol == "AAPL" + assert request.side == OrderSide.BUY + assert request.quantity == Decimal("100") + assert request.order_type == OrderType.MARKET + assert request.time_in_force == TimeInForce.DAY + assert request.client_order_id is not None + + def test_create_limit_order(self): + """Test creating a limit order.""" + request = OrderRequest.limit( + symbol="AAPL", + side=OrderSide.BUY, + quantity=100, + limit_price=150.00, + ) + + assert request.symbol == "AAPL" + assert request.side == OrderSide.BUY + assert request.quantity == Decimal("100") + assert request.order_type == OrderType.LIMIT + assert request.limit_price == Decimal("150.00") + assert request.time_in_force == TimeInForce.GTC + + def test_create_stop_order(self): + """Test creating a stop order.""" + request = OrderRequest.stop( + symbol="AAPL", + side=OrderSide.SELL, + quantity=100, + stop_price=145.00, + ) + + assert request.order_type == OrderType.STOP + assert request.stop_price == Decimal("145.00") + + def test_create_stop_limit_order(self): + """Test creating a stop-limit order.""" + request = OrderRequest.stop_limit( + symbol="AAPL", + side=OrderSide.SELL, + quantity=100, + stop_price=145.00, + limit_price=144.50, + ) + + assert request.order_type == OrderType.STOP_LIMIT + assert request.stop_price == Decimal("145.00") + assert request.limit_price == Decimal("144.50") + + def test_create_trailing_stop_percent(self): + """Test creating a trailing stop order with percent.""" + request = OrderRequest.trailing_stop( + symbol="AAPL", + side=OrderSide.SELL, + quantity=100, + trail_percent=5.0, + ) + + assert request.order_type == OrderType.TRAILING_STOP + assert request.trail_percent == Decimal("5.0") + assert request.trail_amount is None + + def test_create_trailing_stop_amount(self): + """Test creating a trailing stop order with amount.""" + request = OrderRequest.trailing_stop( + symbol="AAPL", + side=OrderSide.SELL, + quantity=100, + trail_amount=5.00, + ) + + assert request.order_type == OrderType.TRAILING_STOP + assert request.trail_amount == Decimal("5.00") + assert request.trail_percent is None + + def test_limit_order_requires_limit_price(self): + """Test that limit orders require limit price.""" + with pytest.raises(ValueError, match="Limit orders require limit_price"): + OrderRequest( + symbol="AAPL", + side=OrderSide.BUY, + quantity=Decimal("100"), + order_type=OrderType.LIMIT, + ) + + def test_stop_order_requires_stop_price(self): + """Test that stop orders require stop price.""" + with pytest.raises(ValueError, match="Stop orders require stop_price"): + OrderRequest( + symbol="AAPL", + side=OrderSide.BUY, + quantity=Decimal("100"), + order_type=OrderType.STOP, + ) + + def test_stop_limit_requires_both_prices(self): + """Test that stop-limit orders require both prices.""" + with pytest.raises(ValueError, match="Stop-limit orders require both"): + OrderRequest( + symbol="AAPL", + side=OrderSide.BUY, + quantity=Decimal("100"), + order_type=OrderType.STOP_LIMIT, + limit_price=Decimal("150"), + # Missing stop_price + ) + + def test_trailing_stop_requires_trail_value(self): + """Test that trailing stop orders require trail value.""" + with pytest.raises(ValueError, match="Trailing stop orders require"): + OrderRequest( + symbol="AAPL", + side=OrderSide.BUY, + quantity=Decimal("100"), + order_type=OrderType.TRAILING_STOP, + ) + + def test_order_with_metadata(self): + """Test order with custom metadata.""" + request = OrderRequest.market( + "AAPL", + OrderSide.BUY, + 100, + metadata={"strategy": "momentum", "signal_strength": 0.85}, + ) + + assert request.metadata["strategy"] == "momentum" + assert request.metadata["signal_strength"] == 0.85 + + +# ============================================================================= +# Order Tests +# ============================================================================= + + +class TestOrder: + """Tests for Order dataclass.""" + + def test_order_is_open(self): + """Test is_open property.""" + order = Order( + broker_order_id="ORD-1", + client_order_id="CLT-1", + symbol="AAPL", + side=OrderSide.BUY, + quantity=Decimal("100"), + order_type=OrderType.MARKET, + status=OrderStatus.NEW, + ) + assert order.is_open is True + + order.status = OrderStatus.FILLED + assert order.is_open is False + + def test_order_is_filled(self): + """Test is_filled property.""" + order = Order( + broker_order_id="ORD-1", + client_order_id="CLT-1", + symbol="AAPL", + side=OrderSide.BUY, + quantity=Decimal("100"), + order_type=OrderType.MARKET, + status=OrderStatus.FILLED, + ) + assert order.is_filled is True + + def test_order_is_cancelled(self): + """Test is_cancelled property.""" + order = Order( + broker_order_id="ORD-1", + client_order_id="CLT-1", + symbol="AAPL", + side=OrderSide.BUY, + quantity=Decimal("100"), + order_type=OrderType.MARKET, + status=OrderStatus.CANCELLED, + ) + assert order.is_cancelled is True + + def test_remaining_quantity(self): + """Test remaining_quantity property.""" + order = Order( + broker_order_id="ORD-1", + client_order_id="CLT-1", + symbol="AAPL", + side=OrderSide.BUY, + quantity=Decimal("100"), + order_type=OrderType.MARKET, + status=OrderStatus.PARTIALLY_FILLED, + filled_quantity=Decimal("60"), + ) + assert order.remaining_quantity == Decimal("40") + + def test_fill_percent(self): + """Test fill_percent property.""" + order = Order( + broker_order_id="ORD-1", + client_order_id="CLT-1", + symbol="AAPL", + side=OrderSide.BUY, + quantity=Decimal("100"), + order_type=OrderType.MARKET, + status=OrderStatus.PARTIALLY_FILLED, + filled_quantity=Decimal("60"), + ) + assert order.fill_percent == 60.0 + + +# ============================================================================= +# Position Tests +# ============================================================================= + + +class TestPosition: + """Tests for Position dataclass.""" + + def test_position_long(self): + """Test long position properties.""" + position = Position( + symbol="AAPL", + quantity=Decimal("100"), + side=PositionSide.LONG, + avg_entry_price=Decimal("150.00"), + current_price=Decimal("160.00"), + market_value=Decimal("16000.00"), + cost_basis=Decimal("15000.00"), + unrealized_pnl=Decimal("1000.00"), + unrealized_pnl_percent=Decimal("6.67"), + ) + + assert position.is_long is True + assert position.is_short is False + assert position.abs_quantity == Decimal("100") + + def test_position_short(self): + """Test short position properties.""" + position = Position( + symbol="AAPL", + quantity=Decimal("-100"), + side=PositionSide.SHORT, + avg_entry_price=Decimal("150.00"), + current_price=Decimal("140.00"), + market_value=Decimal("-14000.00"), + cost_basis=Decimal("-15000.00"), + unrealized_pnl=Decimal("1000.00"), + unrealized_pnl_percent=Decimal("6.67"), + ) + + assert position.is_long is False + assert position.is_short is True + assert position.abs_quantity == Decimal("100") + + +# ============================================================================= +# AccountInfo Tests +# ============================================================================= + + +class TestAccountInfo: + """Tests for AccountInfo dataclass.""" + + def test_account_is_active(self): + """Test is_active property.""" + account = AccountInfo( + account_id="TEST123", + account_type="margin", + status="active", + ) + assert account.is_active is True + + account.status = "inactive" + assert account.is_active is False + + def test_account_defaults(self): + """Test account default values.""" + account = AccountInfo( + account_id="TEST123", + account_type="cash", + status="active", + ) + + assert account.currency == "USD" + assert account.cash == Decimal("0") + assert account.portfolio_value == Decimal("0") + assert account.is_pattern_day_trader is False + + +# ============================================================================= +# Quote Tests +# ============================================================================= + + +class TestQuote: + """Tests for Quote dataclass.""" + + def test_quote_mid_price(self): + """Test mid_price calculation.""" + quote = Quote( + symbol="AAPL", + bid_price=Decimal("100.00"), + ask_price=Decimal("100.10"), + ) + assert quote.mid_price == Decimal("100.05") + + def test_quote_mid_price_fallback(self): + """Test mid_price fallback to last_price.""" + quote = Quote( + symbol="AAPL", + last_price=Decimal("100.05"), + ) + assert quote.mid_price == Decimal("100.05") + + def test_quote_spread(self): + """Test spread calculation.""" + quote = Quote( + symbol="AAPL", + bid_price=Decimal("100.00"), + ask_price=Decimal("100.10"), + ) + assert quote.spread == Decimal("0.10") + + def test_quote_spread_percent(self): + """Test spread_percent calculation.""" + quote = Quote( + symbol="AAPL", + bid_price=Decimal("100.00"), + ask_price=Decimal("100.10"), + ) + # spread / mid_price * 100 = 0.10 / 100.05 * 100 ≈ 0.0999% + assert quote.spread_percent is not None + assert abs(quote.spread_percent - 0.0999) < 0.001 + + +# ============================================================================= +# AssetInfo Tests +# ============================================================================= + + +class TestAssetInfo: + """Tests for AssetInfo dataclass.""" + + def test_asset_defaults(self): + """Test asset default values.""" + asset = AssetInfo( + symbol="AAPL", + name="Apple Inc.", + ) + + assert asset.asset_class == AssetClass.EQUITY + assert asset.tradable is True + assert asset.marginable is True + assert asset.shortable is True + assert asset.fractionable is False + + +# ============================================================================= +# Exception Tests +# ============================================================================= + + +class TestBrokerExceptions: + """Tests for broker exceptions.""" + + def test_broker_error(self): + """Test BrokerError exception.""" + error = BrokerError("Something went wrong", code="ERR001", details={"key": "value"}) + + assert str(error) == "Something went wrong" + assert error.code == "ERR001" + assert error.details == {"key": "value"} + + def test_connection_error(self): + """Test ConnectionError exception.""" + error = ConnectionError("Failed to connect") + assert isinstance(error, BrokerError) + + def test_authentication_error(self): + """Test AuthenticationError exception.""" + error = AuthenticationError("Invalid credentials") + assert isinstance(error, BrokerError) + + def test_order_error(self): + """Test OrderError exception.""" + error = OrderError("Order failed") + assert isinstance(error, BrokerError) + + def test_insufficient_funds_error(self): + """Test InsufficientFundsError exception.""" + error = InsufficientFundsError("Not enough buying power") + assert isinstance(error, OrderError) + + def test_invalid_order_error(self): + """Test InvalidOrderError exception.""" + error = InvalidOrderError("Invalid order parameters") + assert isinstance(error, OrderError) + + def test_position_error(self): + """Test PositionError exception.""" + error = PositionError("Position not found") + assert isinstance(error, BrokerError) + + def test_rate_limit_error(self): + """Test RateLimitError exception.""" + error = RateLimitError("Rate limit exceeded", retry_after=60.0) + + assert isinstance(error, BrokerError) + assert error.retry_after == 60.0 + + +# ============================================================================= +# BrokerBase Tests +# ============================================================================= + + +class TestBrokerBase: + """Tests for BrokerBase abstract class.""" + + @pytest.fixture + def broker(self): + """Create a mock broker instance.""" + return MockBroker(paper_trading=True) + + def test_broker_properties(self, broker): + """Test broker properties.""" + assert broker.name == "MockBroker" + assert broker.is_paper_trading is True + assert broker.is_connected is False + + def test_supported_asset_classes(self, broker): + """Test supported asset classes.""" + assert AssetClass.EQUITY in broker.supported_asset_classes + assert AssetClass.ETF in broker.supported_asset_classes + assert AssetClass.CRYPTO in broker.supported_asset_classes + assert broker.supports_asset_class(AssetClass.EQUITY) is True + assert broker.supports_asset_class(AssetClass.FOREX) is False + + @pytest.mark.asyncio + async def test_connect(self, broker): + """Test connect method.""" + result = await broker.connect() + assert result is True + assert broker.is_connected is True + + @pytest.mark.asyncio + async def test_disconnect(self, broker): + """Test disconnect method.""" + await broker.connect() + await broker.disconnect() + assert broker.is_connected is False + + @pytest.mark.asyncio + async def test_get_account(self, broker): + """Test get_account method.""" + await broker.connect() + account = await broker.get_account() + + assert account.account_id == "TEST123" + assert account.cash == Decimal("100000") + assert account.is_active is True + + @pytest.mark.asyncio + async def test_submit_market_order(self, broker): + """Test submitting a market order.""" + await broker.connect() + + request = OrderRequest.market("AAPL", OrderSide.BUY, 100) + order = await broker.submit_order(request) + + assert order.symbol == "AAPL" + assert order.side == OrderSide.BUY + assert order.quantity == Decimal("100") + assert order.status == OrderStatus.NEW + + @pytest.mark.asyncio + async def test_submit_limit_order(self, broker): + """Test submitting a limit order.""" + await broker.connect() + + request = OrderRequest.limit("AAPL", OrderSide.BUY, 100, 150.00) + order = await broker.submit_order(request) + + assert order.order_type == OrderType.LIMIT + assert order.limit_price == Decimal("150.00") + + @pytest.mark.asyncio + async def test_cancel_order(self, broker): + """Test cancelling an order.""" + await broker.connect() + + request = OrderRequest.market("AAPL", OrderSide.BUY, 100) + order = await broker.submit_order(request) + + cancelled = await broker.cancel_order(order.broker_order_id) + assert cancelled.status == OrderStatus.CANCELLED + assert cancelled.cancelled_at is not None + + @pytest.mark.asyncio + async def test_cancel_nonexistent_order(self, broker): + """Test cancelling a non-existent order.""" + await broker.connect() + + with pytest.raises(OrderError, match="not found"): + await broker.cancel_order("NONEXISTENT") + + @pytest.mark.asyncio + async def test_replace_order(self, broker): + """Test replacing an order.""" + await broker.connect() + + request = OrderRequest.limit("AAPL", OrderSide.BUY, 100, 150.00) + order = await broker.submit_order(request) + + new_order = await broker.replace_order( + order.broker_order_id, + quantity=Decimal("200"), + limit_price=Decimal("155.00"), + ) + + assert new_order.quantity == Decimal("200") + assert new_order.limit_price == Decimal("155.00") + assert new_order.status == OrderStatus.NEW + + @pytest.mark.asyncio + async def test_get_order(self, broker): + """Test getting an order.""" + await broker.connect() + + request = OrderRequest.market("AAPL", OrderSide.BUY, 100) + submitted = await broker.submit_order(request) + + order = await broker.get_order(submitted.broker_order_id) + assert order.broker_order_id == submitted.broker_order_id + + @pytest.mark.asyncio + async def test_get_orders_filter_by_status(self, broker): + """Test getting orders filtered by status.""" + await broker.connect() + + # Submit some orders + await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, 100)) + order2 = await broker.submit_order(OrderRequest.market("GOOGL", OrderSide.BUY, 50)) + await broker.cancel_order(order2.broker_order_id) + + # Get only NEW orders + orders = await broker.get_orders(status=OrderStatus.NEW) + assert len(orders) == 1 + assert orders[0].symbol == "AAPL" + + @pytest.mark.asyncio + async def test_get_orders_filter_by_symbols(self, broker): + """Test getting orders filtered by symbols.""" + await broker.connect() + + await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, 100)) + await broker.submit_order(OrderRequest.market("GOOGL", OrderSide.BUY, 50)) + await broker.submit_order(OrderRequest.market("MSFT", OrderSide.BUY, 75)) + + orders = await broker.get_orders(symbols=["AAPL", "MSFT"]) + assert len(orders) == 2 + symbols = {o.symbol for o in orders} + assert symbols == {"AAPL", "MSFT"} + + @pytest.mark.asyncio + async def test_cancel_all_orders(self, broker): + """Test cancelling all orders.""" + await broker.connect() + + await broker.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, 100)) + await broker.submit_order(OrderRequest.market("GOOGL", OrderSide.BUY, 50)) + await broker.submit_order(OrderRequest.market("MSFT", OrderSide.BUY, 75)) + + cancelled = await broker.cancel_all_orders() + assert len(cancelled) == 3 + + # All orders should be cancelled + orders = await broker.get_orders(status=OrderStatus.NEW) + assert len(orders) == 0 + + @pytest.mark.asyncio + async def test_get_positions(self, broker): + """Test getting positions.""" + await broker.connect() + + # Add a test position + broker.add_position( + Position( + symbol="AAPL", + quantity=Decimal("100"), + side=PositionSide.LONG, + avg_entry_price=Decimal("150.00"), + current_price=Decimal("160.00"), + market_value=Decimal("16000.00"), + cost_basis=Decimal("15000.00"), + unrealized_pnl=Decimal("1000.00"), + unrealized_pnl_percent=Decimal("6.67"), + ) + ) + + positions = await broker.get_positions() + assert len(positions) == 1 + assert positions[0].symbol == "AAPL" + + @pytest.mark.asyncio + async def test_get_position(self, broker): + """Test getting a specific position.""" + await broker.connect() + + broker.add_position( + Position( + symbol="AAPL", + quantity=Decimal("100"), + side=PositionSide.LONG, + avg_entry_price=Decimal("150.00"), + current_price=Decimal("160.00"), + market_value=Decimal("16000.00"), + cost_basis=Decimal("15000.00"), + unrealized_pnl=Decimal("1000.00"), + unrealized_pnl_percent=Decimal("6.67"), + ) + ) + + position = await broker.get_position("AAPL") + assert position is not None + assert position.symbol == "AAPL" + + position = await broker.get_position("NONEXISTENT") + assert position is None + + @pytest.mark.asyncio + async def test_close_position(self, broker): + """Test closing a position.""" + await broker.connect() + + broker.add_position( + Position( + symbol="AAPL", + quantity=Decimal("100"), + side=PositionSide.LONG, + avg_entry_price=Decimal("150.00"), + current_price=Decimal("160.00"), + market_value=Decimal("16000.00"), + cost_basis=Decimal("15000.00"), + unrealized_pnl=Decimal("1000.00"), + unrealized_pnl_percent=Decimal("6.67"), + ) + ) + + order = await broker.close_position("AAPL") + assert order.symbol == "AAPL" + assert order.side == OrderSide.SELL # Closing long = sell + assert order.quantity == Decimal("100") + + @pytest.mark.asyncio + async def test_close_position_partial(self, broker): + """Test partially closing a position.""" + await broker.connect() + + broker.add_position( + Position( + symbol="AAPL", + quantity=Decimal("100"), + side=PositionSide.LONG, + avg_entry_price=Decimal("150.00"), + current_price=Decimal("160.00"), + market_value=Decimal("16000.00"), + cost_basis=Decimal("15000.00"), + unrealized_pnl=Decimal("1000.00"), + unrealized_pnl_percent=Decimal("6.67"), + ) + ) + + order = await broker.close_position("AAPL", quantity=Decimal("50")) + assert order.quantity == Decimal("50") + + @pytest.mark.asyncio + async def test_close_nonexistent_position(self, broker): + """Test closing a non-existent position.""" + await broker.connect() + + with pytest.raises(PositionError, match="No position found"): + await broker.close_position("NONEXISTENT") + + @pytest.mark.asyncio + async def test_close_short_position(self, broker): + """Test closing a short position.""" + await broker.connect() + + broker.add_position( + Position( + symbol="AAPL", + quantity=Decimal("-100"), + side=PositionSide.SHORT, + avg_entry_price=Decimal("150.00"), + current_price=Decimal("140.00"), + market_value=Decimal("-14000.00"), + cost_basis=Decimal("-15000.00"), + unrealized_pnl=Decimal("1000.00"), + unrealized_pnl_percent=Decimal("6.67"), + ) + ) + + order = await broker.close_position("AAPL") + assert order.side == OrderSide.BUY # Closing short = buy + + @pytest.mark.asyncio + async def test_close_all_positions(self, broker): + """Test closing all positions.""" + await broker.connect() + + broker.add_position( + Position( + symbol="AAPL", + quantity=Decimal("100"), + side=PositionSide.LONG, + avg_entry_price=Decimal("150.00"), + current_price=Decimal("160.00"), + market_value=Decimal("16000.00"), + cost_basis=Decimal("15000.00"), + unrealized_pnl=Decimal("1000.00"), + unrealized_pnl_percent=Decimal("6.67"), + ) + ) + broker.add_position( + Position( + symbol="GOOGL", + quantity=Decimal("50"), + side=PositionSide.LONG, + avg_entry_price=Decimal("2800.00"), + current_price=Decimal("2900.00"), + market_value=Decimal("145000.00"), + cost_basis=Decimal("140000.00"), + unrealized_pnl=Decimal("5000.00"), + unrealized_pnl_percent=Decimal("3.57"), + ) + ) + + orders = await broker.close_all_positions() + assert len(orders) == 2 + + @pytest.mark.asyncio + async def test_get_quote(self, broker): + """Test getting a quote.""" + await broker.connect() + + quote = await broker.get_quote("AAPL") + assert quote.symbol == "AAPL" + assert quote.bid_price is not None + assert quote.ask_price is not None + + @pytest.mark.asyncio + async def test_get_quotes(self, broker): + """Test getting multiple quotes.""" + await broker.connect() + + quotes = await broker.get_quotes(["AAPL", "GOOGL", "MSFT"]) + assert len(quotes) == 3 + assert "AAPL" in quotes + assert "GOOGL" in quotes + assert "MSFT" in quotes + + @pytest.mark.asyncio + async def test_get_asset(self, broker): + """Test getting asset information.""" + await broker.connect() + + asset = await broker.get_asset("AAPL") + assert asset.symbol == "AAPL" + assert asset.tradable is True + + @pytest.mark.asyncio + async def test_validate_order_valid(self, broker): + """Test validating a valid order.""" + await broker.connect() + + request = OrderRequest.market("AAPL", OrderSide.BUY, 100) + errors = await broker.validate_order(request) + assert len(errors) == 0 + + @pytest.mark.asyncio + async def test_validate_order_zero_quantity(self, broker): + """Test validating order with zero quantity.""" + await broker.connect() + + request = OrderRequest( + symbol="AAPL", + side=OrderSide.BUY, + quantity=Decimal("0"), + order_type=OrderType.MARKET, + ) + errors = await broker.validate_order(request) + assert "Quantity must be positive" in errors + + @pytest.mark.asyncio + async def test_validate_order_non_tradable(self, broker): + """Test validating order for non-tradable asset.""" + await broker.connect() + + # Add a non-tradable asset + broker.add_asset( + AssetInfo( + symbol="DELISTED", + name="Delisted Stock", + tradable=False, + ) + ) + + request = OrderRequest.market("DELISTED", OrderSide.BUY, 100) + errors = await broker.validate_order(request) + assert any("not currently tradable" in e for e in errors) + + @pytest.mark.asyncio + async def test_validate_limit_order_no_price(self, broker): + """Test validating limit order without price.""" + await broker.connect() + + # Manually construct invalid request (bypassing __post_init__) + request = OrderRequest.__new__(OrderRequest) + request.symbol = "AAPL" + request.side = OrderSide.BUY + request.quantity = Decimal("100") + request.order_type = OrderType.LIMIT + request.limit_price = None + request.stop_price = None + request.time_in_force = TimeInForce.GTC + request.client_order_id = "test" + request.extended_hours = False + request.trail_amount = None + request.trail_percent = None + request.take_profit_price = None + request.stop_loss_price = None + request.metadata = {} + + errors = await broker.validate_order(request) + assert any("Limit price must be positive" in e for e in errors) + + def test_broker_repr(self, broker): + """Test broker string representation.""" + repr_str = repr(broker) + assert "MockBroker" in repr_str + assert "paper_trading=True" in repr_str + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +class TestBrokerWorkflow: + """Integration tests for complete trading workflows.""" + + @pytest.fixture + def broker(self): + """Create a mock broker instance.""" + return MockBroker(paper_trading=True) + + @pytest.mark.asyncio + async def test_full_trade_workflow(self, broker): + """Test a complete trading workflow.""" + # 1. Connect + await broker.connect() + assert broker.is_connected + + # 2. Check account + account = await broker.get_account() + assert account.buying_power > 0 + + # 3. Get quote + quote = await broker.get_quote("AAPL") + assert quote.last_price is not None + + # 4. Submit order + request = OrderRequest.market("AAPL", OrderSide.BUY, 100) + order = await broker.submit_order(request) + assert order.status == OrderStatus.NEW + + # 5. Order gets filled (simulate) + broker.fill_order(order.broker_order_id, Decimal("150.00")) + + # 6. Check filled order + filled_order = await broker.get_order(order.broker_order_id) + assert filled_order.is_filled + assert filled_order.filled_avg_price == Decimal("150.00") + + # 7. Add position (simulate broker updating positions) + broker.add_position( + Position( + symbol="AAPL", + quantity=Decimal("100"), + side=PositionSide.LONG, + avg_entry_price=Decimal("150.00"), + current_price=Decimal("155.00"), + market_value=Decimal("15500.00"), + cost_basis=Decimal("15000.00"), + unrealized_pnl=Decimal("500.00"), + unrealized_pnl_percent=Decimal("3.33"), + ) + ) + + # 8. Check positions + positions = await broker.get_positions() + assert len(positions) == 1 + assert positions[0].symbol == "AAPL" + + # 9. Close position + close_order = await broker.close_position("AAPL") + assert close_order.side == OrderSide.SELL + assert close_order.quantity == Decimal("100") + + # 10. Disconnect + await broker.disconnect() + assert not broker.is_connected + + @pytest.mark.asyncio + async def test_order_modification_workflow(self, broker): + """Test order modification workflow.""" + await broker.connect() + + # Submit limit order + request = OrderRequest.limit("AAPL", OrderSide.BUY, 100, 145.00) + order = await broker.submit_order(request) + + # Modify the order + new_order = await broker.replace_order( + order.broker_order_id, + limit_price=Decimal("147.50"), + ) + + assert new_order.limit_price == Decimal("147.50") + + # Original order should be replaced + original = await broker.get_order(order.broker_order_id) + assert original.status == OrderStatus.REPLACED + + @pytest.mark.asyncio + async def test_risk_management_workflow(self, broker): + """Test risk management with stop orders.""" + await broker.connect() + + # Submit main order + buy_request = OrderRequest.market("AAPL", OrderSide.BUY, 100) + await broker.submit_order(buy_request) + + # Submit stop loss + stop_request = OrderRequest.stop( + symbol="AAPL", + side=OrderSide.SELL, + quantity=100, + stop_price=140.00, + ) + stop_order = await broker.submit_order(stop_request) + + assert stop_order.order_type == OrderType.STOP + assert stop_order.stop_price == Decimal("140.00") + + # Cancel stop if price moves up + await broker.cancel_order(stop_order.broker_order_id) + + # Submit new stop at higher price (trailing manually) + new_stop_request = OrderRequest.stop( + symbol="AAPL", + side=OrderSide.SELL, + quantity=100, + stop_price=145.00, + ) + new_stop = await broker.submit_order(new_stop_request) + + assert new_stop.stop_price == Decimal("145.00") diff --git a/tradingagents/execution/__init__.py b/tradingagents/execution/__init__.py new file mode 100644 index 00000000..c243db80 --- /dev/null +++ b/tradingagents/execution/__init__.py @@ -0,0 +1,115 @@ +"""Execution module for broker integrations and order management. + +This module provides a unified interface for interacting with various brokers +(Alpaca, IBKR, Paper) and managing order execution. + +Issue #22: [EXEC-21] Broker base interface - abstract broker class + +Submodules: + broker_base: Abstract base class for broker implementations + +Classes: + Enums: + - AssetClass: Supported asset classes (EQUITY, ETF, CRYPTO, etc.) + - OrderSide: Order side (BUY, SELL) + - OrderType: Order types (MARKET, LIMIT, STOP, etc.) + - TimeInForce: Order duration (DAY, GTC, IOC, etc.) + - OrderStatus: Order execution status + - PositionSide: Position side (LONG, SHORT) + + Data Classes: + - OrderRequest: Request to submit an order + - Order: Order information returned from broker + - Position: Current position in an asset + - AccountInfo: Broker account information + - Quote: Current quote/price data + - AssetInfo: Asset/instrument information + + Exceptions: + - BrokerError: Base exception for broker errors + - ConnectionError: Error connecting to broker + - AuthenticationError: Authentication failed + - OrderError: Error submitting or managing order + - InsufficientFundsError: Insufficient funds for order + - InvalidOrderError: Invalid order parameters + - PositionError: Error with position operations + - RateLimitError: Rate limit exceeded + + Abstract Base Class: + - BrokerBase: Abstract base class for broker implementations + +Example: + >>> from tradingagents.execution import ( + ... BrokerBase, + ... OrderRequest, + ... OrderSide, + ... OrderType, + ... ) + >>> + >>> # Create a market buy order + >>> order_request = OrderRequest.market("AAPL", OrderSide.BUY, 100) + >>> + >>> # Or with more options + >>> order_request = OrderRequest.limit( + ... symbol="AAPL", + ... side=OrderSide.BUY, + ... quantity=100, + ... limit_price=150.00, + ... ) +""" + +from .broker_base import ( + # Enums + AssetClass, + OrderSide, + OrderType, + TimeInForce, + OrderStatus, + PositionSide, + # Data Classes + OrderRequest, + Order, + Position, + AccountInfo, + Quote, + AssetInfo, + # Exceptions + BrokerError, + ConnectionError, + AuthenticationError, + OrderError, + InsufficientFundsError, + InvalidOrderError, + PositionError, + RateLimitError, + # Abstract Base Class + BrokerBase, +) + +__all__ = [ + # Enums + "AssetClass", + "OrderSide", + "OrderType", + "TimeInForce", + "OrderStatus", + "PositionSide", + # Data Classes + "OrderRequest", + "Order", + "Position", + "AccountInfo", + "Quote", + "AssetInfo", + # Exceptions + "BrokerError", + "ConnectionError", + "AuthenticationError", + "OrderError", + "InsufficientFundsError", + "InvalidOrderError", + "PositionError", + "RateLimitError", + # Abstract Base Class + "BrokerBase", +] diff --git a/tradingagents/execution/broker_base.py b/tradingagents/execution/broker_base.py new file mode 100644 index 00000000..a02a695d --- /dev/null +++ b/tradingagents/execution/broker_base.py @@ -0,0 +1,1030 @@ +"""Abstract Broker Base Interface. + +This module defines the abstract base class for all broker implementations. +Concrete broker implementations (Alpaca, IBKR, Paper) inherit from this class +and implement the abstract methods for their specific APIs. + +Issue #22: [EXEC-21] Broker base interface - abstract broker class + +Design Principles: + - Uniform interface across all brokers + - Async-first for I/O operations + - Type-safe with dataclasses + - Support for multiple asset classes + - Extensible for broker-specific features +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from decimal import Decimal +from enum import Enum +from typing import Any, Dict, List, Optional, Union +import uuid + + +class AssetClass(Enum): + """Supported asset classes.""" + EQUITY = "equity" # Stocks + ETF = "etf" # Exchange-traded funds + OPTION = "option" # Options contracts + FUTURE = "future" # Futures contracts + CRYPTO = "crypto" # Cryptocurrency + FOREX = "forex" # Foreign exchange + BOND = "bond" # Fixed income + INDEX = "index" # Market indices + + +class OrderSide(Enum): + """Order side (buy or sell).""" + BUY = "buy" + SELL = "sell" + + +class OrderType(Enum): + """Order type. + + MARKET: Execute at current market price + LIMIT: Execute at specified price or better + STOP: Trigger market order at stop price + STOP_LIMIT: Trigger limit order at stop price + TRAILING_STOP: Stop that trails price by specified amount/percent + """ + MARKET = "market" + LIMIT = "limit" + STOP = "stop" + STOP_LIMIT = "stop_limit" + TRAILING_STOP = "trailing_stop" + + +class TimeInForce(Enum): + """Time in force (order duration). + + DAY: Valid until end of regular trading hours + GTC: Good till cancelled + IOC: Immediate or cancel (partial fills allowed) + FOK: Fill or kill (all or nothing) + OPG: On open (execute at market open) + CLS: On close (execute at market close) + GTD: Good till date + """ + DAY = "day" + GTC = "gtc" + IOC = "ioc" + FOK = "fok" + OPG = "opg" + CLS = "cls" + GTD = "gtd" + + +class OrderStatus(Enum): + """Order execution status. + + PENDING_NEW: Order submitted, awaiting confirmation + NEW: Order accepted by broker + PARTIALLY_FILLED: Order partially executed + FILLED: Order fully executed + PENDING_CANCEL: Cancel request submitted + CANCELLED: Order cancelled + REJECTED: Order rejected by broker + EXPIRED: Order expired (time in force elapsed) + REPLACED: Order was replaced by new order + """ + PENDING_NEW = "pending_new" + NEW = "new" + PARTIALLY_FILLED = "partially_filled" + FILLED = "filled" + PENDING_CANCEL = "pending_cancel" + CANCELLED = "cancelled" + REJECTED = "rejected" + EXPIRED = "expired" + REPLACED = "replaced" + + +class PositionSide(Enum): + """Position side.""" + LONG = "long" + SHORT = "short" + + +@dataclass +class OrderRequest: + """Request to submit an order. + + Attributes: + symbol: Trading symbol + side: Buy or sell + quantity: Number of shares/contracts + order_type: Type of order + limit_price: Limit price (for limit/stop-limit orders) + stop_price: Stop price (for stop/stop-limit orders) + time_in_force: Order duration + client_order_id: Optional client-defined order ID + extended_hours: Allow extended hours trading + trail_amount: Trail amount for trailing stop (absolute) + trail_percent: Trail percent for trailing stop + take_profit_price: Take profit price (OCO orders) + stop_loss_price: Stop loss price (OCO orders) + metadata: Additional broker-specific metadata + """ + symbol: str + side: OrderSide + quantity: Decimal + order_type: OrderType = OrderType.MARKET + limit_price: Optional[Decimal] = None + stop_price: Optional[Decimal] = None + time_in_force: TimeInForce = TimeInForce.DAY + client_order_id: Optional[str] = None + extended_hours: bool = False + trail_amount: Optional[Decimal] = None + trail_percent: Optional[Decimal] = None + take_profit_price: Optional[Decimal] = None + stop_loss_price: Optional[Decimal] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + """Generate client order ID if not provided.""" + if self.client_order_id is None: + self.client_order_id = str(uuid.uuid4()) + + # Validate order type requirements + if self.order_type == OrderType.LIMIT and self.limit_price is None: + raise ValueError("Limit orders require limit_price") + if self.order_type == OrderType.STOP and self.stop_price is None: + raise ValueError("Stop orders require stop_price") + if self.order_type == OrderType.STOP_LIMIT: + if self.limit_price is None or self.stop_price is None: + raise ValueError("Stop-limit orders require both limit_price and stop_price") + if self.order_type == OrderType.TRAILING_STOP: + if self.trail_amount is None and self.trail_percent is None: + raise ValueError("Trailing stop orders require trail_amount or trail_percent") + + @classmethod + def market( + cls, + symbol: str, + side: OrderSide, + quantity: Union[Decimal, float, int], + time_in_force: TimeInForce = TimeInForce.DAY, + **kwargs, + ) -> "OrderRequest": + """Create a market order request.""" + return cls( + symbol=symbol, + side=side, + quantity=Decimal(str(quantity)), + order_type=OrderType.MARKET, + time_in_force=time_in_force, + **kwargs, + ) + + @classmethod + def limit( + cls, + symbol: str, + side: OrderSide, + quantity: Union[Decimal, float, int], + limit_price: Union[Decimal, float, int], + time_in_force: TimeInForce = TimeInForce.GTC, + **kwargs, + ) -> "OrderRequest": + """Create a limit order request.""" + return cls( + symbol=symbol, + side=side, + quantity=Decimal(str(quantity)), + order_type=OrderType.LIMIT, + limit_price=Decimal(str(limit_price)), + time_in_force=time_in_force, + **kwargs, + ) + + @classmethod + def stop( + cls, + symbol: str, + side: OrderSide, + quantity: Union[Decimal, float, int], + stop_price: Union[Decimal, float, int], + time_in_force: TimeInForce = TimeInForce.GTC, + **kwargs, + ) -> "OrderRequest": + """Create a stop order request.""" + return cls( + symbol=symbol, + side=side, + quantity=Decimal(str(quantity)), + order_type=OrderType.STOP, + stop_price=Decimal(str(stop_price)), + time_in_force=time_in_force, + **kwargs, + ) + + @classmethod + def stop_limit( + cls, + symbol: str, + side: OrderSide, + quantity: Union[Decimal, float, int], + stop_price: Union[Decimal, float, int], + limit_price: Union[Decimal, float, int], + time_in_force: TimeInForce = TimeInForce.GTC, + **kwargs, + ) -> "OrderRequest": + """Create a stop-limit order request.""" + return cls( + symbol=symbol, + side=side, + quantity=Decimal(str(quantity)), + order_type=OrderType.STOP_LIMIT, + stop_price=Decimal(str(stop_price)), + limit_price=Decimal(str(limit_price)), + time_in_force=time_in_force, + **kwargs, + ) + + @classmethod + def trailing_stop( + cls, + symbol: str, + side: OrderSide, + quantity: Union[Decimal, float, int], + trail_percent: Optional[Union[Decimal, float]] = None, + trail_amount: Optional[Union[Decimal, float]] = None, + time_in_force: TimeInForce = TimeInForce.GTC, + **kwargs, + ) -> "OrderRequest": + """Create a trailing stop order request.""" + return cls( + symbol=symbol, + side=side, + quantity=Decimal(str(quantity)), + order_type=OrderType.TRAILING_STOP, + trail_percent=Decimal(str(trail_percent)) if trail_percent else None, + trail_amount=Decimal(str(trail_amount)) if trail_amount else None, + time_in_force=time_in_force, + **kwargs, + ) + + +@dataclass +class Order: + """Order information returned from broker. + + Attributes: + broker_order_id: Broker-assigned order ID + client_order_id: Client-assigned order ID + symbol: Trading symbol + side: Buy or sell + quantity: Ordered quantity + order_type: Type of order + status: Current order status + limit_price: Limit price (if applicable) + stop_price: Stop price (if applicable) + time_in_force: Order duration + filled_quantity: Quantity filled so far + filled_avg_price: Average fill price + created_at: Order creation timestamp + updated_at: Last update timestamp + submitted_at: Submission timestamp + filled_at: Fill completion timestamp (if filled) + cancelled_at: Cancellation timestamp (if cancelled) + expired_at: Expiration timestamp (if expired) + extended_hours: Whether extended hours allowed + trail_amount: Trail amount (if trailing stop) + trail_percent: Trail percent (if trailing stop) + legs: Child orders (for bracket/OCO orders) + reject_reason: Reason for rejection (if rejected) + metadata: Additional broker-specific data + """ + broker_order_id: str + client_order_id: str + symbol: str + side: OrderSide + quantity: Decimal + order_type: OrderType + status: OrderStatus + limit_price: Optional[Decimal] = None + stop_price: Optional[Decimal] = None + time_in_force: TimeInForce = TimeInForce.DAY + filled_quantity: Decimal = field(default_factory=lambda: Decimal("0")) + filled_avg_price: Optional[Decimal] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + submitted_at: Optional[datetime] = None + filled_at: Optional[datetime] = None + cancelled_at: Optional[datetime] = None + expired_at: Optional[datetime] = None + extended_hours: bool = False + trail_amount: Optional[Decimal] = None + trail_percent: Optional[Decimal] = None + legs: List["Order"] = field(default_factory=list) + reject_reason: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def is_open(self) -> bool: + """Check if order is still open.""" + return self.status in ( + OrderStatus.PENDING_NEW, + OrderStatus.NEW, + OrderStatus.PARTIALLY_FILLED, + OrderStatus.PENDING_CANCEL, + ) + + @property + def is_filled(self) -> bool: + """Check if order is completely filled.""" + return self.status == OrderStatus.FILLED + + @property + def is_cancelled(self) -> bool: + """Check if order is cancelled.""" + return self.status == OrderStatus.CANCELLED + + @property + def remaining_quantity(self) -> Decimal: + """Calculate remaining unfilled quantity.""" + return self.quantity - self.filled_quantity + + @property + def fill_percent(self) -> float: + """Calculate fill percentage.""" + if self.quantity == 0: + return 0.0 + return float(self.filled_quantity / self.quantity * 100) + + +@dataclass +class Position: + """Current position in an asset. + + Attributes: + symbol: Trading symbol + quantity: Position quantity (positive for long, negative for short) + side: Position side (long/short) + avg_entry_price: Average entry price + current_price: Current market price + market_value: Current market value + cost_basis: Total cost basis + unrealized_pnl: Unrealized profit/loss + unrealized_pnl_percent: Unrealized P&L as percentage + realized_pnl: Realized profit/loss (if tracked) + asset_class: Asset class + exchange: Exchange where traded + asset_id: Broker's asset ID + metadata: Additional broker-specific data + """ + symbol: str + quantity: Decimal + side: PositionSide + avg_entry_price: Decimal + current_price: Decimal + market_value: Decimal + cost_basis: Decimal + unrealized_pnl: Decimal + unrealized_pnl_percent: Decimal + realized_pnl: Optional[Decimal] = None + asset_class: AssetClass = AssetClass.EQUITY + exchange: Optional[str] = None + asset_id: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def is_long(self) -> bool: + """Check if position is long.""" + return self.side == PositionSide.LONG + + @property + def is_short(self) -> bool: + """Check if position is short.""" + return self.side == PositionSide.SHORT + + @property + def abs_quantity(self) -> Decimal: + """Get absolute quantity.""" + return abs(self.quantity) + + +@dataclass +class AccountInfo: + """Broker account information. + + Attributes: + account_id: Broker account ID + account_type: Account type (e.g., 'cash', 'margin') + status: Account status + currency: Base currency + cash: Available cash balance + portfolio_value: Total portfolio value + buying_power: Available buying power + equity: Account equity + margin_used: Margin currently in use + margin_available: Available margin + initial_margin: Initial margin requirement + maintenance_margin: Maintenance margin requirement + pending_transfer_in: Pending incoming transfers + pending_transfer_out: Pending outgoing transfers + day_trades_remaining: PDT day trades remaining (if applicable) + is_pattern_day_trader: Whether flagged as PDT + created_at: Account creation date + metadata: Additional broker-specific data + """ + account_id: str + account_type: str + status: str + currency: str = "USD" + cash: Decimal = field(default_factory=lambda: Decimal("0")) + portfolio_value: Decimal = field(default_factory=lambda: Decimal("0")) + buying_power: Decimal = field(default_factory=lambda: Decimal("0")) + equity: Decimal = field(default_factory=lambda: Decimal("0")) + margin_used: Optional[Decimal] = None + margin_available: Optional[Decimal] = None + initial_margin: Optional[Decimal] = None + maintenance_margin: Optional[Decimal] = None + pending_transfer_in: Optional[Decimal] = None + pending_transfer_out: Optional[Decimal] = None + day_trades_remaining: Optional[int] = None + is_pattern_day_trader: bool = False + created_at: Optional[datetime] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def is_active(self) -> bool: + """Check if account is active.""" + return self.status.lower() in ("active", "approved", "enabled") + + +@dataclass +class Quote: + """Current quote/price data. + + Attributes: + symbol: Trading symbol + bid_price: Current bid price + bid_size: Bid size + ask_price: Current ask price + ask_size: Ask size + last_price: Last trade price + last_size: Last trade size + volume: Trading volume + timestamp: Quote timestamp + exchange: Exchange code + conditions: Trade conditions + metadata: Additional data + """ + symbol: str + bid_price: Optional[Decimal] = None + bid_size: Optional[Decimal] = None + ask_price: Optional[Decimal] = None + ask_size: Optional[Decimal] = None + last_price: Optional[Decimal] = None + last_size: Optional[Decimal] = None + volume: Optional[int] = None + timestamp: Optional[datetime] = None + exchange: Optional[str] = None + conditions: List[str] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def mid_price(self) -> Optional[Decimal]: + """Calculate mid price between bid and ask.""" + if self.bid_price is not None and self.ask_price is not None: + return (self.bid_price + self.ask_price) / 2 + return self.last_price + + @property + def spread(self) -> Optional[Decimal]: + """Calculate bid-ask spread.""" + if self.bid_price is not None and self.ask_price is not None: + return self.ask_price - self.bid_price + return None + + @property + def spread_percent(self) -> Optional[float]: + """Calculate spread as percentage of mid price.""" + if self.spread is not None and self.mid_price is not None and self.mid_price > 0: + return float(self.spread / self.mid_price * 100) + return None + + +@dataclass +class AssetInfo: + """Asset/instrument information. + + Attributes: + symbol: Trading symbol + name: Full name + asset_class: Asset class + exchange: Primary exchange + tradable: Whether currently tradable + marginable: Whether marginable + shortable: Whether shortable + easy_to_borrow: Whether easy to borrow for shorting + fractionable: Whether fractional shares allowed + min_order_size: Minimum order size + min_trade_increment: Minimum trade increment + price_increment: Price increment (tick size) + maintenance_margin_req: Maintenance margin requirement + attributes: Additional attributes list + metadata: Additional broker-specific data + """ + symbol: str + name: str + asset_class: AssetClass = AssetClass.EQUITY + exchange: Optional[str] = None + tradable: bool = True + marginable: bool = True + shortable: bool = True + easy_to_borrow: bool = True + fractionable: bool = False + min_order_size: Optional[Decimal] = None + min_trade_increment: Optional[Decimal] = None + price_increment: Optional[Decimal] = None + maintenance_margin_req: Optional[Decimal] = None + attributes: List[str] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + + +class BrokerError(Exception): + """Base exception for broker errors.""" + + def __init__(self, message: str, code: Optional[str] = None, details: Optional[Dict] = None): + super().__init__(message) + self.code = code + self.details = details or {} + + +class ConnectionError(BrokerError): + """Error connecting to broker.""" + pass + + +class AuthenticationError(BrokerError): + """Authentication failed.""" + pass + + +class OrderError(BrokerError): + """Error submitting or managing order.""" + pass + + +class InsufficientFundsError(OrderError): + """Insufficient funds for order.""" + pass + + +class InvalidOrderError(OrderError): + """Invalid order parameters.""" + pass + + +class PositionError(BrokerError): + """Error with position operations.""" + pass + + +class RateLimitError(BrokerError): + """Rate limit exceeded.""" + + def __init__(self, message: str, retry_after: Optional[float] = None, **kwargs): + super().__init__(message, **kwargs) + self.retry_after = retry_after + + +class BrokerBase(ABC): + """Abstract base class for broker implementations. + + All broker implementations must inherit from this class and implement + the abstract methods. This provides a uniform interface for the trading + system regardless of which broker is used. + + Example: + >>> class AlpacaBroker(BrokerBase): + ... async def connect(self) -> bool: + ... # Connect to Alpaca API + ... return True + ... # ... implement other abstract methods + >>> + >>> broker = AlpacaBroker(api_key="...", api_secret="...") + >>> await broker.connect() + >>> order = await broker.submit_order( + ... OrderRequest.market("AAPL", OrderSide.BUY, 100) + ... ) + """ + + def __init__( + self, + name: str, + supported_asset_classes: Optional[List[AssetClass]] = None, + paper_trading: bool = False, + **kwargs, + ): + """Initialize broker base. + + Args: + name: Broker name + supported_asset_classes: List of supported asset classes + paper_trading: Whether this is paper trading mode + **kwargs: Additional broker-specific configuration + """ + self._name = name + self._supported_asset_classes = supported_asset_classes or [AssetClass.EQUITY] + self._paper_trading = paper_trading + self._connected = False + self._config = kwargs + + @property + def name(self) -> str: + """Get broker name.""" + return self._name + + @property + def supported_asset_classes(self) -> List[AssetClass]: + """Get list of supported asset classes.""" + return self._supported_asset_classes + + @property + def is_paper_trading(self) -> bool: + """Check if broker is in paper trading mode.""" + return self._paper_trading + + @property + def is_connected(self) -> bool: + """Check if broker is connected.""" + return self._connected + + def supports_asset_class(self, asset_class: AssetClass) -> bool: + """Check if broker supports a specific asset class. + + Args: + asset_class: Asset class to check + + Returns: + True if supported, False otherwise + """ + return asset_class in self._supported_asset_classes + + # ========================================================================== + # Connection Management + # ========================================================================== + + @abstractmethod + async def connect(self) -> bool: + """Connect to broker API. + + Returns: + True if connection successful, False otherwise + + Raises: + ConnectionError: If connection fails + AuthenticationError: If authentication fails + """ + pass + + @abstractmethod + async def disconnect(self) -> None: + """Disconnect from broker API.""" + pass + + @abstractmethod + async def is_market_open(self) -> bool: + """Check if market is currently open. + + Returns: + True if market is open, False otherwise + """ + pass + + # ========================================================================== + # Account Information + # ========================================================================== + + @abstractmethod + async def get_account(self) -> AccountInfo: + """Get account information. + + Returns: + AccountInfo object with account details + + Raises: + ConnectionError: If not connected + BrokerError: If account retrieval fails + """ + pass + + # ========================================================================== + # Order Management + # ========================================================================== + + @abstractmethod + async def submit_order(self, request: OrderRequest) -> Order: + """Submit a new order. + + Args: + request: Order request details + + Returns: + Order object representing submitted order + + Raises: + ConnectionError: If not connected + InvalidOrderError: If order parameters invalid + InsufficientFundsError: If insufficient buying power + OrderError: If order submission fails + """ + pass + + @abstractmethod + async def cancel_order(self, order_id: str) -> Order: + """Cancel an existing order. + + Args: + order_id: Broker order ID to cancel + + Returns: + Updated order object + + Raises: + OrderError: If cancellation fails + """ + pass + + @abstractmethod + async def replace_order( + self, + order_id: str, + quantity: Optional[Decimal] = None, + limit_price: Optional[Decimal] = None, + stop_price: Optional[Decimal] = None, + time_in_force: Optional[TimeInForce] = None, + ) -> Order: + """Replace/modify an existing order. + + Args: + order_id: Broker order ID to replace + quantity: New quantity (optional) + limit_price: New limit price (optional) + stop_price: New stop price (optional) + time_in_force: New time in force (optional) + + Returns: + New order object + + Raises: + OrderError: If replacement fails + """ + pass + + @abstractmethod + async def get_order(self, order_id: str) -> Order: + """Get order by ID. + + Args: + order_id: Broker order ID + + Returns: + Order object + + Raises: + OrderError: If order not found + """ + pass + + @abstractmethod + async def get_orders( + self, + status: Optional[OrderStatus] = None, + limit: int = 100, + symbols: Optional[List[str]] = None, + ) -> List[Order]: + """Get orders with optional filters. + + Args: + status: Filter by order status + limit: Maximum number of orders to return + symbols: Filter by symbols + + Returns: + List of Order objects + """ + pass + + async def cancel_all_orders(self, symbols: Optional[List[str]] = None) -> List[Order]: + """Cancel all open orders. + + Args: + symbols: Optional list of symbols to cancel orders for + + Returns: + List of cancelled orders + """ + open_orders = await self.get_orders( + status=OrderStatus.NEW, + symbols=symbols, + ) + + # Also get partially filled orders + partial_orders = await self.get_orders( + status=OrderStatus.PARTIALLY_FILLED, + symbols=symbols, + ) + + cancelled = [] + for order in open_orders + partial_orders: + try: + cancelled_order = await self.cancel_order(order.broker_order_id) + cancelled.append(cancelled_order) + except OrderError: + # Order may have been filled between query and cancel + pass + + return cancelled + + # ========================================================================== + # Position Management + # ========================================================================== + + @abstractmethod + async def get_positions(self) -> List[Position]: + """Get all current positions. + + Returns: + List of Position objects + """ + pass + + @abstractmethod + async def get_position(self, symbol: str) -> Optional[Position]: + """Get position for a specific symbol. + + Args: + symbol: Trading symbol + + Returns: + Position object or None if no position + """ + pass + + async def close_position( + self, + symbol: str, + quantity: Optional[Decimal] = None, + ) -> Order: + """Close a position partially or completely. + + Args: + symbol: Symbol to close + quantity: Quantity to close (None for entire position) + + Returns: + Order object for the closing trade + + Raises: + PositionError: If position doesn't exist + """ + position = await self.get_position(symbol) + if position is None: + raise PositionError(f"No position found for {symbol}") + + close_qty = quantity if quantity is not None else position.abs_quantity + + # Determine side based on position + side = OrderSide.SELL if position.is_long else OrderSide.BUY + + return await self.submit_order( + OrderRequest.market(symbol, side, close_qty) + ) + + async def close_all_positions(self) -> List[Order]: + """Close all positions. + + Returns: + List of orders for closing trades + """ + positions = await self.get_positions() + orders = [] + + for position in positions: + try: + order = await self.close_position(position.symbol) + orders.append(order) + except (OrderError, PositionError): + # Position may have been closed between query and close + pass + + return orders + + # ========================================================================== + # Market Data + # ========================================================================== + + @abstractmethod + async def get_quote(self, symbol: str) -> Quote: + """Get current quote for a symbol. + + Args: + symbol: Trading symbol + + Returns: + Quote object with bid/ask/last prices + """ + pass + + async def get_quotes(self, symbols: List[str]) -> Dict[str, Quote]: + """Get quotes for multiple symbols. + + Default implementation calls get_quote for each symbol. + Override for batch operations if supported by broker. + + Args: + symbols: List of trading symbols + + Returns: + Dict mapping symbol to Quote + """ + quotes = {} + for symbol in symbols: + try: + quotes[symbol] = await self.get_quote(symbol) + except BrokerError: + pass + return quotes + + @abstractmethod + async def get_asset(self, symbol: str) -> AssetInfo: + """Get asset information. + + Args: + symbol: Trading symbol + + Returns: + AssetInfo object with asset details + """ + pass + + # ========================================================================== + # Utility Methods + # ========================================================================== + + async def validate_order(self, request: OrderRequest) -> List[str]: + """Validate an order request before submission. + + Args: + request: Order request to validate + + Returns: + List of validation error messages (empty if valid) + """ + errors = [] + + # Check basic parameters + if request.quantity <= 0: + errors.append("Quantity must be positive") + + # Check asset is tradable + try: + asset = await self.get_asset(request.symbol) + if not asset.tradable: + errors.append(f"{request.symbol} is not currently tradable") + except BrokerError: + errors.append(f"Could not validate asset {request.symbol}") + + # Check limit price for limit orders + if request.order_type in (OrderType.LIMIT, OrderType.STOP_LIMIT): + if request.limit_price is None or request.limit_price <= 0: + errors.append("Limit price must be positive for limit orders") + + # Check stop price for stop orders + if request.order_type in (OrderType.STOP, OrderType.STOP_LIMIT, OrderType.TRAILING_STOP): + if request.order_type != OrderType.TRAILING_STOP: + if request.stop_price is None or request.stop_price <= 0: + errors.append("Stop price must be positive for stop orders") + + # Check buying power + if request.side == OrderSide.BUY: + try: + account = await self.get_account() + quote = await self.get_quote(request.symbol) + estimated_cost = request.quantity * ( + request.limit_price or quote.ask_price or quote.last_price or Decimal("0") + ) + if estimated_cost > account.buying_power: + errors.append( + f"Insufficient buying power. Required: {estimated_cost}, " + f"Available: {account.buying_power}" + ) + except BrokerError: + pass # Skip buying power check if we can't get data + + return errors + + def __repr__(self) -> str: + """String representation.""" + return ( + f"{self.__class__.__name__}(" + f"name='{self._name}', " + f"paper_trading={self._paper_trading}, " + f"connected={self._connected})" + )