From 6863e3ed870ba6b3bcefc4a2b2ac47109577a980 Mon Sep 17 00:00:00 2001 From: Andrew Kaszubski Date: Fri, 26 Dec 2025 21:24:54 +1100 Subject: [PATCH] feat(execution): add Order Manager for lifecycle management - Issue #27 (47 tests) --- tests/unit/execution/test_order_manager.py | 659 +++++++++++++++++++++ tradingagents/execution/__init__.py | 18 + tradingagents/execution/order_manager.py | 650 ++++++++++++++++++++ 3 files changed, 1327 insertions(+) create mode 100644 tests/unit/execution/test_order_manager.py create mode 100644 tradingagents/execution/order_manager.py diff --git a/tests/unit/execution/test_order_manager.py b/tests/unit/execution/test_order_manager.py new file mode 100644 index 00000000..3a80c5e6 --- /dev/null +++ b/tests/unit/execution/test_order_manager.py @@ -0,0 +1,659 @@ +"""Tests for Order Manager implementation. + +Issue #27: [EXEC-26] Order types and manager - market, limit, stop, trailing +""" + +from decimal import Decimal +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock +import pytest + +from tradingagents.execution import ( + OrderManager, + OrderEvent, + OrderValidationResult, + OrderStateChange, + PaperBroker, + OrderRequest, + OrderSide, + OrderType, + OrderStatus, + TimeInForce, + InvalidOrderError, + VALID_TRANSITIONS, + TERMINAL_STATES, + OPEN_STATES, +) + + +class TestOrderValidationResult: + """Test OrderValidationResult dataclass.""" + + def test_default_valid(self): + """Test default result is valid.""" + result = OrderValidationResult() + assert result.valid is True + assert result.errors == [] + assert result.warnings == [] + + def test_invalid_with_errors(self): + """Test result with errors.""" + result = OrderValidationResult( + valid=False, + errors=["Error 1", "Error 2"], + ) + assert result.valid is False + assert len(result.errors) == 2 + + def test_valid_with_warnings(self): + """Test result with warnings but still valid.""" + result = OrderValidationResult( + valid=True, + warnings=["Warning 1"], + ) + assert result.valid is True + assert len(result.warnings) == 1 + + +class TestOrderStateChange: + """Test OrderStateChange dataclass.""" + + def test_state_change_creation(self): + """Test creating state change record.""" + change = OrderStateChange( + order_id="TEST-123", + from_status=OrderStatus.NEW, + to_status=OrderStatus.FILLED, + event=OrderEvent.FILLED, + ) + assert change.order_id == "TEST-123" + assert change.from_status == OrderStatus.NEW + assert change.to_status == OrderStatus.FILLED + assert change.event == OrderEvent.FILLED + assert isinstance(change.timestamp, datetime) + + def test_state_change_with_metadata(self): + """Test state change with metadata.""" + change = OrderStateChange( + order_id="TEST-123", + from_status=None, + to_status=OrderStatus.NEW, + event=OrderEvent.SUBMITTED, + metadata={"broker": "paper"}, + ) + assert change.metadata == {"broker": "paper"} + + +class TestOrderManagerInit: + """Test OrderManager initialization.""" + + def test_default_initialization(self): + """Test default initialization.""" + manager = OrderManager() + assert manager.order_count == 0 + assert manager.open_order_count == 0 + + def test_custom_max_orders(self): + """Test initialization with custom max orders.""" + manager = OrderManager(max_orders=100) + assert manager._max_orders == 100 + + def test_validation_disabled(self): + """Test initialization with validation disabled.""" + manager = OrderManager(validate_before_submit=False) + assert manager._validate_before_submit is False + + +class TestOrderValidation: + """Test OrderManager order validation.""" + + def test_valid_market_order(self): + """Test valid market order.""" + manager = OrderManager() + request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")) + result = manager.validate_order(request) + assert result.valid is True + + def test_invalid_quantity_zero(self): + """Test invalid order with zero quantity.""" + manager = OrderManager() + request = OrderRequest( + symbol="AAPL", + side=OrderSide.BUY, + quantity=Decimal("0"), + ) + result = manager.validate_order(request) + assert result.valid is False + assert any("positive" in e.lower() for e in result.errors) + + def test_invalid_quantity_negative(self): + """Test invalid order with negative quantity.""" + manager = OrderManager() + request = OrderRequest( + symbol="AAPL", + side=OrderSide.BUY, + quantity=Decimal("-10"), + ) + result = manager.validate_order(request) + assert result.valid is False + + def test_limit_order_missing_price(self): + """Test limit order without limit price raises at construction.""" + # OrderRequest validates in __post_init__ + with pytest.raises(ValueError, match="limit_price"): + OrderRequest( + symbol="AAPL", + side=OrderSide.BUY, + quantity=Decimal("100"), + order_type=OrderType.LIMIT, + limit_price=None, + ) + + def test_limit_order_with_price(self): + """Test valid limit order.""" + manager = OrderManager() + request = OrderRequest.limit( + "AAPL", OrderSide.BUY, Decimal("100"), Decimal("150.00") + ) + result = manager.validate_order(request) + assert result.valid is True + + def test_stop_order_missing_stop_price(self): + """Test stop order without stop price raises at construction.""" + # OrderRequest validates in __post_init__ + with pytest.raises(ValueError, match="stop_price"): + OrderRequest( + symbol="AAPL", + side=OrderSide.SELL, + quantity=Decimal("100"), + order_type=OrderType.STOP, + stop_price=None, + ) + + def test_stop_limit_order_missing_both(self): + """Test stop limit order missing both prices raises at construction.""" + # OrderRequest validates in __post_init__ + with pytest.raises(ValueError, match="stop_price|limit_price"): + OrderRequest( + symbol="AAPL", + side=OrderSide.SELL, + quantity=Decimal("100"), + order_type=OrderType.STOP_LIMIT, + ) + + def test_trailing_stop_missing_trail(self): + """Test trailing stop without trail parameters raises at construction.""" + # OrderRequest validates in __post_init__ + with pytest.raises(ValueError, match="trail"): + OrderRequest( + symbol="AAPL", + side=OrderSide.SELL, + quantity=Decimal("100"), + order_type=OrderType.TRAILING_STOP, + ) + + def test_trailing_stop_with_percent(self): + """Test valid trailing stop with percent.""" + manager = OrderManager() + request = OrderRequest( + symbol="AAPL", + side=OrderSide.SELL, + quantity=Decimal("100"), + order_type=OrderType.TRAILING_STOP, + trail_percent=Decimal("5.0"), + ) + result = manager.validate_order(request) + assert result.valid is True + + def test_trailing_stop_high_percent_warning(self): + """Test trailing stop with high percent warns.""" + manager = OrderManager() + request = OrderRequest( + symbol="AAPL", + side=OrderSide.SELL, + quantity=Decimal("100"), + order_type=OrderType.TRAILING_STOP, + trail_percent=Decimal("60.0"), + ) + result = manager.validate_order(request) + assert result.valid is True + assert len(result.warnings) > 0 + + def test_empty_symbol(self): + """Test order with empty symbol.""" + manager = OrderManager() + request = OrderRequest( + symbol="", + side=OrderSide.BUY, + quantity=Decimal("100"), + ) + result = manager.validate_order(request) + assert result.valid is False + assert any("symbol" in e.lower() for e in result.errors) + + +class TestStateTransitions: + """Test order state machine transitions.""" + + def test_valid_transitions_defined(self): + """Test all statuses have defined transitions.""" + for status in OrderStatus: + assert status in VALID_TRANSITIONS + + def test_terminal_states_immutable(self): + """Test terminal states have no valid transitions.""" + for status in TERMINAL_STATES: + assert len(VALID_TRANSITIONS[status]) == 0 + + def test_open_states_have_transitions(self): + """Test open states have transitions.""" + for status in OPEN_STATES: + assert len(VALID_TRANSITIONS[status]) > 0 + + def test_is_valid_transition(self): + """Test valid transition checking.""" + manager = OrderManager() + assert manager.is_valid_transition(OrderStatus.NEW, OrderStatus.FILLED) + assert manager.is_valid_transition(OrderStatus.NEW, OrderStatus.CANCELLED) + assert not manager.is_valid_transition(OrderStatus.FILLED, OrderStatus.NEW) + + def test_is_terminal(self): + """Test terminal status checking.""" + manager = OrderManager() + assert manager.is_terminal(OrderStatus.FILLED) is True + assert manager.is_terminal(OrderStatus.CANCELLED) is True + assert manager.is_terminal(OrderStatus.NEW) is False + + def test_is_open(self): + """Test open status checking.""" + manager = OrderManager() + assert manager.is_open(OrderStatus.NEW) is True + assert manager.is_open(OrderStatus.PARTIALLY_FILLED) is True + assert manager.is_open(OrderStatus.FILLED) is False + + +class TestOrderSubmission: + """Test OrderManager order submission.""" + + @pytest.mark.asyncio + async def test_submit_order(self): + """Test basic order submission.""" + manager = OrderManager() + broker = PaperBroker() + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + order = await manager.submit_order(request, broker) + + assert order.symbol == "AAPL" + assert order.status == OrderStatus.FILLED + assert manager.order_count == 1 + + @pytest.mark.asyncio + async def test_submit_order_validation_fails(self): + """Test submission fails on invalid order.""" + manager = OrderManager() + broker = PaperBroker() + await broker.connect() + + request = OrderRequest( + symbol="AAPL", + side=OrderSide.BUY, + quantity=Decimal("-10"), + ) + + with pytest.raises(InvalidOrderError): + await manager.submit_order(request, broker) + + @pytest.mark.asyncio + async def test_submit_order_validation_disabled(self): + """Test submission with validation disabled.""" + manager = OrderManager(validate_before_submit=False) + broker = PaperBroker() + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + # This would normally fail validation but passes with disabled + request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + order = await manager.submit_order(request, broker) + assert order is not None + + @pytest.mark.asyncio + async def test_order_tracking(self): + """Test order is tracked after submission.""" + manager = OrderManager() + broker = PaperBroker() + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + order = await manager.submit_order(request, broker) + + tracked = manager.get_order(order.broker_order_id) + assert tracked is not None + assert tracked.broker_order_id == order.broker_order_id + + @pytest.mark.asyncio + async def test_order_history_recorded(self): + """Test order history is recorded.""" + manager = OrderManager() + broker = PaperBroker() + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + order = await manager.submit_order(request, broker) + + history = manager.get_order_history(order.broker_order_id) + assert len(history) >= 1 + assert history[0].event == OrderEvent.SUBMITTED + + +class TestOrderCallbacks: + """Test OrderManager event callbacks.""" + + @pytest.mark.asyncio + async def test_register_callback(self): + """Test registering callback.""" + manager = OrderManager() + callback_called = [] + + async def callback(order, event, metadata): + callback_called.append((order, event)) + + manager.register_callback(OrderEvent.FILLED, callback) + + broker = PaperBroker() + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + await manager.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), + broker, + ) + + assert len(callback_called) == 1 + assert callback_called[0][1] == OrderEvent.FILLED + + @pytest.mark.asyncio + async def test_unregister_callback(self): + """Test unregistering callback.""" + manager = OrderManager() + callback_called = [] + + async def callback(order, event, metadata): + callback_called.append(True) + + manager.register_callback(OrderEvent.FILLED, callback) + manager.unregister_callback(OrderEvent.FILLED, callback) + + broker = PaperBroker() + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + await manager.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), + broker, + ) + + assert len(callback_called) == 0 + + @pytest.mark.asyncio + async def test_callback_error_doesnt_break_flow(self): + """Test callback error doesn't break order flow.""" + manager = OrderManager() + + async def bad_callback(order, event, metadata): + raise Exception("Callback error") + + manager.register_callback(OrderEvent.FILLED, bad_callback) + + broker = PaperBroker() + broker.set_price("AAPL", Decimal("100")) + await broker.connect() + + # Should not raise despite callback error + order = await manager.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), + broker, + ) + assert order is not None + + +class TestOrderCancellation: + """Test OrderManager order cancellation.""" + + @pytest.mark.asyncio + async def test_cancel_order(self): + """Test cancelling an order.""" + manager = OrderManager() + broker = PaperBroker(fill_probability=0.0) + await broker.connect() + + request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + order = await manager.submit_order(request, broker) + + cancelled = await manager.cancel_order(order.broker_order_id, broker) + assert cancelled.status == OrderStatus.CANCELLED + + @pytest.mark.asyncio + async def test_cancel_updates_tracking(self): + """Test cancellation updates tracked order.""" + manager = OrderManager() + broker = PaperBroker(fill_probability=0.0) + await broker.connect() + + request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + order = await manager.submit_order(request, broker) + + await manager.cancel_order(order.broker_order_id, broker) + + tracked = manager.get_order(order.broker_order_id) + assert tracked.status == OrderStatus.CANCELLED + + +class TestOrderReplacement: + """Test OrderManager order replacement.""" + + @pytest.mark.asyncio + async def test_replace_order(self): + """Test replacing an order.""" + manager = OrderManager() + broker = PaperBroker(fill_probability=0.0) + await broker.connect() + + request = OrderRequest.limit( + "AAPL", OrderSide.BUY, Decimal("10"), Decimal("100") + ) + order = await manager.submit_order(request, broker) + + new_order = await manager.replace_order( + order.broker_order_id, + broker, + quantity=Decimal("20"), + ) + + assert new_order.quantity == Decimal("20") + assert new_order.broker_order_id != order.broker_order_id + + +class TestOrderQueries: + """Test OrderManager query methods.""" + + @pytest.mark.asyncio + async def test_get_orders_all(self): + """Test getting all orders.""" + manager = OrderManager() + broker = PaperBroker() + broker.set_price("AAPL", Decimal("10")) + broker.set_price("MSFT", Decimal("10")) + await broker.connect() + + await manager.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker + ) + await manager.submit_order( + OrderRequest.market("MSFT", OrderSide.BUY, Decimal("10")), broker + ) + + orders = manager.get_orders() + assert len(orders) == 2 + + @pytest.mark.asyncio + async def test_get_orders_by_symbol(self): + """Test filtering orders by symbol.""" + manager = OrderManager() + broker = PaperBroker() + broker.set_price("AAPL", Decimal("10")) + broker.set_price("MSFT", Decimal("10")) + await broker.connect() + + await manager.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker + ) + await manager.submit_order( + OrderRequest.market("MSFT", OrderSide.BUY, Decimal("10")), broker + ) + + orders = manager.get_orders(symbol="AAPL") + assert len(orders) == 1 + assert orders[0].symbol == "AAPL" + + @pytest.mark.asyncio + async def test_get_orders_by_status(self): + """Test filtering orders by status.""" + manager = OrderManager() + broker = PaperBroker() + broker.set_price("AAPL", Decimal("10")) + await broker.connect() + + # Create filled order + await manager.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker + ) + + # Create unfilled order + broker._fill_probability = 0.0 + await manager.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker + ) + + filled_orders = manager.get_orders(status=OrderStatus.FILLED) + assert len(filled_orders) == 1 + + @pytest.mark.asyncio + async def test_get_open_orders(self): + """Test getting open orders only.""" + manager = OrderManager() + broker = PaperBroker(fill_probability=0.0) + await broker.connect() + + await manager.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker + ) + + open_orders = manager.get_open_orders() + assert len(open_orders) == 1 + + +class TestOrderCleanup: + """Test OrderManager order cleanup.""" + + @pytest.mark.asyncio + async def test_clear_completed_orders(self): + """Test clearing completed orders.""" + manager = OrderManager() + broker = PaperBroker() + broker.set_price("AAPL", Decimal("10")) + await broker.connect() + + # Create filled orders + await manager.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker + ) + await manager.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker + ) + + assert manager.order_count == 2 + + removed = manager.clear_completed_orders() + assert removed == 2 + assert manager.order_count == 0 + + @pytest.mark.asyncio + async def test_max_orders_limit(self): + """Test orders are trimmed when max reached.""" + manager = OrderManager(max_orders=5) + broker = PaperBroker() + broker.set_price("AAPL", Decimal("1")) + await broker.connect() + + # Submit more than max orders + for _ in range(10): + await manager.submit_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("1")), broker + ) + + assert manager.order_count <= 5 + + +class TestOrderStatusUpdate: + """Test OrderManager status updates.""" + + @pytest.mark.asyncio + async def test_update_order_status(self): + """Test updating order status.""" + manager = OrderManager() + broker = PaperBroker(fill_probability=0.0) + await broker.connect() + + request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + order = await manager.submit_order(request, broker) + + # Simulate status change + order.status = OrderStatus.FILLED + order.filled_quantity = order.quantity + await manager.update_order_status(order) + + tracked = manager.get_order(order.broker_order_id) + assert tracked.status == OrderStatus.FILLED + + +class TestOrderEvents: + """Test OrderEvent enum.""" + + def test_all_events_defined(self): + """Test all expected events are defined.""" + expected_events = [ + "CREATED", "SUBMITTED", "ACCEPTED", "REJECTED", + "PARTIALLY_FILLED", "FILLED", "PENDING_CANCEL", + "CANCELLED", "REPLACED", "EXPIRED", "ERROR" + ] + for event_name in expected_events: + assert hasattr(OrderEvent, event_name) + + +class TestStateConstants: + """Test state machine constants.""" + + def test_terminal_states_complete(self): + """Test all terminal states are included.""" + assert OrderStatus.FILLED in TERMINAL_STATES + assert OrderStatus.CANCELLED in TERMINAL_STATES + assert OrderStatus.REJECTED in TERMINAL_STATES + assert OrderStatus.EXPIRED in TERMINAL_STATES + assert OrderStatus.REPLACED in TERMINAL_STATES + + def test_open_states_complete(self): + """Test all open states are included.""" + assert OrderStatus.PENDING_NEW in OPEN_STATES + assert OrderStatus.NEW in OPEN_STATES + assert OrderStatus.PARTIALLY_FILLED in OPEN_STATES + assert OrderStatus.PENDING_CANCEL in OPEN_STATES + + def test_no_overlap(self): + """Test terminal and open states don't overlap.""" + overlap = TERMINAL_STATES & OPEN_STATES + assert len(overlap) == 0 diff --git a/tradingagents/execution/__init__.py b/tradingagents/execution/__init__.py index 1027e755..d26154c6 100644 --- a/tradingagents/execution/__init__.py +++ b/tradingagents/execution/__init__.py @@ -114,6 +114,16 @@ from .ibkr_broker import ( from .paper_broker import PaperBroker +from .order_manager import ( + OrderManager, + OrderEvent, + OrderValidationResult, + OrderStateChange, + VALID_TRANSITIONS, + TERMINAL_STATES, + OPEN_STATES, +) + __all__ = [ # Enums "AssetClass", @@ -158,4 +168,12 @@ __all__ = [ "FUTURES_SPECS", # Paper Broker "PaperBroker", + # Order Manager + "OrderManager", + "OrderEvent", + "OrderValidationResult", + "OrderStateChange", + "VALID_TRANSITIONS", + "TERMINAL_STATES", + "OPEN_STATES", ] diff --git a/tradingagents/execution/order_manager.py b/tradingagents/execution/order_manager.py new file mode 100644 index 00000000..62ddb994 --- /dev/null +++ b/tradingagents/execution/order_manager.py @@ -0,0 +1,650 @@ +"""Order Manager for order lifecycle management. + +Issue #27: [EXEC-26] Order types and manager - market, limit, stop, trailing + +This module provides order lifecycle management including validation, +state transitions, and event notifications. + +Features: + - Order validation before submission + - Order state machine with valid transitions + - Order tracking and retrieval + - Event callbacks for order state changes + - Support for all order types: market, limit, stop, stop_limit, trailing_stop + +Example: + >>> from tradingagents.execution import OrderManager, OrderRequest, OrderSide + >>> + >>> manager = OrderManager() + >>> request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")) + >>> order = await manager.submit_order(request, broker) + >>> print(f"Order {order.broker_order_id} status: {order.status}") +""" + +from __future__ import annotations + +import asyncio +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from decimal import Decimal +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Set, Awaitable + +from .broker_base import ( + BrokerBase, + Order, + OrderError, + OrderRequest, + OrderSide, + OrderStatus, + OrderType, + TimeInForce, + InvalidOrderError, +) + + +class OrderEvent(Enum): + """Order lifecycle events.""" + + CREATED = "created" + SUBMITTED = "submitted" + ACCEPTED = "accepted" + REJECTED = "rejected" + PARTIALLY_FILLED = "partially_filled" + FILLED = "filled" + PENDING_CANCEL = "pending_cancel" + CANCELLED = "cancelled" + REPLACED = "replaced" + EXPIRED = "expired" + ERROR = "error" + + +# Valid state transitions for order state machine +VALID_TRANSITIONS: Dict[OrderStatus, Set[OrderStatus]] = { + OrderStatus.PENDING_NEW: { + OrderStatus.NEW, + OrderStatus.REJECTED, + OrderStatus.CANCELLED, + }, + OrderStatus.NEW: { + OrderStatus.PARTIALLY_FILLED, + OrderStatus.FILLED, + OrderStatus.PENDING_CANCEL, + OrderStatus.CANCELLED, + OrderStatus.EXPIRED, + OrderStatus.REPLACED, + }, + OrderStatus.PARTIALLY_FILLED: { + OrderStatus.PARTIALLY_FILLED, + OrderStatus.FILLED, + OrderStatus.PENDING_CANCEL, + OrderStatus.CANCELLED, + }, + OrderStatus.FILLED: set(), # Terminal state + OrderStatus.PENDING_CANCEL: { + OrderStatus.CANCELLED, + OrderStatus.FILLED, # Can fill while cancel is pending + OrderStatus.PARTIALLY_FILLED, + }, + OrderStatus.CANCELLED: set(), # Terminal state + OrderStatus.REJECTED: set(), # Terminal state + OrderStatus.EXPIRED: set(), # Terminal state + OrderStatus.REPLACED: set(), # Terminal state +} + +# Terminal states (order cannot change after reaching these) +TERMINAL_STATES: Set[OrderStatus] = { + OrderStatus.FILLED, + OrderStatus.CANCELLED, + OrderStatus.REJECTED, + OrderStatus.EXPIRED, + OrderStatus.REPLACED, +} + +# Open states (order can still be filled or cancelled) +OPEN_STATES: Set[OrderStatus] = { + OrderStatus.PENDING_NEW, + OrderStatus.NEW, + OrderStatus.PARTIALLY_FILLED, + OrderStatus.PENDING_CANCEL, +} + + +@dataclass +class OrderValidationResult: + """Result of order validation. + + Attributes: + valid: Whether the order is valid + errors: List of validation error messages + warnings: List of validation warning messages + """ + valid: bool = True + errors: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + + +@dataclass +class OrderStateChange: + """Record of an order state change. + + Attributes: + order_id: Order identifier + from_status: Previous status + to_status: New status + event: Event that triggered the change + timestamp: When the change occurred + metadata: Additional change details + """ + order_id: str + from_status: Optional[OrderStatus] + to_status: OrderStatus + event: OrderEvent + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + metadata: Dict[str, Any] = field(default_factory=dict) + + +# Callback type for order events +OrderEventCallback = Callable[[Order, OrderEvent, Dict[str, Any]], Awaitable[None]] + + +class OrderManager: + """Manages order lifecycle and state transitions. + + The OrderManager provides: + - Order validation before submission + - Order state machine with valid transitions + - Order tracking and retrieval + - Event callbacks for order state changes + - Order history and audit trail + + Example: + >>> manager = OrderManager() + >>> + >>> # Register callbacks + >>> async def on_fill(order, event, metadata): + ... print(f"Order {order.broker_order_id} filled!") + >>> manager.register_callback(OrderEvent.FILLED, on_fill) + >>> + >>> # Submit order + >>> order = await manager.submit_order(request, broker) + """ + + def __init__( + self, + max_orders: int = 10000, + validate_before_submit: bool = True, + ) -> None: + """Initialize order manager. + + Args: + max_orders: Maximum orders to track (oldest removed when exceeded) + validate_before_submit: Whether to validate orders before submission + """ + self._orders: Dict[str, Order] = {} + self._order_history: Dict[str, List[OrderStateChange]] = {} + self._callbacks: Dict[OrderEvent, List[OrderEventCallback]] = { + event: [] for event in OrderEvent + } + self._max_orders = max_orders + self._validate_before_submit = validate_before_submit + self._lock = asyncio.Lock() + + def register_callback( + self, + event: OrderEvent, + callback: OrderEventCallback, + ) -> None: + """Register a callback for an order event. + + Args: + event: Event to listen for + callback: Async callback function(order, event, metadata) + """ + self._callbacks[event].append(callback) + + def unregister_callback( + self, + event: OrderEvent, + callback: OrderEventCallback, + ) -> None: + """Unregister a callback. + + Args: + event: Event type + callback: Callback to remove + """ + if callback in self._callbacks[event]: + self._callbacks[event].remove(callback) + + async def _fire_event( + self, + order: Order, + event: OrderEvent, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Fire callbacks for an event. + + Args: + order: Order that triggered event + event: Event type + metadata: Additional event data + """ + metadata = metadata or {} + for callback in self._callbacks[event]: + try: + await callback(order, event, metadata) + except Exception: + # Don't let callback errors break order flow + pass + + def validate_order(self, request: OrderRequest) -> OrderValidationResult: + """Validate an order request. + + Args: + request: Order request to validate + + Returns: + Validation result with errors/warnings + """ + result = OrderValidationResult() + + # Validate quantity + if request.quantity <= 0: + result.valid = False + result.errors.append("Quantity must be positive") + + # Validate limit price for limit orders + if request.order_type in (OrderType.LIMIT, OrderType.STOP_LIMIT): + if request.limit_price is None: + result.valid = False + result.errors.append(f"{request.order_type.value} order requires limit_price") + elif request.limit_price <= 0: + result.valid = False + result.errors.append("Limit price must be positive") + + # Validate stop price for stop orders + if request.order_type in (OrderType.STOP, OrderType.STOP_LIMIT): + if request.stop_price is None: + result.valid = False + result.errors.append(f"{request.order_type.value} order requires stop_price") + elif request.stop_price <= 0: + result.valid = False + result.errors.append("Stop price must be positive") + + # Validate trailing stop parameters + if request.order_type == OrderType.TRAILING_STOP: + if request.trail_amount is None and request.trail_percent is None: + result.valid = False + result.errors.append("Trailing stop requires trail_amount or trail_percent") + if request.trail_amount is not None and request.trail_amount <= 0: + result.valid = False + result.errors.append("Trail amount must be positive") + if request.trail_percent is not None: + if request.trail_percent <= 0: + result.valid = False + result.errors.append("Trail percent must be positive") + elif request.trail_percent > Decimal("50"): + result.warnings.append("Trail percent > 50% may execute far from market") + + # Validate symbol + if not request.symbol or not request.symbol.strip(): + result.valid = False + result.errors.append("Symbol is required") + + # Warn about FOK/IOC with limit orders far from market + if request.time_in_force in (TimeInForce.FOK, TimeInForce.IOC): + if request.order_type == OrderType.MARKET: + result.warnings.append( + f"{request.time_in_force.value} with market order may not execute" + ) + + return result + + def is_valid_transition( + self, + from_status: OrderStatus, + to_status: OrderStatus, + ) -> bool: + """Check if a state transition is valid. + + Args: + from_status: Current status + to_status: Target status + + Returns: + True if transition is valid + """ + return to_status in VALID_TRANSITIONS.get(from_status, set()) + + def is_terminal(self, status: OrderStatus) -> bool: + """Check if a status is terminal. + + Args: + status: Status to check + + Returns: + True if status is terminal + """ + return status in TERMINAL_STATES + + def is_open(self, status: OrderStatus) -> bool: + """Check if a status means order is open. + + Args: + status: Status to check + + Returns: + True if order is open + """ + return status in OPEN_STATES + + async def submit_order( + self, + request: OrderRequest, + broker: BrokerBase, + ) -> Order: + """Submit an order through a broker. + + Args: + request: Order request + broker: Broker to submit through + + Returns: + Submitted order + + Raises: + InvalidOrderError: If validation fails + OrderError: If submission fails + """ + # Validate if enabled + if self._validate_before_submit: + validation = self.validate_order(request) + if not validation.valid: + raise InvalidOrderError( + f"Order validation failed: {'; '.join(validation.errors)}" + ) + + # Submit to broker + order = await broker.submit_order(request) + + # Track the order + async with self._lock: + self._orders[order.broker_order_id] = order + self._order_history[order.broker_order_id] = [ + OrderStateChange( + order_id=order.broker_order_id, + from_status=None, + to_status=order.status, + event=OrderEvent.SUBMITTED, + ) + ] + + # Trim old orders if at max + if len(self._orders) > self._max_orders: + # Remove oldest orders + sorted_orders = sorted( + self._orders.items(), + key=lambda x: x[1].created_at or datetime.min, + ) + for order_id, _ in sorted_orders[: len(self._orders) - self._max_orders]: + del self._orders[order_id] + self._order_history.pop(order_id, None) + + # Fire event + await self._fire_event(order, OrderEvent.SUBMITTED) + + # Fire additional events based on status + if order.status == OrderStatus.FILLED: + await self._fire_event(order, OrderEvent.FILLED) + elif order.status == OrderStatus.REJECTED: + await self._fire_event(order, OrderEvent.REJECTED) + + return order + + async def cancel_order( + self, + order_id: str, + broker: BrokerBase, + ) -> Order: + """Cancel an order. + + Args: + order_id: Order to cancel + broker: Broker to cancel through + + Returns: + Cancelled order + + Raises: + OrderError: If cancel fails + """ + order = await broker.cancel_order(order_id) + + async with self._lock: + old_order = self._orders.get(order_id) + old_status = old_order.status if old_order else None + self._orders[order_id] = order + + if order_id in self._order_history: + self._order_history[order_id].append( + OrderStateChange( + order_id=order_id, + from_status=old_status, + to_status=order.status, + event=OrderEvent.CANCELLED, + ) + ) + + await self._fire_event(order, OrderEvent.CANCELLED) + return order + + async def replace_order( + self, + order_id: str, + broker: BrokerBase, + quantity: Optional[Decimal] = None, + limit_price: Optional[Decimal] = None, + stop_price: Optional[Decimal] = None, + time_in_force: Optional[TimeInForce] = None, + ) -> Order: + """Replace an order with updated parameters. + + Args: + order_id: Order to replace + broker: Broker to replace through + quantity: New quantity + limit_price: New limit price + stop_price: New stop price + time_in_force: New time in force + + Returns: + New replacement order + """ + new_order = await broker.replace_order( + order_id, + quantity=quantity, + limit_price=limit_price, + stop_price=stop_price, + time_in_force=time_in_force, + ) + + async with self._lock: + # Mark old order as replaced + if order_id in self._orders: + old_order = self._orders[order_id] + old_order.status = OrderStatus.REPLACED + self._order_history[order_id].append( + OrderStateChange( + order_id=order_id, + from_status=old_order.status, + to_status=OrderStatus.REPLACED, + event=OrderEvent.REPLACED, + metadata={"replaced_by": new_order.broker_order_id}, + ) + ) + + # Track new order + self._orders[new_order.broker_order_id] = new_order + self._order_history[new_order.broker_order_id] = [ + OrderStateChange( + order_id=new_order.broker_order_id, + from_status=None, + to_status=new_order.status, + event=OrderEvent.SUBMITTED, + metadata={"replaces": order_id}, + ) + ] + + await self._fire_event(new_order, OrderEvent.REPLACED) + return new_order + + async def update_order_status( + self, + order: Order, + ) -> None: + """Update tracked order status. + + Called when order status changes (e.g., from broker callbacks). + + Args: + order: Order with updated status + """ + async with self._lock: + old_order = self._orders.get(order.broker_order_id) + old_status = old_order.status if old_order else None + + # Validate transition + if old_status and not self.is_valid_transition(old_status, order.status): + # Log warning but allow - broker is authoritative + pass + + self._orders[order.broker_order_id] = order + + # Record state change + event = self._status_to_event(order.status) + if order.broker_order_id in self._order_history: + self._order_history[order.broker_order_id].append( + OrderStateChange( + order_id=order.broker_order_id, + from_status=old_status, + to_status=order.status, + event=event, + ) + ) + + await self._fire_event(order, event) + + def _status_to_event(self, status: OrderStatus) -> OrderEvent: + """Convert order status to event type.""" + mapping = { + OrderStatus.PENDING_NEW: OrderEvent.SUBMITTED, + OrderStatus.NEW: OrderEvent.ACCEPTED, + OrderStatus.PARTIALLY_FILLED: OrderEvent.PARTIALLY_FILLED, + OrderStatus.FILLED: OrderEvent.FILLED, + OrderStatus.PENDING_CANCEL: OrderEvent.PENDING_CANCEL, + OrderStatus.CANCELLED: OrderEvent.CANCELLED, + OrderStatus.REJECTED: OrderEvent.REJECTED, + OrderStatus.EXPIRED: OrderEvent.EXPIRED, + OrderStatus.REPLACED: OrderEvent.REPLACED, + } + return mapping.get(status, OrderEvent.ERROR) + + def get_order(self, order_id: str) -> Optional[Order]: + """Get tracked order by ID. + + Args: + order_id: Order identifier + + Returns: + Order if found, None otherwise + """ + return self._orders.get(order_id) + + def get_orders( + self, + status: Optional[OrderStatus] = None, + symbol: Optional[str] = None, + side: Optional[OrderSide] = None, + ) -> List[Order]: + """Get tracked orders with optional filters. + + Args: + status: Filter by status + symbol: Filter by symbol + side: Filter by side + + Returns: + List of matching orders + """ + orders = list(self._orders.values()) + + if status: + orders = [o for o in orders if o.status == status] + if symbol: + orders = [o for o in orders if o.symbol == symbol] + if side: + orders = [o for o in orders if o.side == side] + + return orders + + def get_open_orders(self) -> List[Order]: + """Get all open (non-terminal) orders. + + Returns: + List of open orders + """ + return [o for o in self._orders.values() if self.is_open(o.status)] + + def get_order_history(self, order_id: str) -> List[OrderStateChange]: + """Get state change history for an order. + + Args: + order_id: Order identifier + + Returns: + List of state changes + """ + return self._order_history.get(order_id, []) + + def clear_completed_orders(self) -> int: + """Remove all terminal (completed) orders from tracking. + + Returns: + Number of orders removed + """ + to_remove = [ + order_id + for order_id, order in self._orders.items() + if self.is_terminal(order.status) + ] + + for order_id in to_remove: + del self._orders[order_id] + self._order_history.pop(order_id, None) + + return len(to_remove) + + @property + def order_count(self) -> int: + """Get number of tracked orders.""" + return len(self._orders) + + @property + def open_order_count(self) -> int: + """Get number of open orders.""" + return len(self.get_open_orders()) + + +# Export +__all__ = [ + "OrderManager", + "OrderEvent", + "OrderValidationResult", + "OrderStateChange", + "OrderEventCallback", + "VALID_TRANSITIONS", + "TERMINAL_STATES", + "OPEN_STATES", +]