feat(execution): add Order Manager for lifecycle management - Issue #27 (47 tests)

This commit is contained in:
Andrew Kaszubski 2025-12-26 21:24:54 +11:00
parent 834d18fb51
commit 6863e3ed87
3 changed files with 1327 additions and 0 deletions

View File

@ -0,0 +1,659 @@
"""Tests for Order Manager implementation.
Issue #27: [EXEC-26] Order types and manager - market, limit, stop, trailing
"""
from decimal import Decimal
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock
import pytest
from tradingagents.execution import (
OrderManager,
OrderEvent,
OrderValidationResult,
OrderStateChange,
PaperBroker,
OrderRequest,
OrderSide,
OrderType,
OrderStatus,
TimeInForce,
InvalidOrderError,
VALID_TRANSITIONS,
TERMINAL_STATES,
OPEN_STATES,
)
class TestOrderValidationResult:
"""Test OrderValidationResult dataclass."""
def test_default_valid(self):
"""Test default result is valid."""
result = OrderValidationResult()
assert result.valid is True
assert result.errors == []
assert result.warnings == []
def test_invalid_with_errors(self):
"""Test result with errors."""
result = OrderValidationResult(
valid=False,
errors=["Error 1", "Error 2"],
)
assert result.valid is False
assert len(result.errors) == 2
def test_valid_with_warnings(self):
"""Test result with warnings but still valid."""
result = OrderValidationResult(
valid=True,
warnings=["Warning 1"],
)
assert result.valid is True
assert len(result.warnings) == 1
class TestOrderStateChange:
"""Test OrderStateChange dataclass."""
def test_state_change_creation(self):
"""Test creating state change record."""
change = OrderStateChange(
order_id="TEST-123",
from_status=OrderStatus.NEW,
to_status=OrderStatus.FILLED,
event=OrderEvent.FILLED,
)
assert change.order_id == "TEST-123"
assert change.from_status == OrderStatus.NEW
assert change.to_status == OrderStatus.FILLED
assert change.event == OrderEvent.FILLED
assert isinstance(change.timestamp, datetime)
def test_state_change_with_metadata(self):
"""Test state change with metadata."""
change = OrderStateChange(
order_id="TEST-123",
from_status=None,
to_status=OrderStatus.NEW,
event=OrderEvent.SUBMITTED,
metadata={"broker": "paper"},
)
assert change.metadata == {"broker": "paper"}
class TestOrderManagerInit:
"""Test OrderManager initialization."""
def test_default_initialization(self):
"""Test default initialization."""
manager = OrderManager()
assert manager.order_count == 0
assert manager.open_order_count == 0
def test_custom_max_orders(self):
"""Test initialization with custom max orders."""
manager = OrderManager(max_orders=100)
assert manager._max_orders == 100
def test_validation_disabled(self):
"""Test initialization with validation disabled."""
manager = OrderManager(validate_before_submit=False)
assert manager._validate_before_submit is False
class TestOrderValidation:
"""Test OrderManager order validation."""
def test_valid_market_order(self):
"""Test valid market order."""
manager = OrderManager()
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
result = manager.validate_order(request)
assert result.valid is True
def test_invalid_quantity_zero(self):
"""Test invalid order with zero quantity."""
manager = OrderManager()
request = OrderRequest(
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("0"),
)
result = manager.validate_order(request)
assert result.valid is False
assert any("positive" in e.lower() for e in result.errors)
def test_invalid_quantity_negative(self):
"""Test invalid order with negative quantity."""
manager = OrderManager()
request = OrderRequest(
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("-10"),
)
result = manager.validate_order(request)
assert result.valid is False
def test_limit_order_missing_price(self):
"""Test limit order without limit price raises at construction."""
# OrderRequest validates in __post_init__
with pytest.raises(ValueError, match="limit_price"):
OrderRequest(
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
order_type=OrderType.LIMIT,
limit_price=None,
)
def test_limit_order_with_price(self):
"""Test valid limit order."""
manager = OrderManager()
request = OrderRequest.limit(
"AAPL", OrderSide.BUY, Decimal("100"), Decimal("150.00")
)
result = manager.validate_order(request)
assert result.valid is True
def test_stop_order_missing_stop_price(self):
"""Test stop order without stop price raises at construction."""
# OrderRequest validates in __post_init__
with pytest.raises(ValueError, match="stop_price"):
OrderRequest(
symbol="AAPL",
side=OrderSide.SELL,
quantity=Decimal("100"),
order_type=OrderType.STOP,
stop_price=None,
)
def test_stop_limit_order_missing_both(self):
"""Test stop limit order missing both prices raises at construction."""
# OrderRequest validates in __post_init__
with pytest.raises(ValueError, match="stop_price|limit_price"):
OrderRequest(
symbol="AAPL",
side=OrderSide.SELL,
quantity=Decimal("100"),
order_type=OrderType.STOP_LIMIT,
)
def test_trailing_stop_missing_trail(self):
"""Test trailing stop without trail parameters raises at construction."""
# OrderRequest validates in __post_init__
with pytest.raises(ValueError, match="trail"):
OrderRequest(
symbol="AAPL",
side=OrderSide.SELL,
quantity=Decimal("100"),
order_type=OrderType.TRAILING_STOP,
)
def test_trailing_stop_with_percent(self):
"""Test valid trailing stop with percent."""
manager = OrderManager()
request = OrderRequest(
symbol="AAPL",
side=OrderSide.SELL,
quantity=Decimal("100"),
order_type=OrderType.TRAILING_STOP,
trail_percent=Decimal("5.0"),
)
result = manager.validate_order(request)
assert result.valid is True
def test_trailing_stop_high_percent_warning(self):
"""Test trailing stop with high percent warns."""
manager = OrderManager()
request = OrderRequest(
symbol="AAPL",
side=OrderSide.SELL,
quantity=Decimal("100"),
order_type=OrderType.TRAILING_STOP,
trail_percent=Decimal("60.0"),
)
result = manager.validate_order(request)
assert result.valid is True
assert len(result.warnings) > 0
def test_empty_symbol(self):
"""Test order with empty symbol."""
manager = OrderManager()
request = OrderRequest(
symbol="",
side=OrderSide.BUY,
quantity=Decimal("100"),
)
result = manager.validate_order(request)
assert result.valid is False
assert any("symbol" in e.lower() for e in result.errors)
class TestStateTransitions:
"""Test order state machine transitions."""
def test_valid_transitions_defined(self):
"""Test all statuses have defined transitions."""
for status in OrderStatus:
assert status in VALID_TRANSITIONS
def test_terminal_states_immutable(self):
"""Test terminal states have no valid transitions."""
for status in TERMINAL_STATES:
assert len(VALID_TRANSITIONS[status]) == 0
def test_open_states_have_transitions(self):
"""Test open states have transitions."""
for status in OPEN_STATES:
assert len(VALID_TRANSITIONS[status]) > 0
def test_is_valid_transition(self):
"""Test valid transition checking."""
manager = OrderManager()
assert manager.is_valid_transition(OrderStatus.NEW, OrderStatus.FILLED)
assert manager.is_valid_transition(OrderStatus.NEW, OrderStatus.CANCELLED)
assert not manager.is_valid_transition(OrderStatus.FILLED, OrderStatus.NEW)
def test_is_terminal(self):
"""Test terminal status checking."""
manager = OrderManager()
assert manager.is_terminal(OrderStatus.FILLED) is True
assert manager.is_terminal(OrderStatus.CANCELLED) is True
assert manager.is_terminal(OrderStatus.NEW) is False
def test_is_open(self):
"""Test open status checking."""
manager = OrderManager()
assert manager.is_open(OrderStatus.NEW) is True
assert manager.is_open(OrderStatus.PARTIALLY_FILLED) is True
assert manager.is_open(OrderStatus.FILLED) is False
class TestOrderSubmission:
"""Test OrderManager order submission."""
@pytest.mark.asyncio
async def test_submit_order(self):
"""Test basic order submission."""
manager = OrderManager()
broker = PaperBroker()
broker.set_price("AAPL", Decimal("100"))
await broker.connect()
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
order = await manager.submit_order(request, broker)
assert order.symbol == "AAPL"
assert order.status == OrderStatus.FILLED
assert manager.order_count == 1
@pytest.mark.asyncio
async def test_submit_order_validation_fails(self):
"""Test submission fails on invalid order."""
manager = OrderManager()
broker = PaperBroker()
await broker.connect()
request = OrderRequest(
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("-10"),
)
with pytest.raises(InvalidOrderError):
await manager.submit_order(request, broker)
@pytest.mark.asyncio
async def test_submit_order_validation_disabled(self):
"""Test submission with validation disabled."""
manager = OrderManager(validate_before_submit=False)
broker = PaperBroker()
broker.set_price("AAPL", Decimal("100"))
await broker.connect()
# This would normally fail validation but passes with disabled
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
order = await manager.submit_order(request, broker)
assert order is not None
@pytest.mark.asyncio
async def test_order_tracking(self):
"""Test order is tracked after submission."""
manager = OrderManager()
broker = PaperBroker()
broker.set_price("AAPL", Decimal("100"))
await broker.connect()
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
order = await manager.submit_order(request, broker)
tracked = manager.get_order(order.broker_order_id)
assert tracked is not None
assert tracked.broker_order_id == order.broker_order_id
@pytest.mark.asyncio
async def test_order_history_recorded(self):
"""Test order history is recorded."""
manager = OrderManager()
broker = PaperBroker()
broker.set_price("AAPL", Decimal("100"))
await broker.connect()
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
order = await manager.submit_order(request, broker)
history = manager.get_order_history(order.broker_order_id)
assert len(history) >= 1
assert history[0].event == OrderEvent.SUBMITTED
class TestOrderCallbacks:
"""Test OrderManager event callbacks."""
@pytest.mark.asyncio
async def test_register_callback(self):
"""Test registering callback."""
manager = OrderManager()
callback_called = []
async def callback(order, event, metadata):
callback_called.append((order, event))
manager.register_callback(OrderEvent.FILLED, callback)
broker = PaperBroker()
broker.set_price("AAPL", Decimal("100"))
await broker.connect()
await manager.submit_order(
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")),
broker,
)
assert len(callback_called) == 1
assert callback_called[0][1] == OrderEvent.FILLED
@pytest.mark.asyncio
async def test_unregister_callback(self):
"""Test unregistering callback."""
manager = OrderManager()
callback_called = []
async def callback(order, event, metadata):
callback_called.append(True)
manager.register_callback(OrderEvent.FILLED, callback)
manager.unregister_callback(OrderEvent.FILLED, callback)
broker = PaperBroker()
broker.set_price("AAPL", Decimal("100"))
await broker.connect()
await manager.submit_order(
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")),
broker,
)
assert len(callback_called) == 0
@pytest.mark.asyncio
async def test_callback_error_doesnt_break_flow(self):
"""Test callback error doesn't break order flow."""
manager = OrderManager()
async def bad_callback(order, event, metadata):
raise Exception("Callback error")
manager.register_callback(OrderEvent.FILLED, bad_callback)
broker = PaperBroker()
broker.set_price("AAPL", Decimal("100"))
await broker.connect()
# Should not raise despite callback error
order = await manager.submit_order(
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")),
broker,
)
assert order is not None
class TestOrderCancellation:
"""Test OrderManager order cancellation."""
@pytest.mark.asyncio
async def test_cancel_order(self):
"""Test cancelling an order."""
manager = OrderManager()
broker = PaperBroker(fill_probability=0.0)
await broker.connect()
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
order = await manager.submit_order(request, broker)
cancelled = await manager.cancel_order(order.broker_order_id, broker)
assert cancelled.status == OrderStatus.CANCELLED
@pytest.mark.asyncio
async def test_cancel_updates_tracking(self):
"""Test cancellation updates tracked order."""
manager = OrderManager()
broker = PaperBroker(fill_probability=0.0)
await broker.connect()
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
order = await manager.submit_order(request, broker)
await manager.cancel_order(order.broker_order_id, broker)
tracked = manager.get_order(order.broker_order_id)
assert tracked.status == OrderStatus.CANCELLED
class TestOrderReplacement:
"""Test OrderManager order replacement."""
@pytest.mark.asyncio
async def test_replace_order(self):
"""Test replacing an order."""
manager = OrderManager()
broker = PaperBroker(fill_probability=0.0)
await broker.connect()
request = OrderRequest.limit(
"AAPL", OrderSide.BUY, Decimal("10"), Decimal("100")
)
order = await manager.submit_order(request, broker)
new_order = await manager.replace_order(
order.broker_order_id,
broker,
quantity=Decimal("20"),
)
assert new_order.quantity == Decimal("20")
assert new_order.broker_order_id != order.broker_order_id
class TestOrderQueries:
"""Test OrderManager query methods."""
@pytest.mark.asyncio
async def test_get_orders_all(self):
"""Test getting all orders."""
manager = OrderManager()
broker = PaperBroker()
broker.set_price("AAPL", Decimal("10"))
broker.set_price("MSFT", Decimal("10"))
await broker.connect()
await manager.submit_order(
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker
)
await manager.submit_order(
OrderRequest.market("MSFT", OrderSide.BUY, Decimal("10")), broker
)
orders = manager.get_orders()
assert len(orders) == 2
@pytest.mark.asyncio
async def test_get_orders_by_symbol(self):
"""Test filtering orders by symbol."""
manager = OrderManager()
broker = PaperBroker()
broker.set_price("AAPL", Decimal("10"))
broker.set_price("MSFT", Decimal("10"))
await broker.connect()
await manager.submit_order(
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker
)
await manager.submit_order(
OrderRequest.market("MSFT", OrderSide.BUY, Decimal("10")), broker
)
orders = manager.get_orders(symbol="AAPL")
assert len(orders) == 1
assert orders[0].symbol == "AAPL"
@pytest.mark.asyncio
async def test_get_orders_by_status(self):
"""Test filtering orders by status."""
manager = OrderManager()
broker = PaperBroker()
broker.set_price("AAPL", Decimal("10"))
await broker.connect()
# Create filled order
await manager.submit_order(
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker
)
# Create unfilled order
broker._fill_probability = 0.0
await manager.submit_order(
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker
)
filled_orders = manager.get_orders(status=OrderStatus.FILLED)
assert len(filled_orders) == 1
@pytest.mark.asyncio
async def test_get_open_orders(self):
"""Test getting open orders only."""
manager = OrderManager()
broker = PaperBroker(fill_probability=0.0)
await broker.connect()
await manager.submit_order(
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker
)
open_orders = manager.get_open_orders()
assert len(open_orders) == 1
class TestOrderCleanup:
"""Test OrderManager order cleanup."""
@pytest.mark.asyncio
async def test_clear_completed_orders(self):
"""Test clearing completed orders."""
manager = OrderManager()
broker = PaperBroker()
broker.set_price("AAPL", Decimal("10"))
await broker.connect()
# Create filled orders
await manager.submit_order(
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker
)
await manager.submit_order(
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker
)
assert manager.order_count == 2
removed = manager.clear_completed_orders()
assert removed == 2
assert manager.order_count == 0
@pytest.mark.asyncio
async def test_max_orders_limit(self):
"""Test orders are trimmed when max reached."""
manager = OrderManager(max_orders=5)
broker = PaperBroker()
broker.set_price("AAPL", Decimal("1"))
await broker.connect()
# Submit more than max orders
for _ in range(10):
await manager.submit_order(
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("1")), broker
)
assert manager.order_count <= 5
class TestOrderStatusUpdate:
"""Test OrderManager status updates."""
@pytest.mark.asyncio
async def test_update_order_status(self):
"""Test updating order status."""
manager = OrderManager()
broker = PaperBroker(fill_probability=0.0)
await broker.connect()
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
order = await manager.submit_order(request, broker)
# Simulate status change
order.status = OrderStatus.FILLED
order.filled_quantity = order.quantity
await manager.update_order_status(order)
tracked = manager.get_order(order.broker_order_id)
assert tracked.status == OrderStatus.FILLED
class TestOrderEvents:
"""Test OrderEvent enum."""
def test_all_events_defined(self):
"""Test all expected events are defined."""
expected_events = [
"CREATED", "SUBMITTED", "ACCEPTED", "REJECTED",
"PARTIALLY_FILLED", "FILLED", "PENDING_CANCEL",
"CANCELLED", "REPLACED", "EXPIRED", "ERROR"
]
for event_name in expected_events:
assert hasattr(OrderEvent, event_name)
class TestStateConstants:
"""Test state machine constants."""
def test_terminal_states_complete(self):
"""Test all terminal states are included."""
assert OrderStatus.FILLED in TERMINAL_STATES
assert OrderStatus.CANCELLED in TERMINAL_STATES
assert OrderStatus.REJECTED in TERMINAL_STATES
assert OrderStatus.EXPIRED in TERMINAL_STATES
assert OrderStatus.REPLACED in TERMINAL_STATES
def test_open_states_complete(self):
"""Test all open states are included."""
assert OrderStatus.PENDING_NEW in OPEN_STATES
assert OrderStatus.NEW in OPEN_STATES
assert OrderStatus.PARTIALLY_FILLED in OPEN_STATES
assert OrderStatus.PENDING_CANCEL in OPEN_STATES
def test_no_overlap(self):
"""Test terminal and open states don't overlap."""
overlap = TERMINAL_STATES & OPEN_STATES
assert len(overlap) == 0

