feat(execution): add Order Manager for lifecycle management - Issue #27 (47 tests)
This commit is contained in:
parent
834d18fb51
commit
6863e3ed87
|
|
@ -0,0 +1,659 @@
|
||||||
|
"""Tests for Order Manager implementation.
|
||||||
|
|
||||||
|
Issue #27: [EXEC-26] Order types and manager - market, limit, stop, trailing
|
||||||
|
"""
|
||||||
|
|
||||||
|
from decimal import Decimal
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.execution import (
|
||||||
|
OrderManager,
|
||||||
|
OrderEvent,
|
||||||
|
OrderValidationResult,
|
||||||
|
OrderStateChange,
|
||||||
|
PaperBroker,
|
||||||
|
OrderRequest,
|
||||||
|
OrderSide,
|
||||||
|
OrderType,
|
||||||
|
OrderStatus,
|
||||||
|
TimeInForce,
|
||||||
|
InvalidOrderError,
|
||||||
|
VALID_TRANSITIONS,
|
||||||
|
TERMINAL_STATES,
|
||||||
|
OPEN_STATES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderValidationResult:
|
||||||
|
"""Test OrderValidationResult dataclass."""
|
||||||
|
|
||||||
|
def test_default_valid(self):
|
||||||
|
"""Test default result is valid."""
|
||||||
|
result = OrderValidationResult()
|
||||||
|
assert result.valid is True
|
||||||
|
assert result.errors == []
|
||||||
|
assert result.warnings == []
|
||||||
|
|
||||||
|
def test_invalid_with_errors(self):
|
||||||
|
"""Test result with errors."""
|
||||||
|
result = OrderValidationResult(
|
||||||
|
valid=False,
|
||||||
|
errors=["Error 1", "Error 2"],
|
||||||
|
)
|
||||||
|
assert result.valid is False
|
||||||
|
assert len(result.errors) == 2
|
||||||
|
|
||||||
|
def test_valid_with_warnings(self):
|
||||||
|
"""Test result with warnings but still valid."""
|
||||||
|
result = OrderValidationResult(
|
||||||
|
valid=True,
|
||||||
|
warnings=["Warning 1"],
|
||||||
|
)
|
||||||
|
assert result.valid is True
|
||||||
|
assert len(result.warnings) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderStateChange:
|
||||||
|
"""Test OrderStateChange dataclass."""
|
||||||
|
|
||||||
|
def test_state_change_creation(self):
|
||||||
|
"""Test creating state change record."""
|
||||||
|
change = OrderStateChange(
|
||||||
|
order_id="TEST-123",
|
||||||
|
from_status=OrderStatus.NEW,
|
||||||
|
to_status=OrderStatus.FILLED,
|
||||||
|
event=OrderEvent.FILLED,
|
||||||
|
)
|
||||||
|
assert change.order_id == "TEST-123"
|
||||||
|
assert change.from_status == OrderStatus.NEW
|
||||||
|
assert change.to_status == OrderStatus.FILLED
|
||||||
|
assert change.event == OrderEvent.FILLED
|
||||||
|
assert isinstance(change.timestamp, datetime)
|
||||||
|
|
||||||
|
def test_state_change_with_metadata(self):
|
||||||
|
"""Test state change with metadata."""
|
||||||
|
change = OrderStateChange(
|
||||||
|
order_id="TEST-123",
|
||||||
|
from_status=None,
|
||||||
|
to_status=OrderStatus.NEW,
|
||||||
|
event=OrderEvent.SUBMITTED,
|
||||||
|
metadata={"broker": "paper"},
|
||||||
|
)
|
||||||
|
assert change.metadata == {"broker": "paper"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderManagerInit:
|
||||||
|
"""Test OrderManager initialization."""
|
||||||
|
|
||||||
|
def test_default_initialization(self):
|
||||||
|
"""Test default initialization."""
|
||||||
|
manager = OrderManager()
|
||||||
|
assert manager.order_count == 0
|
||||||
|
assert manager.open_order_count == 0
|
||||||
|
|
||||||
|
def test_custom_max_orders(self):
|
||||||
|
"""Test initialization with custom max orders."""
|
||||||
|
manager = OrderManager(max_orders=100)
|
||||||
|
assert manager._max_orders == 100
|
||||||
|
|
||||||
|
def test_validation_disabled(self):
|
||||||
|
"""Test initialization with validation disabled."""
|
||||||
|
manager = OrderManager(validate_before_submit=False)
|
||||||
|
assert manager._validate_before_submit is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderValidation:
|
||||||
|
"""Test OrderManager order validation."""
|
||||||
|
|
||||||
|
def test_valid_market_order(self):
|
||||||
|
"""Test valid market order."""
|
||||||
|
manager = OrderManager()
|
||||||
|
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
|
||||||
|
result = manager.validate_order(request)
|
||||||
|
assert result.valid is True
|
||||||
|
|
||||||
|
def test_invalid_quantity_zero(self):
|
||||||
|
"""Test invalid order with zero quantity."""
|
||||||
|
manager = OrderManager()
|
||||||
|
request = OrderRequest(
|
||||||
|
symbol="AAPL",
|
||||||
|
side=OrderSide.BUY,
|
||||||
|
quantity=Decimal("0"),
|
||||||
|
)
|
||||||
|
result = manager.validate_order(request)
|
||||||
|
assert result.valid is False
|
||||||
|
assert any("positive" in e.lower() for e in result.errors)
|
||||||
|
|
||||||
|
def test_invalid_quantity_negative(self):
|
||||||
|
"""Test invalid order with negative quantity."""
|
||||||
|
manager = OrderManager()
|
||||||
|
request = OrderRequest(
|
||||||
|
symbol="AAPL",
|
||||||
|
side=OrderSide.BUY,
|
||||||
|
quantity=Decimal("-10"),
|
||||||
|
)
|
||||||
|
result = manager.validate_order(request)
|
||||||
|
assert result.valid is False
|
||||||
|
|
||||||
|
def test_limit_order_missing_price(self):
|
||||||
|
"""Test limit order without limit price raises at construction."""
|
||||||
|
# OrderRequest validates in __post_init__
|
||||||
|
with pytest.raises(ValueError, match="limit_price"):
|
||||||
|
OrderRequest(
|
||||||
|
symbol="AAPL",
|
||||||
|
side=OrderSide.BUY,
|
||||||
|
quantity=Decimal("100"),
|
||||||
|
order_type=OrderType.LIMIT,
|
||||||
|
limit_price=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_limit_order_with_price(self):
|
||||||
|
"""Test valid limit order."""
|
||||||
|
manager = OrderManager()
|
||||||
|
request = OrderRequest.limit(
|
||||||
|
"AAPL", OrderSide.BUY, Decimal("100"), Decimal("150.00")
|
||||||
|
)
|
||||||
|
result = manager.validate_order(request)
|
||||||
|
assert result.valid is True
|
||||||
|
|
||||||
|
def test_stop_order_missing_stop_price(self):
|
||||||
|
"""Test stop order without stop price raises at construction."""
|
||||||
|
# OrderRequest validates in __post_init__
|
||||||
|
with pytest.raises(ValueError, match="stop_price"):
|
||||||
|
OrderRequest(
|
||||||
|
symbol="AAPL",
|
||||||
|
side=OrderSide.SELL,
|
||||||
|
quantity=Decimal("100"),
|
||||||
|
order_type=OrderType.STOP,
|
||||||
|
stop_price=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_stop_limit_order_missing_both(self):
|
||||||
|
"""Test stop limit order missing both prices raises at construction."""
|
||||||
|
# OrderRequest validates in __post_init__
|
||||||
|
with pytest.raises(ValueError, match="stop_price|limit_price"):
|
||||||
|
OrderRequest(
|
||||||
|
symbol="AAPL",
|
||||||
|
side=OrderSide.SELL,
|
||||||
|
quantity=Decimal("100"),
|
||||||
|
order_type=OrderType.STOP_LIMIT,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_trailing_stop_missing_trail(self):
|
||||||
|
"""Test trailing stop without trail parameters raises at construction."""
|
||||||
|
# OrderRequest validates in __post_init__
|
||||||
|
with pytest.raises(ValueError, match="trail"):
|
||||||
|
OrderRequest(
|
||||||
|
symbol="AAPL",
|
||||||
|
side=OrderSide.SELL,
|
||||||
|
quantity=Decimal("100"),
|
||||||
|
order_type=OrderType.TRAILING_STOP,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_trailing_stop_with_percent(self):
|
||||||
|
"""Test valid trailing stop with percent."""
|
||||||
|
manager = OrderManager()
|
||||||
|
request = OrderRequest(
|
||||||
|
symbol="AAPL",
|
||||||
|
side=OrderSide.SELL,
|
||||||
|
quantity=Decimal("100"),
|
||||||
|
order_type=OrderType.TRAILING_STOP,
|
||||||
|
trail_percent=Decimal("5.0"),
|
||||||
|
)
|
||||||
|
result = manager.validate_order(request)
|
||||||
|
assert result.valid is True
|
||||||
|
|
||||||
|
def test_trailing_stop_high_percent_warning(self):
|
||||||
|
"""Test trailing stop with high percent warns."""
|
||||||
|
manager = OrderManager()
|
||||||
|
request = OrderRequest(
|
||||||
|
symbol="AAPL",
|
||||||
|
side=OrderSide.SELL,
|
||||||
|
quantity=Decimal("100"),
|
||||||
|
order_type=OrderType.TRAILING_STOP,
|
||||||
|
trail_percent=Decimal("60.0"),
|
||||||
|
)
|
||||||
|
result = manager.validate_order(request)
|
||||||
|
assert result.valid is True
|
||||||
|
assert len(result.warnings) > 0
|
||||||
|
|
||||||
|
def test_empty_symbol(self):
|
||||||
|
"""Test order with empty symbol."""
|
||||||
|
manager = OrderManager()
|
||||||
|
request = OrderRequest(
|
||||||
|
symbol="",
|
||||||
|
side=OrderSide.BUY,
|
||||||
|
quantity=Decimal("100"),
|
||||||
|
)
|
||||||
|
result = manager.validate_order(request)
|
||||||
|
assert result.valid is False
|
||||||
|
assert any("symbol" in e.lower() for e in result.errors)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStateTransitions:
|
||||||
|
"""Test order state machine transitions."""
|
||||||
|
|
||||||
|
def test_valid_transitions_defined(self):
|
||||||
|
"""Test all statuses have defined transitions."""
|
||||||
|
for status in OrderStatus:
|
||||||
|
assert status in VALID_TRANSITIONS
|
||||||
|
|
||||||
|
def test_terminal_states_immutable(self):
|
||||||
|
"""Test terminal states have no valid transitions."""
|
||||||
|
for status in TERMINAL_STATES:
|
||||||
|
assert len(VALID_TRANSITIONS[status]) == 0
|
||||||
|
|
||||||
|
def test_open_states_have_transitions(self):
|
||||||
|
"""Test open states have transitions."""
|
||||||
|
for status in OPEN_STATES:
|
||||||
|
assert len(VALID_TRANSITIONS[status]) > 0
|
||||||
|
|
||||||
|
def test_is_valid_transition(self):
|
||||||
|
"""Test valid transition checking."""
|
||||||
|
manager = OrderManager()
|
||||||
|
assert manager.is_valid_transition(OrderStatus.NEW, OrderStatus.FILLED)
|
||||||
|
assert manager.is_valid_transition(OrderStatus.NEW, OrderStatus.CANCELLED)
|
||||||
|
assert not manager.is_valid_transition(OrderStatus.FILLED, OrderStatus.NEW)
|
||||||
|
|
||||||
|
def test_is_terminal(self):
|
||||||
|
"""Test terminal status checking."""
|
||||||
|
manager = OrderManager()
|
||||||
|
assert manager.is_terminal(OrderStatus.FILLED) is True
|
||||||
|
assert manager.is_terminal(OrderStatus.CANCELLED) is True
|
||||||
|
assert manager.is_terminal(OrderStatus.NEW) is False
|
||||||
|
|
||||||
|
def test_is_open(self):
|
||||||
|
"""Test open status checking."""
|
||||||
|
manager = OrderManager()
|
||||||
|
assert manager.is_open(OrderStatus.NEW) is True
|
||||||
|
assert manager.is_open(OrderStatus.PARTIALLY_FILLED) is True
|
||||||
|
assert manager.is_open(OrderStatus.FILLED) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderSubmission:
|
||||||
|
"""Test OrderManager order submission."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_order(self):
|
||||||
|
"""Test basic order submission."""
|
||||||
|
manager = OrderManager()
|
||||||
|
broker = PaperBroker()
|
||||||
|
broker.set_price("AAPL", Decimal("100"))
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||||
|
order = await manager.submit_order(request, broker)
|
||||||
|
|
||||||
|
assert order.symbol == "AAPL"
|
||||||
|
assert order.status == OrderStatus.FILLED
|
||||||
|
assert manager.order_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_order_validation_fails(self):
|
||||||
|
"""Test submission fails on invalid order."""
|
||||||
|
manager = OrderManager()
|
||||||
|
broker = PaperBroker()
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
request = OrderRequest(
|
||||||
|
symbol="AAPL",
|
||||||
|
side=OrderSide.BUY,
|
||||||
|
quantity=Decimal("-10"),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidOrderError):
|
||||||
|
await manager.submit_order(request, broker)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_order_validation_disabled(self):
|
||||||
|
"""Test submission with validation disabled."""
|
||||||
|
manager = OrderManager(validate_before_submit=False)
|
||||||
|
broker = PaperBroker()
|
||||||
|
broker.set_price("AAPL", Decimal("100"))
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
# This would normally fail validation but passes with disabled
|
||||||
|
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||||
|
order = await manager.submit_order(request, broker)
|
||||||
|
assert order is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_order_tracking(self):
|
||||||
|
"""Test order is tracked after submission."""
|
||||||
|
manager = OrderManager()
|
||||||
|
broker = PaperBroker()
|
||||||
|
broker.set_price("AAPL", Decimal("100"))
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||||
|
order = await manager.submit_order(request, broker)
|
||||||
|
|
||||||
|
tracked = manager.get_order(order.broker_order_id)
|
||||||
|
assert tracked is not None
|
||||||
|
assert tracked.broker_order_id == order.broker_order_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_order_history_recorded(self):
|
||||||
|
"""Test order history is recorded."""
|
||||||
|
manager = OrderManager()
|
||||||
|
broker = PaperBroker()
|
||||||
|
broker.set_price("AAPL", Decimal("100"))
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||||
|
order = await manager.submit_order(request, broker)
|
||||||
|
|
||||||
|
history = manager.get_order_history(order.broker_order_id)
|
||||||
|
assert len(history) >= 1
|
||||||
|
assert history[0].event == OrderEvent.SUBMITTED
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderCallbacks:
|
||||||
|
"""Test OrderManager event callbacks."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_callback(self):
|
||||||
|
"""Test registering callback."""
|
||||||
|
manager = OrderManager()
|
||||||
|
callback_called = []
|
||||||
|
|
||||||
|
async def callback(order, event, metadata):
|
||||||
|
callback_called.append((order, event))
|
||||||
|
|
||||||
|
manager.register_callback(OrderEvent.FILLED, callback)
|
||||||
|
|
||||||
|
broker = PaperBroker()
|
||||||
|
broker.set_price("AAPL", Decimal("100"))
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
await manager.submit_order(
|
||||||
|
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")),
|
||||||
|
broker,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(callback_called) == 1
|
||||||
|
assert callback_called[0][1] == OrderEvent.FILLED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unregister_callback(self):
|
||||||
|
"""Test unregistering callback."""
|
||||||
|
manager = OrderManager()
|
||||||
|
callback_called = []
|
||||||
|
|
||||||
|
async def callback(order, event, metadata):
|
||||||
|
callback_called.append(True)
|
||||||
|
|
||||||
|
manager.register_callback(OrderEvent.FILLED, callback)
|
||||||
|
manager.unregister_callback(OrderEvent.FILLED, callback)
|
||||||
|
|
||||||
|
broker = PaperBroker()
|
||||||
|
broker.set_price("AAPL", Decimal("100"))
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
await manager.submit_order(
|
||||||
|
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")),
|
||||||
|
broker,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(callback_called) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_error_doesnt_break_flow(self):
|
||||||
|
"""Test callback error doesn't break order flow."""
|
||||||
|
manager = OrderManager()
|
||||||
|
|
||||||
|
async def bad_callback(order, event, metadata):
|
||||||
|
raise Exception("Callback error")
|
||||||
|
|
||||||
|
manager.register_callback(OrderEvent.FILLED, bad_callback)
|
||||||
|
|
||||||
|
broker = PaperBroker()
|
||||||
|
broker.set_price("AAPL", Decimal("100"))
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
# Should not raise despite callback error
|
||||||
|
order = await manager.submit_order(
|
||||||
|
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")),
|
||||||
|
broker,
|
||||||
|
)
|
||||||
|
assert order is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderCancellation:
|
||||||
|
"""Test OrderManager order cancellation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_order(self):
|
||||||
|
"""Test cancelling an order."""
|
||||||
|
manager = OrderManager()
|
||||||
|
broker = PaperBroker(fill_probability=0.0)
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||||
|
order = await manager.submit_order(request, broker)
|
||||||
|
|
||||||
|
cancelled = await manager.cancel_order(order.broker_order_id, broker)
|
||||||
|
assert cancelled.status == OrderStatus.CANCELLED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_updates_tracking(self):
|
||||||
|
"""Test cancellation updates tracked order."""
|
||||||
|
manager = OrderManager()
|
||||||
|
broker = PaperBroker(fill_probability=0.0)
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||||
|
order = await manager.submit_order(request, broker)
|
||||||
|
|
||||||
|
await manager.cancel_order(order.broker_order_id, broker)
|
||||||
|
|
||||||
|
tracked = manager.get_order(order.broker_order_id)
|
||||||
|
assert tracked.status == OrderStatus.CANCELLED
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderReplacement:
|
||||||
|
"""Test OrderManager order replacement."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replace_order(self):
|
||||||
|
"""Test replacing an order."""
|
||||||
|
manager = OrderManager()
|
||||||
|
broker = PaperBroker(fill_probability=0.0)
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
request = OrderRequest.limit(
|
||||||
|
"AAPL", OrderSide.BUY, Decimal("10"), Decimal("100")
|
||||||
|
)
|
||||||
|
order = await manager.submit_order(request, broker)
|
||||||
|
|
||||||
|
new_order = await manager.replace_order(
|
||||||
|
order.broker_order_id,
|
||||||
|
broker,
|
||||||
|
quantity=Decimal("20"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert new_order.quantity == Decimal("20")
|
||||||
|
assert new_order.broker_order_id != order.broker_order_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderQueries:
|
||||||
|
"""Test OrderManager query methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_orders_all(self):
|
||||||
|
"""Test getting all orders."""
|
||||||
|
manager = OrderManager()
|
||||||
|
broker = PaperBroker()
|
||||||
|
broker.set_price("AAPL", Decimal("10"))
|
||||||
|
broker.set_price("MSFT", Decimal("10"))
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
await manager.submit_order(
|
||||||
|
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker
|
||||||
|
)
|
||||||
|
await manager.submit_order(
|
||||||
|
OrderRequest.market("MSFT", OrderSide.BUY, Decimal("10")), broker
|
||||||
|
)
|
||||||
|
|
||||||
|
orders = manager.get_orders()
|
||||||
|
assert len(orders) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_orders_by_symbol(self):
|
||||||
|
"""Test filtering orders by symbol."""
|
||||||
|
manager = OrderManager()
|
||||||
|
broker = PaperBroker()
|
||||||
|
broker.set_price("AAPL", Decimal("10"))
|
||||||
|
broker.set_price("MSFT", Decimal("10"))
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
await manager.submit_order(
|
||||||
|
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker
|
||||||
|
)
|
||||||
|
await manager.submit_order(
|
||||||
|
OrderRequest.market("MSFT", OrderSide.BUY, Decimal("10")), broker
|
||||||
|
)
|
||||||
|
|
||||||
|
orders = manager.get_orders(symbol="AAPL")
|
||||||
|
assert len(orders) == 1
|
||||||
|
assert orders[0].symbol == "AAPL"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_orders_by_status(self):
|
||||||
|
"""Test filtering orders by status."""
|
||||||
|
manager = OrderManager()
|
||||||
|
broker = PaperBroker()
|
||||||
|
broker.set_price("AAPL", Decimal("10"))
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
# Create filled order
|
||||||
|
await manager.submit_order(
|
||||||
|
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create unfilled order
|
||||||
|
broker._fill_probability = 0.0
|
||||||
|
await manager.submit_order(
|
||||||
|
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker
|
||||||
|
)
|
||||||
|
|
||||||
|
filled_orders = manager.get_orders(status=OrderStatus.FILLED)
|
||||||
|
assert len(filled_orders) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_open_orders(self):
|
||||||
|
"""Test getting open orders only."""
|
||||||
|
manager = OrderManager()
|
||||||
|
broker = PaperBroker(fill_probability=0.0)
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
await manager.submit_order(
|
||||||
|
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker
|
||||||
|
)
|
||||||
|
|
||||||
|
open_orders = manager.get_open_orders()
|
||||||
|
assert len(open_orders) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderCleanup:
|
||||||
|
"""Test OrderManager order cleanup."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_completed_orders(self):
|
||||||
|
"""Test clearing completed orders."""
|
||||||
|
manager = OrderManager()
|
||||||
|
broker = PaperBroker()
|
||||||
|
broker.set_price("AAPL", Decimal("10"))
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
# Create filled orders
|
||||||
|
await manager.submit_order(
|
||||||
|
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker
|
||||||
|
)
|
||||||
|
await manager.submit_order(
|
||||||
|
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")), broker
|
||||||
|
)
|
||||||
|
|
||||||
|
assert manager.order_count == 2
|
||||||
|
|
||||||
|
removed = manager.clear_completed_orders()
|
||||||
|
assert removed == 2
|
||||||
|
assert manager.order_count == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_orders_limit(self):
|
||||||
|
"""Test orders are trimmed when max reached."""
|
||||||
|
manager = OrderManager(max_orders=5)
|
||||||
|
broker = PaperBroker()
|
||||||
|
broker.set_price("AAPL", Decimal("1"))
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
# Submit more than max orders
|
||||||
|
for _ in range(10):
|
||||||
|
await manager.submit_order(
|
||||||
|
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("1")), broker
|
||||||
|
)
|
||||||
|
|
||||||
|
assert manager.order_count <= 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderStatusUpdate:
|
||||||
|
"""Test OrderManager status updates."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_order_status(self):
|
||||||
|
"""Test updating order status."""
|
||||||
|
manager = OrderManager()
|
||||||
|
broker = PaperBroker(fill_probability=0.0)
|
||||||
|
await broker.connect()
|
||||||
|
|
||||||
|
request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||||
|
order = await manager.submit_order(request, broker)
|
||||||
|
|
||||||
|
# Simulate status change
|
||||||
|
order.status = OrderStatus.FILLED
|
||||||
|
order.filled_quantity = order.quantity
|
||||||
|
await manager.update_order_status(order)
|
||||||
|
|
||||||
|
tracked = manager.get_order(order.broker_order_id)
|
||||||
|
assert tracked.status == OrderStatus.FILLED
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrderEvents:
|
||||||
|
"""Test OrderEvent enum."""
|
||||||
|
|
||||||
|
def test_all_events_defined(self):
|
||||||
|
"""Test all expected events are defined."""
|
||||||
|
expected_events = [
|
||||||
|
"CREATED", "SUBMITTED", "ACCEPTED", "REJECTED",
|
||||||
|
"PARTIALLY_FILLED", "FILLED", "PENDING_CANCEL",
|
||||||
|
"CANCELLED", "REPLACED", "EXPIRED", "ERROR"
|
||||||
|
]
|
||||||
|
for event_name in expected_events:
|
||||||
|
assert hasattr(OrderEvent, event_name)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStateConstants:
|
||||||
|
"""Test state machine constants."""
|
||||||
|
|
||||||
|
def test_terminal_states_complete(self):
|
||||||
|
"""Test all terminal states are included."""
|
||||||
|
assert OrderStatus.FILLED in TERMINAL_STATES
|
||||||
|
assert OrderStatus.CANCELLED in TERMINAL_STATES
|
||||||
|
assert OrderStatus.REJECTED in TERMINAL_STATES
|
||||||
|
assert OrderStatus.EXPIRED in TERMINAL_STATES
|
||||||
|
assert OrderStatus.REPLACED in TERMINAL_STATES
|
||||||
|
|
||||||
|
def test_open_states_complete(self):
|
||||||
|
"""Test all open states are included."""
|
||||||
|
assert OrderStatus.PENDING_NEW in OPEN_STATES
|
||||||
|
assert OrderStatus.NEW in OPEN_STATES
|
||||||
|
assert OrderStatus.PARTIALLY_FILLED in OPEN_STATES
|
||||||
|
assert OrderStatus.PENDING_CANCEL in OPEN_STATES
|
||||||
|
|
||||||
|
def test_no_overlap(self):
|
||||||
|
"""Test terminal and open states don't overlap."""
|
||||||
|
overlap = TERMINAL_STATES & OPEN_STATES
|
||||||
|
assert len(overlap) == 0
|
||||||
|
|
@ -114,6 +114,16 @@ from .ibkr_broker import (
|
||||||
|
|
||||||
from .paper_broker import PaperBroker
|
from .paper_broker import PaperBroker
|
||||||
|
|
||||||
|
from .order_manager import (
|
||||||
|
OrderManager,
|
||||||
|
OrderEvent,
|
||||||
|
OrderValidationResult,
|
||||||
|
OrderStateChange,
|
||||||
|
VALID_TRANSITIONS,
|
||||||
|
TERMINAL_STATES,
|
||||||
|
OPEN_STATES,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Enums
|
# Enums
|
||||||
"AssetClass",
|
"AssetClass",
|
||||||
|
|
@ -158,4 +168,12 @@ __all__ = [
|
||||||
"FUTURES_SPECS",
|
"FUTURES_SPECS",
|
||||||
# Paper Broker
|
# Paper Broker
|
||||||
"PaperBroker",
|
"PaperBroker",
|
||||||
|
# Order Manager
|
||||||
|
"OrderManager",
|
||||||
|
"OrderEvent",
|
||||||
|
"OrderValidationResult",
|
||||||
|
"OrderStateChange",
|
||||||
|
"VALID_TRANSITIONS",
|
||||||
|
"TERMINAL_STATES",
|
||||||
|
"OPEN_STATES",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,650 @@
|
||||||
|
"""Order Manager for order lifecycle management.
|
||||||
|
|
||||||
|
Issue #27: [EXEC-26] Order types and manager - market, limit, stop, trailing
|
||||||
|
|
||||||
|
This module provides order lifecycle management including validation,
|
||||||
|
state transitions, and event notifications.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Order validation before submission
|
||||||
|
- Order state machine with valid transitions
|
||||||
|
- Order tracking and retrieval
|
||||||
|
- Event callbacks for order state changes
|
||||||
|
- Support for all order types: market, limit, stop, stop_limit, trailing_stop
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from tradingagents.execution import OrderManager, OrderRequest, OrderSide
|
||||||
|
>>>
|
||||||
|
>>> manager = OrderManager()
|
||||||
|
>>> request = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
|
||||||
|
>>> order = await manager.submit_order(request, broker)
|
||||||
|
>>> print(f"Order {order.broker_order_id} status: {order.status}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from decimal import Decimal
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Set, Awaitable
|
||||||
|
|
||||||
|
from .broker_base import (
|
||||||
|
BrokerBase,
|
||||||
|
Order,
|
||||||
|
OrderError,
|
||||||
|
OrderRequest,
|
||||||
|
OrderSide,
|
||||||
|
OrderStatus,
|
||||||
|
OrderType,
|
||||||
|
TimeInForce,
|
||||||
|
InvalidOrderError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OrderEvent(Enum):
|
||||||
|
"""Order lifecycle events."""
|
||||||
|
|
||||||
|
CREATED = "created"
|
||||||
|
SUBMITTED = "submitted"
|
||||||
|
ACCEPTED = "accepted"
|
||||||
|
REJECTED = "rejected"
|
||||||
|
PARTIALLY_FILLED = "partially_filled"
|
||||||
|
FILLED = "filled"
|
||||||
|
PENDING_CANCEL = "pending_cancel"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
REPLACED = "replaced"
|
||||||
|
EXPIRED = "expired"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
# Valid state transitions for order state machine
|
||||||
|
VALID_TRANSITIONS: Dict[OrderStatus, Set[OrderStatus]] = {
|
||||||
|
OrderStatus.PENDING_NEW: {
|
||||||
|
OrderStatus.NEW,
|
||||||
|
OrderStatus.REJECTED,
|
||||||
|
OrderStatus.CANCELLED,
|
||||||
|
},
|
||||||
|
OrderStatus.NEW: {
|
||||||
|
OrderStatus.PARTIALLY_FILLED,
|
||||||
|
OrderStatus.FILLED,
|
||||||
|
OrderStatus.PENDING_CANCEL,
|
||||||
|
OrderStatus.CANCELLED,
|
||||||
|
OrderStatus.EXPIRED,
|
||||||
|
OrderStatus.REPLACED,
|
||||||
|
},
|
||||||
|
OrderStatus.PARTIALLY_FILLED: {
|
||||||
|
OrderStatus.PARTIALLY_FILLED,
|
||||||
|
OrderStatus.FILLED,
|
||||||
|
OrderStatus.PENDING_CANCEL,
|
||||||
|
OrderStatus.CANCELLED,
|
||||||
|
},
|
||||||
|
OrderStatus.FILLED: set(), # Terminal state
|
||||||
|
OrderStatus.PENDING_CANCEL: {
|
||||||
|
OrderStatus.CANCELLED,
|
||||||
|
OrderStatus.FILLED, # Can fill while cancel is pending
|
||||||
|
OrderStatus.PARTIALLY_FILLED,
|
||||||
|
},
|
||||||
|
OrderStatus.CANCELLED: set(), # Terminal state
|
||||||
|
OrderStatus.REJECTED: set(), # Terminal state
|
||||||
|
OrderStatus.EXPIRED: set(), # Terminal state
|
||||||
|
OrderStatus.REPLACED: set(), # Terminal state
|
||||||
|
}
|
||||||
|
|
||||||
|
# Terminal states (order cannot change after reaching these)
|
||||||
|
TERMINAL_STATES: Set[OrderStatus] = {
|
||||||
|
OrderStatus.FILLED,
|
||||||
|
OrderStatus.CANCELLED,
|
||||||
|
OrderStatus.REJECTED,
|
||||||
|
OrderStatus.EXPIRED,
|
||||||
|
OrderStatus.REPLACED,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Open states (order can still be filled or cancelled)
|
||||||
|
OPEN_STATES: Set[OrderStatus] = {
|
||||||
|
OrderStatus.PENDING_NEW,
|
||||||
|
OrderStatus.NEW,
|
||||||
|
OrderStatus.PARTIALLY_FILLED,
|
||||||
|
OrderStatus.PENDING_CANCEL,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OrderValidationResult:
|
||||||
|
"""Result of order validation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
valid: Whether the order is valid
|
||||||
|
errors: List of validation error messages
|
||||||
|
warnings: List of validation warning messages
|
||||||
|
"""
|
||||||
|
valid: bool = True
|
||||||
|
errors: List[str] = field(default_factory=list)
|
||||||
|
warnings: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OrderStateChange:
|
||||||
|
"""Record of an order state change.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
order_id: Order identifier
|
||||||
|
from_status: Previous status
|
||||||
|
to_status: New status
|
||||||
|
event: Event that triggered the change
|
||||||
|
timestamp: When the change occurred
|
||||||
|
metadata: Additional change details
|
||||||
|
"""
|
||||||
|
order_id: str
|
||||||
|
from_status: Optional[OrderStatus]
|
||||||
|
to_status: OrderStatus
|
||||||
|
event: OrderEvent
|
||||||
|
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
# Callback type for order events
|
||||||
|
OrderEventCallback = Callable[[Order, OrderEvent, Dict[str, Any]], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
|
class OrderManager:
|
||||||
|
"""Manages order lifecycle and state transitions.
|
||||||
|
|
||||||
|
The OrderManager provides:
|
||||||
|
- Order validation before submission
|
||||||
|
- Order state machine with valid transitions
|
||||||
|
- Order tracking and retrieval
|
||||||
|
- Event callbacks for order state changes
|
||||||
|
- Order history and audit trail
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> manager = OrderManager()
|
||||||
|
>>>
|
||||||
|
>>> # Register callbacks
|
||||||
|
>>> async def on_fill(order, event, metadata):
|
||||||
|
... print(f"Order {order.broker_order_id} filled!")
|
||||||
|
>>> manager.register_callback(OrderEvent.FILLED, on_fill)
|
||||||
|
>>>
|
||||||
|
>>> # Submit order
|
||||||
|
>>> order = await manager.submit_order(request, broker)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_orders: int = 10000,
|
||||||
|
validate_before_submit: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize order manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_orders: Maximum orders to track (oldest removed when exceeded)
|
||||||
|
validate_before_submit: Whether to validate orders before submission
|
||||||
|
"""
|
||||||
|
self._orders: Dict[str, Order] = {}
|
||||||
|
self._order_history: Dict[str, List[OrderStateChange]] = {}
|
||||||
|
self._callbacks: Dict[OrderEvent, List[OrderEventCallback]] = {
|
||||||
|
event: [] for event in OrderEvent
|
||||||
|
}
|
||||||
|
self._max_orders = max_orders
|
||||||
|
self._validate_before_submit = validate_before_submit
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def register_callback(
|
||||||
|
self,
|
||||||
|
event: OrderEvent,
|
||||||
|
callback: OrderEventCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Register a callback for an order event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: Event to listen for
|
||||||
|
callback: Async callback function(order, event, metadata)
|
||||||
|
"""
|
||||||
|
self._callbacks[event].append(callback)
|
||||||
|
|
||||||
|
def unregister_callback(
|
||||||
|
self,
|
||||||
|
event: OrderEvent,
|
||||||
|
callback: OrderEventCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Unregister a callback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: Event type
|
||||||
|
callback: Callback to remove
|
||||||
|
"""
|
||||||
|
if callback in self._callbacks[event]:
|
||||||
|
self._callbacks[event].remove(callback)
|
||||||
|
|
||||||
|
async def _fire_event(
|
||||||
|
self,
|
||||||
|
order: Order,
|
||||||
|
event: OrderEvent,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Fire callbacks for an event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
order: Order that triggered event
|
||||||
|
event: Event type
|
||||||
|
metadata: Additional event data
|
||||||
|
"""
|
||||||
|
metadata = metadata or {}
|
||||||
|
for callback in self._callbacks[event]:
|
||||||
|
try:
|
||||||
|
await callback(order, event, metadata)
|
||||||
|
except Exception:
|
||||||
|
# Don't let callback errors break order flow
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validate_order(self, request: OrderRequest) -> OrderValidationResult:
|
||||||
|
"""Validate an order request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Order request to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validation result with errors/warnings
|
||||||
|
"""
|
||||||
|
result = OrderValidationResult()
|
||||||
|
|
||||||
|
# Validate quantity
|
||||||
|
if request.quantity <= 0:
|
||||||
|
result.valid = False
|
||||||
|
result.errors.append("Quantity must be positive")
|
||||||
|
|
||||||
|
# Validate limit price for limit orders
|
||||||
|
if request.order_type in (OrderType.LIMIT, OrderType.STOP_LIMIT):
|
||||||
|
if request.limit_price is None:
|
||||||
|
result.valid = False
|
||||||
|
result.errors.append(f"{request.order_type.value} order requires limit_price")
|
||||||
|
elif request.limit_price <= 0:
|
||||||
|
result.valid = False
|
||||||
|
result.errors.append("Limit price must be positive")
|
||||||
|
|
||||||
|
# Validate stop price for stop orders
|
||||||
|
if request.order_type in (OrderType.STOP, OrderType.STOP_LIMIT):
|
||||||
|
if request.stop_price is None:
|
||||||
|
result.valid = False
|
||||||
|
result.errors.append(f"{request.order_type.value} order requires stop_price")
|
||||||
|
elif request.stop_price <= 0:
|
||||||
|
result.valid = False
|
||||||
|
result.errors.append("Stop price must be positive")
|
||||||
|
|
||||||
|
# Validate trailing stop parameters
|
||||||
|
if request.order_type == OrderType.TRAILING_STOP:
|
||||||
|
if request.trail_amount is None and request.trail_percent is None:
|
||||||
|
result.valid = False
|
||||||
|
result.errors.append("Trailing stop requires trail_amount or trail_percent")
|
||||||
|
if request.trail_amount is not None and request.trail_amount <= 0:
|
||||||
|
result.valid = False
|
||||||
|
result.errors.append("Trail amount must be positive")
|
||||||
|
if request.trail_percent is not None:
|
||||||
|
if request.trail_percent <= 0:
|
||||||
|
result.valid = False
|
||||||
|
result.errors.append("Trail percent must be positive")
|
||||||
|
elif request.trail_percent > Decimal("50"):
|
||||||
|
result.warnings.append("Trail percent > 50% may execute far from market")
|
||||||
|
|
||||||
|
# Validate symbol
|
||||||
|
if not request.symbol or not request.symbol.strip():
|
||||||
|
result.valid = False
|
||||||
|
result.errors.append("Symbol is required")
|
||||||
|
|
||||||
|
# Warn about FOK/IOC with limit orders far from market
|
||||||
|
if request.time_in_force in (TimeInForce.FOK, TimeInForce.IOC):
|
||||||
|
if request.order_type == OrderType.MARKET:
|
||||||
|
result.warnings.append(
|
||||||
|
f"{request.time_in_force.value} with market order may not execute"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def is_valid_transition(
|
||||||
|
self,
|
||||||
|
from_status: OrderStatus,
|
||||||
|
to_status: OrderStatus,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if a state transition is valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_status: Current status
|
||||||
|
to_status: Target status
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if transition is valid
|
||||||
|
"""
|
||||||
|
return to_status in VALID_TRANSITIONS.get(from_status, set())
|
||||||
|
|
||||||
|
def is_terminal(self, status: OrderStatus) -> bool:
|
||||||
|
"""Check if a status is terminal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: Status to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if status is terminal
|
||||||
|
"""
|
||||||
|
return status in TERMINAL_STATES
|
||||||
|
|
||||||
|
def is_open(self, status: OrderStatus) -> bool:
|
||||||
|
"""Check if a status means order is open.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: Status to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if order is open
|
||||||
|
"""
|
||||||
|
return status in OPEN_STATES
|
||||||
|
|
||||||
|
async def submit_order(
|
||||||
|
self,
|
||||||
|
request: OrderRequest,
|
||||||
|
broker: BrokerBase,
|
||||||
|
) -> Order:
|
||||||
|
"""Submit an order through a broker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Order request
|
||||||
|
broker: Broker to submit through
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Submitted order
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidOrderError: If validation fails
|
||||||
|
OrderError: If submission fails
|
||||||
|
"""
|
||||||
|
# Validate if enabled
|
||||||
|
if self._validate_before_submit:
|
||||||
|
validation = self.validate_order(request)
|
||||||
|
if not validation.valid:
|
||||||
|
raise InvalidOrderError(
|
||||||
|
f"Order validation failed: {'; '.join(validation.errors)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Submit to broker
|
||||||
|
order = await broker.submit_order(request)
|
||||||
|
|
||||||
|
# Track the order
|
||||||
|
async with self._lock:
|
||||||
|
self._orders[order.broker_order_id] = order
|
||||||
|
self._order_history[order.broker_order_id] = [
|
||||||
|
OrderStateChange(
|
||||||
|
order_id=order.broker_order_id,
|
||||||
|
from_status=None,
|
||||||
|
to_status=order.status,
|
||||||
|
event=OrderEvent.SUBMITTED,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Trim old orders if at max
|
||||||
|
if len(self._orders) > self._max_orders:
|
||||||
|
# Remove oldest orders
|
||||||
|
sorted_orders = sorted(
|
||||||
|
self._orders.items(),
|
||||||
|
key=lambda x: x[1].created_at or datetime.min,
|
||||||
|
)
|
||||||
|
for order_id, _ in sorted_orders[: len(self._orders) - self._max_orders]:
|
||||||
|
del self._orders[order_id]
|
||||||
|
self._order_history.pop(order_id, None)
|
||||||
|
|
||||||
|
# Fire event
|
||||||
|
await self._fire_event(order, OrderEvent.SUBMITTED)
|
||||||
|
|
||||||
|
# Fire additional events based on status
|
||||||
|
if order.status == OrderStatus.FILLED:
|
||||||
|
await self._fire_event(order, OrderEvent.FILLED)
|
||||||
|
elif order.status == OrderStatus.REJECTED:
|
||||||
|
await self._fire_event(order, OrderEvent.REJECTED)
|
||||||
|
|
||||||
|
return order
|
||||||
|
|
||||||
|
async def cancel_order(
|
||||||
|
self,
|
||||||
|
order_id: str,
|
||||||
|
broker: BrokerBase,
|
||||||
|
) -> Order:
|
||||||
|
"""Cancel an order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
order_id: Order to cancel
|
||||||
|
broker: Broker to cancel through
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cancelled order
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OrderError: If cancel fails
|
||||||
|
"""
|
||||||
|
order = await broker.cancel_order(order_id)
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
old_order = self._orders.get(order_id)
|
||||||
|
old_status = old_order.status if old_order else None
|
||||||
|
self._orders[order_id] = order
|
||||||
|
|
||||||
|
if order_id in self._order_history:
|
||||||
|
self._order_history[order_id].append(
|
||||||
|
OrderStateChange(
|
||||||
|
order_id=order_id,
|
||||||
|
from_status=old_status,
|
||||||
|
to_status=order.status,
|
||||||
|
event=OrderEvent.CANCELLED,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._fire_event(order, OrderEvent.CANCELLED)
|
||||||
|
return order
|
||||||
|
|
||||||
|
async def replace_order(
|
||||||
|
self,
|
||||||
|
order_id: str,
|
||||||
|
broker: BrokerBase,
|
||||||
|
quantity: Optional[Decimal] = None,
|
||||||
|
limit_price: Optional[Decimal] = None,
|
||||||
|
stop_price: Optional[Decimal] = None,
|
||||||
|
time_in_force: Optional[TimeInForce] = None,
|
||||||
|
) -> Order:
|
||||||
|
"""Replace an order with updated parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
order_id: Order to replace
|
||||||
|
broker: Broker to replace through
|
||||||
|
quantity: New quantity
|
||||||
|
limit_price: New limit price
|
||||||
|
stop_price: New stop price
|
||||||
|
time_in_force: New time in force
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New replacement order
|
||||||
|
"""
|
||||||
|
new_order = await broker.replace_order(
|
||||||
|
order_id,
|
||||||
|
quantity=quantity,
|
||||||
|
limit_price=limit_price,
|
||||||
|
stop_price=stop_price,
|
||||||
|
time_in_force=time_in_force,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
# Mark old order as replaced
|
||||||
|
if order_id in self._orders:
|
||||||
|
old_order = self._orders[order_id]
|
||||||
|
old_order.status = OrderStatus.REPLACED
|
||||||
|
self._order_history[order_id].append(
|
||||||
|
OrderStateChange(
|
||||||
|
order_id=order_id,
|
||||||
|
from_status=old_order.status,
|
||||||
|
to_status=OrderStatus.REPLACED,
|
||||||
|
event=OrderEvent.REPLACED,
|
||||||
|
metadata={"replaced_by": new_order.broker_order_id},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track new order
|
||||||
|
self._orders[new_order.broker_order_id] = new_order
|
||||||
|
self._order_history[new_order.broker_order_id] = [
|
||||||
|
OrderStateChange(
|
||||||
|
order_id=new_order.broker_order_id,
|
||||||
|
from_status=None,
|
||||||
|
to_status=new_order.status,
|
||||||
|
event=OrderEvent.SUBMITTED,
|
||||||
|
metadata={"replaces": order_id},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
await self._fire_event(new_order, OrderEvent.REPLACED)
|
||||||
|
return new_order
|
||||||
|
|
||||||
|
async def update_order_status(
|
||||||
|
self,
|
||||||
|
order: Order,
|
||||||
|
) -> None:
|
||||||
|
"""Update tracked order status.
|
||||||
|
|
||||||
|
Called when order status changes (e.g., from broker callbacks).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
order: Order with updated status
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
old_order = self._orders.get(order.broker_order_id)
|
||||||
|
old_status = old_order.status if old_order else None
|
||||||
|
|
||||||
|
# Validate transition
|
||||||
|
if old_status and not self.is_valid_transition(old_status, order.status):
|
||||||
|
# Log warning but allow - broker is authoritative
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._orders[order.broker_order_id] = order
|
||||||
|
|
||||||
|
# Record state change
|
||||||
|
event = self._status_to_event(order.status)
|
||||||
|
if order.broker_order_id in self._order_history:
|
||||||
|
self._order_history[order.broker_order_id].append(
|
||||||
|
OrderStateChange(
|
||||||
|
order_id=order.broker_order_id,
|
||||||
|
from_status=old_status,
|
||||||
|
to_status=order.status,
|
||||||
|
event=event,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._fire_event(order, event)
|
||||||
|
|
||||||
|
def _status_to_event(self, status: OrderStatus) -> OrderEvent:
|
||||||
|
"""Convert order status to event type."""
|
||||||
|
mapping = {
|
||||||
|
OrderStatus.PENDING_NEW: OrderEvent.SUBMITTED,
|
||||||
|
OrderStatus.NEW: OrderEvent.ACCEPTED,
|
||||||
|
OrderStatus.PARTIALLY_FILLED: OrderEvent.PARTIALLY_FILLED,
|
||||||
|
OrderStatus.FILLED: OrderEvent.FILLED,
|
||||||
|
OrderStatus.PENDING_CANCEL: OrderEvent.PENDING_CANCEL,
|
||||||
|
OrderStatus.CANCELLED: OrderEvent.CANCELLED,
|
||||||
|
OrderStatus.REJECTED: OrderEvent.REJECTED,
|
||||||
|
OrderStatus.EXPIRED: OrderEvent.EXPIRED,
|
||||||
|
OrderStatus.REPLACED: OrderEvent.REPLACED,
|
||||||
|
}
|
||||||
|
return mapping.get(status, OrderEvent.ERROR)
|
||||||
|
|
||||||
|
def get_order(self, order_id: str) -> Optional[Order]:
|
||||||
|
"""Get tracked order by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
order_id: Order identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Order if found, None otherwise
|
||||||
|
"""
|
||||||
|
return self._orders.get(order_id)
|
||||||
|
|
||||||
|
def get_orders(
|
||||||
|
self,
|
||||||
|
status: Optional[OrderStatus] = None,
|
||||||
|
symbol: Optional[str] = None,
|
||||||
|
side: Optional[OrderSide] = None,
|
||||||
|
) -> List[Order]:
|
||||||
|
"""Get tracked orders with optional filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: Filter by status
|
||||||
|
symbol: Filter by symbol
|
||||||
|
side: Filter by side
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching orders
|
||||||
|
"""
|
||||||
|
orders = list(self._orders.values())
|
||||||
|
|
||||||
|
if status:
|
||||||
|
orders = [o for o in orders if o.status == status]
|
||||||
|
if symbol:
|
||||||
|
orders = [o for o in orders if o.symbol == symbol]
|
||||||
|
if side:
|
||||||
|
orders = [o for o in orders if o.side == side]
|
||||||
|
|
||||||
|
return orders
|
||||||
|
|
||||||
|
def get_open_orders(self) -> List[Order]:
|
||||||
|
"""Get all open (non-terminal) orders.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of open orders
|
||||||
|
"""
|
||||||
|
return [o for o in self._orders.values() if self.is_open(o.status)]
|
||||||
|
|
||||||
|
def get_order_history(self, order_id: str) -> List[OrderStateChange]:
|
||||||
|
"""Get state change history for an order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
order_id: Order identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of state changes
|
||||||
|
"""
|
||||||
|
return self._order_history.get(order_id, [])
|
||||||
|
|
||||||
|
def clear_completed_orders(self) -> int:
|
||||||
|
"""Remove all terminal (completed) orders from tracking.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of orders removed
|
||||||
|
"""
|
||||||
|
to_remove = [
|
||||||
|
order_id
|
||||||
|
for order_id, order in self._orders.items()
|
||||||
|
if self.is_terminal(order.status)
|
||||||
|
]
|
||||||
|
|
||||||
|
for order_id in to_remove:
|
||||||
|
del self._orders[order_id]
|
||||||
|
self._order_history.pop(order_id, None)
|
||||||
|
|
||||||
|
return len(to_remove)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def order_count(self) -> int:
|
||||||
|
"""Get number of tracked orders."""
|
||||||
|
return len(self._orders)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def open_order_count(self) -> int:
|
||||||
|
"""Get number of open orders."""
|
||||||
|
return len(self.get_open_orders())
|
||||||
|
|
||||||
|
|
||||||
|
# Export
|
||||||
|
__all__ = [
|
||||||
|
"OrderManager",
|
||||||
|
"OrderEvent",
|
||||||
|
"OrderValidationResult",
|
||||||
|
"OrderStateChange",
|
||||||
|
"OrderEventCallback",
|
||||||
|
"VALID_TRANSITIONS",
|
||||||
|
"TERMINAL_STATES",
|
||||||
|
"OPEN_STATES",
|
||||||
|
]
|
||||||
Loading…
Reference in New Issue