View File

@ -114,6 +114,16 @@ from .ibkr_broker import (
from .paper_broker import PaperBroker
from .order_manager import (
OrderManager,
OrderEvent,
OrderValidationResult,
OrderStateChange,
VALID_TRANSITIONS,
TERMINAL_STATES,
OPEN_STATES,
)
__all__ = [
# Enums
"AssetClass",
@ -158,4 +168,12 @@ __all__ = [
"FUTURES_SPECS",
# Paper Broker
"PaperBroker",
# Order Manager
"OrderManager",
"OrderEvent",
"OrderValidationResult",
"OrderStateChange",
"VALID_TRANSITIONS",
"TERMINAL_STATES",
"OPEN_STATES",
]

View File

@ -0,0 +1,650 @@
"""Order Manager for order lifecycle management.
Issue #27: [EXEC-26] Order types and manager - market, limit, stop, trailing
This module provides order lifecycle management including validation,
state transitions, and event notifications.
Features:
- Order validation before submission
- Order state machine with valid transitions
- Order tracking and retrieval
- Event callbacks for order state changes
- Support for all order types: market, limit, stop, stop_limit, trailing_stop
Example:
>>> from tradingagents.execution import OrderManager, OrderRequest, OrderSide
>>>
>>> manager = OrderManager()
>>> request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
>>> order = await manager.submit_order(request, broker)
>>> print(f"Order {order.broker_order_id} status: {order.status}")
"""
from __future__ import annotations
import asyncio
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from decimal import Decimal
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Awaitable
from .broker_base import (
BrokerBase,
Order,
OrderError,
OrderRequest,
OrderSide,
OrderStatus,
OrderType,
TimeInForce,
InvalidOrderError,
)
class OrderEvent(Enum):
"""Order lifecycle events."""
CREATED = "created"
SUBMITTED = "submitted"
ACCEPTED = "accepted"
REJECTED = "rejected"
PARTIALLY_FILLED = "partially_filled"
FILLED = "filled"
PENDING_CANCEL = "pending_cancel"
CANCELLED = "cancelled"
REPLACED = "replaced"
EXPIRED = "expired"
ERROR = "error"
# Valid state transitions for order state machine
VALID_TRANSITIONS: Dict[OrderStatus, Set[OrderStatus]] = {
OrderStatus.PENDING_NEW: {
OrderStatus.NEW,
OrderStatus.REJECTED,
OrderStatus.CANCELLED,
},
OrderStatus.NEW: {
OrderStatus.PARTIALLY_FILLED,
OrderStatus.FILLED,
OrderStatus.PENDING_CANCEL,
OrderStatus.CANCELLED,
OrderStatus.EXPIRED,
OrderStatus.REPLACED,
},
OrderStatus.PARTIALLY_FILLED: {
OrderStatus.PARTIALLY_FILLED,
OrderStatus.FILLED,
OrderStatus.PENDING_CANCEL,
OrderStatus.CANCELLED,
},
OrderStatus.FILLED: set(), # Terminal state
OrderStatus.PENDING_CANCEL: {
OrderStatus.CANCELLED,
OrderStatus.FILLED, # Can fill while cancel is pending
OrderStatus.PARTIALLY_FILLED,
},
OrderStatus.CANCELLED: set(), # Terminal state
OrderStatus.REJECTED: set(), # Terminal state
OrderStatus.EXPIRED: set(), # Terminal state
OrderStatus.REPLACED: set(), # Terminal state
}
# Terminal states (order cannot change after reaching these)
TERMINAL_STATES: Set[OrderStatus] = {
OrderStatus.FILLED,
OrderStatus.CANCELLED,
OrderStatus.REJECTED,
OrderStatus.EXPIRED,
OrderStatus.REPLACED,
}
# Open states (order can still be filled or cancelled)
OPEN_STATES: Set[OrderStatus] = {
OrderStatus.PENDING_NEW,
OrderStatus.NEW,
OrderStatus.PARTIALLY_FILLED,
OrderStatus.PENDING_CANCEL,
}
@dataclass
class OrderValidationResult:
"""Result of order validation.
Attributes:
valid: Whether the order is valid
errors: List of validation error messages
warnings: List of validation warning messages
"""
valid: bool = True
errors: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
@dataclass
class OrderStateChange:
"""Record of an order state change.
Attributes:
order_id: Order identifier
from_status: Previous status
to_status: New status
event: Event that triggered the change
timestamp: When the change occurred
metadata: Additional change details
"""
order_id: str
from_status: Optional[OrderStatus]
to_status: OrderStatus
event: OrderEvent
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
metadata: Dict[str, Any] = field(default_factory=dict)
# Callback type for order events
OrderEventCallback = Callable[[Order, OrderEvent, Dict[str, Any]], Awaitable[None]]
class OrderManager:
"""Manages order lifecycle and state transitions.
The OrderManager provides:
- Order validation before submission
- Order state machine with valid transitions
- Order tracking and retrieval
- Event callbacks for order state changes
- Order history and audit trail
Example:
>>> manager = OrderManager()
>>>
>>> # Register callbacks
>>> async def on_fill(order, event, metadata):
... print(f"Order {order.broker_order_id} filled!")
>>> manager.register_callback(OrderEvent.FILLED, on_fill)
>>>
>>> # Submit order
>>> order = await manager.submit_order(request, broker)
"""
def __init__(
self,
max_orders: int = 10000,
validate_before_submit: bool = True,
) -> None:
"""Initialize order manager.
Args:
max_orders: Maximum orders to track (oldest removed when exceeded)
validate_before_submit: Whether to validate orders before submission
"""
self._orders: Dict[str, Order] = {}
self._order_history: Dict[str, List[OrderStateChange]] = {}
self._callbacks: Dict[OrderEvent, List[OrderEventCallback]] = {
event: [] for event in OrderEvent
}
self._max_orders = max_orders
self._validate_before_submit = validate_before_submit
self._lock = asyncio.Lock()
def register_callback(
self,
event: OrderEvent,
callback: OrderEventCallback,
) -> None:
"""Register a callback for an order event.
Args:
event: Event to listen for
callback: Async callback function(order, event, metadata)
"""
self._callbacks[event].append(callback)
def unregister_callback(
self,
event: OrderEvent,
callback: OrderEventCallback,
) -> None:
"""Unregister a callback.
Args:
event: Event type
callback: Callback to remove
"""
if callback in self._callbacks[event]:
self._callbacks[event].remove(callback)
async def _fire_event(
self,
order: Order,
event: OrderEvent,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Fire callbacks for an event.
Args:
order: Order that triggered event
event: Event type
metadata: Additional event data
"""
metadata = metadata or {}
for callback in self._callbacks[event]:
try:
await callback(order, event, metadata)
except Exception:
# Don't let callback errors break order flow
pass
def validate_order(self, request: OrderRequest) -> OrderValidationResult:
"""Validate an order request.
Args:
request: Order request to validate
Returns:
Validation result with errors/warnings
"""
result = OrderValidationResult()
# Validate quantity
if request.quantity <= 0:
result.valid = False
result.errors.append("Quantity must be positive")
# Validate limit price for limit orders
if request.order_type in (OrderType.LIMIT, OrderType.STOP_LIMIT):
if request.limit_price is None:
result.valid = False
result.errors.append(f"{request.order_type.value} order requires limit_price")
elif request.limit_price <= 0:
result.valid = False
result.errors.append("Limit price must be positive")
# Validate stop price for stop orders
if request.order_type in (OrderType.STOP, OrderType.STOP_LIMIT):
if request.stop_price is None:
result.valid = False
result.errors.append(f"{request.order_type.value} order requires stop_price")
elif request.stop_price <= 0:
result.valid = False
result.errors.append("Stop price must be positive")
# Validate trailing stop parameters
if request.order_type == OrderType.TRAILING_STOP:
if request.trail_amount is None and request.trail_percent is None:
result.valid = False
result.errors.append("Trailing stop requires trail_amount or trail_percent")
if request.trail_amount is not None and request.trail_amount <= 0:
result.valid = False
result.errors.append("Trail amount must be positive")
if request.trail_percent is not None:
if request.trail_percent <= 0:
result.valid = False
result.errors.append("Trail percent must be positive")
elif request.trail_percent > Decimal("50"):
result.warnings.append("Trail percent > 50% may execute far from market")
# Validate symbol
if not request.symbol or not request.symbol.strip():
result.valid = False
result.errors.append("Symbol is required")
# Warn about FOK/IOC with limit orders far from market
if request.time_in_force in (TimeInForce.FOK, TimeInForce.IOC):
if request.order_type == OrderType.MARKET:
result.warnings.append(
f"{request.time_in_force.value} with market order may not execute"
)
return result
def is_valid_transition(
self,
from_status: OrderStatus,
to_status: OrderStatus,
) -> bool:
"""Check if a state transition is valid.
Args:
from_status: Current status
to_status: Target status
Returns:
True if transition is valid
"""
return to_status in VALID_TRANSITIONS.get(from_status, set())
def is_terminal(self, status: OrderStatus) -> bool:
"""Check if a status is terminal.
Args:
status: Status to check
Returns:
True if status is terminal
"""
return status in TERMINAL_STATES
def is_open(self, status: OrderStatus) -> bool:
"""Check if a status means order is open.
Args:
status: Status to check
Returns:
True if order is open
"""
return status in OPEN_STATES
async def submit_order(
self,
request: OrderRequest,
broker: BrokerBase,
) -> Order:
"""Submit an order through a broker.
Args:
request: Order request
broker: Broker to submit through
Returns:
Submitted order
Raises:
InvalidOrderError: If validation fails
OrderError: If submission fails
"""
# Validate if enabled
if self._validate_before_submit:
validation = self.validate_order(request)
if not validation.valid:
raise InvalidOrderError(
f"Order validation failed: {'; '.join(validation.errors)}"
)
# Submit to broker
order = await broker.submit_order(request)
# Track the order
async with self._lock:
self._orders[order.broker_order_id] = order
self._order_history[order.broker_order_id] = [
OrderStateChange(
order_id=order.broker_order_id,
from_status=None,
to_status=order.status,
event=OrderEvent.SUBMITTED,
)
]
# Trim old orders if at max
if len(self._orders) > self._max_orders:
# Remove oldest orders
sorted_orders = sorted(
self._orders.items(),
key=lambda x: x[1].created_at or datetime.min,
)
for order_id, _ in sorted_orders[: len(self._orders) - self._max_orders]:
del self._orders[order_id]
self._order_history.pop(order_id, None)
# Fire event
await self._fire_event(order, OrderEvent.SUBMITTED)
# Fire additional events based on status
if order.status == OrderStatus.FILLED:
await self._fire_event(order, OrderEvent.FILLED)
elif order.status == OrderStatus.REJECTED:
await self._fire_event(order, OrderEvent.REJECTED)
return order
async def cancel_order(
self,
order_id: str,
broker: BrokerBase,
) -> Order:
"""Cancel an order.
Args:
order_id: Order to cancel
broker: Broker to cancel through
Returns:
Cancelled order
Raises:
OrderError: If cancel fails
"""
order = await broker.cancel_order(order_id)
async with self._lock:
old_order = self._orders.get(order_id)
old_status = old_order.status if old_order else None
self._orders[order_id] = order
if order_id in self._order_history:
self._order_history[order_id].append(
OrderStateChange(
order_id=order_id,
from_status=old_status,
to_status=order.status,
event=OrderEvent.CANCELLED,
)
)
await self._fire_event(order, OrderEvent.CANCELLED)
return order
async def replace_order(
self,
order_id: str,
broker: BrokerBase,
quantity: Optional[Decimal] = None,
limit_price: Optional[Decimal] = None,
stop_price: Optional[Decimal] = None,
time_in_force: Optional[TimeInForce] = None,
) -> Order:
"""Replace an order with updated parameters.
Args:
order_id: Order to replace
broker: Broker to replace through
quantity: New quantity
limit_price: New limit price
stop_price: New stop price
time_in_force: New time in force
Returns:
New replacement order
"""
new_order = await broker.replace_order(
order_id,
quantity=quantity,
limit_price=limit_price,
stop_price=stop_price,
time_in_force=time_in_force,
)
async with self._lock:
# Mark old order as replaced
if order_id in self._orders:
old_order = self._orders[order_id]
old_order.status = OrderStatus.REPLACED
self._order_history[order_id].append(
OrderStateChange(
order_id=order_id,
from_status=old_order.status,
to_status=OrderStatus.REPLACED,
event=OrderEvent.REPLACED,
metadata={"replaced_by": new_order.broker_order_id},
)
)
# Track new order
self._orders[new_order.broker_order_id] = new_order
self._order_history[new_order.broker_order_id] = [
OrderStateChange(
order_id=new_order.broker_order_id,
from_status=None,
to_status=new_order.status,
event=OrderEvent.SUBMITTED,
metadata={"replaces": order_id},
)
]
await self._fire_event(new_order, OrderEvent.REPLACED)
return new_order
async def update_order_status(
self,
order: Order,
) -> None:
"""Update tracked order status.
Called when order status changes (e.g., from broker callbacks).
Args:
order: Order with updated status
"""
async with self._lock:
old_order = self._orders.get(order.broker_order_id)
old_status = old_order.status if old_order else None
# Validate transition
if old_status and not self.is_valid_transition(old_status, order.status):
# Log warning but allow - broker is authoritative
pass
self._orders[order.broker_order_id] = order
# Record state change
event = self._status_to_event(order.status)
if order.broker_order_id in self._order_history:
self._order_history[order.broker_order_id].append(
OrderStateChange(
order_id=order.broker_order_id,
from_status=old_status,
to_status=order.status,
event=event,
)
)
await self._fire_event(order, event)
def _status_to_event(self, status: OrderStatus) -> OrderEvent:
"""Convert order status to event type."""
mapping = {
OrderStatus.PENDING_NEW: OrderEvent.SUBMITTED,
OrderStatus.NEW: OrderEvent.ACCEPTED,
OrderStatus.PARTIALLY_FILLED: OrderEvent.PARTIALLY_FILLED,
OrderStatus.FILLED: OrderEvent.FILLED,
OrderStatus.PENDING_CANCEL: OrderEvent.PENDING_CANCEL,
OrderStatus.CANCELLED: OrderEvent.CANCELLED,
OrderStatus.REJECTED: OrderEvent.REJECTED,
OrderStatus.EXPIRED: OrderEvent.EXPIRED,
OrderStatus.REPLACED: OrderEvent.REPLACED,
}
return mapping.get(status, OrderEvent.ERROR)
def get_order(self, order_id: str) -> Optional[Order]:
"""Get tracked order by ID.
Args:
order_id: Order identifier
Returns:
Order if found, None otherwise
"""
return self._orders.get(order_id)
def get_orders(
self,
status: Optional[OrderStatus] = None,
symbol: Optional[str] = None,
side: Optional[OrderSide] = None,
) -> List[Order]:
"""Get tracked orders with optional filters.
Args:
status: Filter by status
symbol: Filter by symbol
side: Filter by side
Returns:
List of matching orders
"""
orders = list(self._orders.values())
if status:
orders = [o for o in orders if o.status == status]
if symbol:
orders = [o for o in orders if o.symbol == symbol]
if side:
orders = [o for o in orders if o.side == side]
return orders
def get_open_orders(self) -> List[Order]:
"""Get all open (non-terminal) orders.
Returns:
List of open orders
"""
return [o for o in self._orders.values() if self.is_open(o.status)]
def get_order_history(self, order_id: str) -> List[OrderStateChange]:
"""Get state change history for an order.
Args:
order_id: Order identifier
Returns:
List of state changes
"""
return self._order_history.get(order_id, [])
def clear_completed_orders(self) -> int:
"""Remove all terminal (completed) orders from tracking.
Returns:
Number of orders removed
"""
to_remove = [
order_id
for order_id, order in self._orders.items()
if self.is_terminal(order.status)
]
for order_id in to_remove:
del self._orders[order_id]
self._order_history.pop(order_id, None)
return len(to_remove)
@property
def order_count(self) -> int:
"""Get number of tracked orders."""
return len(self._orders)
@property
def open_order_count(self) -> int:
"""Get number of open orders."""
return len(self.get_open_orders())
# Export
__all__ = [
"OrderManager",
"OrderEvent",
"OrderValidationResult",
"OrderStateChange",
"OrderEventCallback",
"VALID_TRANSITIONS",
"TERMINAL_STATES",
"OPEN_STATES",
]