feat(execution): add broker router for multi-broker asset class routing - Issue #23 (57 tests)
This commit is contained in:
parent
e4ef947c3b
commit
850346a47a
|
|
@ -0,0 +1,967 @@
|
|||
"""Tests for Broker Router module.
|
||||
|
||||
Issue #23: [EXEC-22] Broker router - route by asset class
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from tradingagents.execution import (
|
||||
# Enums
|
||||
AssetClass,
|
||||
OrderSide,
|
||||
OrderType,
|
||||
TimeInForce,
|
||||
OrderStatus,
|
||||
PositionSide,
|
||||
# Data Classes
|
||||
OrderRequest,
|
||||
Order,
|
||||
Position,
|
||||
AccountInfo,
|
||||
Quote,
|
||||
AssetInfo,
|
||||
# Exceptions
|
||||
BrokerError,
|
||||
OrderError,
|
||||
PositionError,
|
||||
# Base Class
|
||||
BrokerBase,
|
||||
# Router
|
||||
BrokerRouter,
|
||||
BrokerRegistration,
|
||||
RoutingDecision,
|
||||
SymbolClassifier,
|
||||
RoutingError,
|
||||
NoBrokerError,
|
||||
BrokerNotFoundError,
|
||||
DuplicateBrokerError,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mock Broker for Testing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MockBroker(BrokerBase):
|
||||
"""Mock broker for testing router functionality."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
asset_classes: List[AssetClass],
|
||||
paper_trading: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
supported_asset_classes=asset_classes,
|
||||
paper_trading=paper_trading,
|
||||
)
|
||||
self._orders: Dict[str, Order] = {}
|
||||
self._positions: Dict[str, Position] = {}
|
||||
self._order_counter = 0
|
||||
|
||||
async def connect(self) -> bool:
|
||||
self._connected = True
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._connected = False
|
||||
|
||||
async def is_market_open(self) -> bool:
|
||||
return True
|
||||
|
||||
async def get_account(self) -> AccountInfo:
|
||||
return AccountInfo(
|
||||
account_id=f"{self._name}_account",
|
||||
account_type="margin",
|
||||
status="active",
|
||||
cash=Decimal("100000"),
|
||||
portfolio_value=Decimal("150000"),
|
||||
buying_power=Decimal("200000"),
|
||||
equity=Decimal("150000"),
|
||||
)
|
||||
|
||||
async def submit_order(self, request: OrderRequest) -> Order:
|
||||
self._order_counter += 1
|
||||
order = Order(
|
||||
broker_order_id=f"{self._name}-ORD-{self._order_counter}",
|
||||
client_order_id=request.client_order_id,
|
||||
symbol=request.symbol,
|
||||
side=request.side,
|
||||
quantity=request.quantity,
|
||||
order_type=request.order_type,
|
||||
status=OrderStatus.NEW,
|
||||
limit_price=request.limit_price,
|
||||
stop_price=request.stop_price,
|
||||
time_in_force=request.time_in_force,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
self._orders[order.broker_order_id] = order
|
||||
return order
|
||||
|
||||
async def cancel_order(self, order_id: str) -> Order:
|
||||
if order_id not in self._orders:
|
||||
raise OrderError(f"Order {order_id} not found")
|
||||
order = self._orders[order_id]
|
||||
order.status = OrderStatus.CANCELLED
|
||||
return order
|
||||
|
||||
async def replace_order(
|
||||
self,
|
||||
order_id: str,
|
||||
quantity: Optional[Decimal] = None,
|
||||
limit_price: Optional[Decimal] = None,
|
||||
stop_price: Optional[Decimal] = None,
|
||||
time_in_force: Optional[TimeInForce] = None,
|
||||
) -> Order:
|
||||
old_order = self._orders.get(order_id)
|
||||
if not old_order:
|
||||
raise OrderError(f"Order {order_id} not found")
|
||||
|
||||
self._order_counter += 1
|
||||
new_order = Order(
|
||||
broker_order_id=f"{self._name}-ORD-{self._order_counter}",
|
||||
client_order_id=old_order.client_order_id,
|
||||
symbol=old_order.symbol,
|
||||
side=old_order.side,
|
||||
quantity=quantity or old_order.quantity,
|
||||
order_type=old_order.order_type,
|
||||
status=OrderStatus.NEW,
|
||||
limit_price=limit_price or old_order.limit_price,
|
||||
stop_price=stop_price or old_order.stop_price,
|
||||
time_in_force=time_in_force or old_order.time_in_force,
|
||||
)
|
||||
self._orders[new_order.broker_order_id] = new_order
|
||||
return new_order
|
||||
|
||||
async def get_order(self, order_id: str) -> Order:
|
||||
if order_id not in self._orders:
|
||||
raise OrderError(f"Order {order_id} not found")
|
||||
return self._orders[order_id]
|
||||
|
||||
async def get_orders(
|
||||
self,
|
||||
status: Optional[OrderStatus] = None,
|
||||
limit: int = 100,
|
||||
symbols: Optional[List[str]] = None,
|
||||
) -> List[Order]:
|
||||
orders = list(self._orders.values())
|
||||
if status:
|
||||
orders = [o for o in orders if o.status == status]
|
||||
if symbols:
|
||||
orders = [o for o in orders if o.symbol in symbols]
|
||||
return orders[:limit]
|
||||
|
||||
async def get_positions(self) -> List[Position]:
|
||||
return list(self._positions.values())
|
||||
|
||||
async def get_position(self, symbol: str) -> Optional[Position]:
|
||||
return self._positions.get(symbol)
|
||||
|
||||
async def get_quote(self, symbol: str) -> Quote:
|
||||
return Quote(
|
||||
symbol=symbol,
|
||||
bid_price=Decimal("100.00"),
|
||||
ask_price=Decimal("100.05"),
|
||||
last_price=Decimal("100.02"),
|
||||
volume=1000000,
|
||||
)
|
||||
|
||||
async def get_asset(self, symbol: str) -> AssetInfo:
|
||||
return AssetInfo(
|
||||
symbol=symbol,
|
||||
name=f"{symbol} Asset",
|
||||
tradable=True,
|
||||
)
|
||||
|
||||
def add_position(self, position: Position) -> None:
|
||||
"""Helper to add test positions."""
|
||||
self._positions[position.symbol] = position
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SymbolClassifier Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSymbolClassifier:
|
||||
"""Tests for SymbolClassifier."""
|
||||
|
||||
def test_classify_equity(self):
|
||||
"""Test equity classification."""
|
||||
classifier = SymbolClassifier()
|
||||
|
||||
assert classifier.classify("AAPL") == AssetClass.EQUITY
|
||||
assert classifier.classify("MSFT") == AssetClass.EQUITY
|
||||
assert classifier.classify("GOOGL") == AssetClass.EQUITY
|
||||
|
||||
def test_classify_etf(self):
|
||||
"""Test ETF classification."""
|
||||
classifier = SymbolClassifier()
|
||||
|
||||
assert classifier.classify("SPY") == AssetClass.ETF
|
||||
assert classifier.classify("QQQ") == AssetClass.ETF
|
||||
assert classifier.classify("VTI") == AssetClass.ETF
|
||||
|
||||
def test_classify_crypto(self):
|
||||
"""Test crypto classification."""
|
||||
classifier = SymbolClassifier()
|
||||
|
||||
assert classifier.classify("BTCUSD") == AssetClass.CRYPTO
|
||||
assert classifier.classify("ETHUSD") == AssetClass.CRYPTO
|
||||
assert classifier.classify("BTC") == AssetClass.CRYPTO
|
||||
|
||||
def test_classify_future(self):
|
||||
"""Test futures classification."""
|
||||
classifier = SymbolClassifier()
|
||||
|
||||
assert classifier.classify("ESZ24") == AssetClass.FUTURE # S&P 500 Dec 2024
|
||||
assert classifier.classify("CLF25") == AssetClass.FUTURE # Crude Oil Jan 2025
|
||||
assert classifier.classify("GCG24") == AssetClass.FUTURE # Gold Feb 2024
|
||||
|
||||
def test_custom_mapping(self):
|
||||
"""Test custom symbol mapping."""
|
||||
classifier = SymbolClassifier()
|
||||
|
||||
# Add custom mapping
|
||||
classifier.add_mapping("CUSTOM", AssetClass.BOND)
|
||||
|
||||
assert classifier.classify("CUSTOM") == AssetClass.BOND
|
||||
|
||||
def test_custom_mapping_overrides_default(self):
|
||||
"""Test that custom mappings override defaults."""
|
||||
classifier = SymbolClassifier()
|
||||
|
||||
# SPY is normally an ETF
|
||||
assert classifier.classify("SPY") == AssetClass.ETF
|
||||
|
||||
# Override with custom mapping
|
||||
classifier.add_mapping("SPY", AssetClass.EQUITY)
|
||||
classifier.clear_cache() # Clear cache to pick up new mapping
|
||||
|
||||
assert classifier.classify("SPY") == AssetClass.EQUITY
|
||||
|
||||
def test_cache(self):
|
||||
"""Test classification caching."""
|
||||
classifier = SymbolClassifier()
|
||||
|
||||
# First call should classify
|
||||
result1 = classifier.classify("AAPL")
|
||||
|
||||
# Second call should use cache
|
||||
result2 = classifier.classify("AAPL")
|
||||
|
||||
assert result1 == result2 == AssetClass.EQUITY
|
||||
|
||||
def test_clear_cache(self):
|
||||
"""Test clearing classification cache."""
|
||||
classifier = SymbolClassifier()
|
||||
|
||||
classifier.classify("AAPL") # Populate cache
|
||||
|
||||
classifier.clear_cache()
|
||||
|
||||
# Should still work after cache clear
|
||||
assert classifier.classify("AAPL") == AssetClass.EQUITY
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BrokerRegistration Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBrokerRegistration:
|
||||
"""Tests for BrokerRegistration dataclass."""
|
||||
|
||||
def test_create_registration(self):
|
||||
"""Test creating a broker registration."""
|
||||
broker = MockBroker("TestBroker", [AssetClass.EQUITY])
|
||||
|
||||
reg = BrokerRegistration(
|
||||
broker=broker,
|
||||
asset_classes={AssetClass.EQUITY, AssetClass.ETF},
|
||||
priority=10,
|
||||
is_primary=True,
|
||||
)
|
||||
|
||||
assert reg.broker == broker
|
||||
assert AssetClass.EQUITY in reg.asset_classes
|
||||
assert reg.priority == 10
|
||||
assert reg.is_primary is True
|
||||
assert reg.enabled is True
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default registration values."""
|
||||
broker = MockBroker("TestBroker", [AssetClass.EQUITY])
|
||||
|
||||
reg = BrokerRegistration(
|
||||
broker=broker,
|
||||
asset_classes={AssetClass.EQUITY},
|
||||
)
|
||||
|
||||
assert reg.priority == 0
|
||||
assert reg.is_primary is False
|
||||
assert reg.enabled is True
|
||||
assert reg.registered_at is not None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RoutingDecision Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRoutingDecision:
|
||||
"""Tests for RoutingDecision dataclass."""
|
||||
|
||||
def test_create_decision(self):
|
||||
"""Test creating a routing decision."""
|
||||
decision = RoutingDecision(
|
||||
symbol="AAPL",
|
||||
asset_class=AssetClass.EQUITY,
|
||||
broker_name="Alpaca",
|
||||
reason="Primary broker for equity",
|
||||
alternatives=["IBKR"],
|
||||
)
|
||||
|
||||
assert decision.symbol == "AAPL"
|
||||
assert decision.asset_class == AssetClass.EQUITY
|
||||
assert decision.broker_name == "Alpaca"
|
||||
assert "IBKR" in decision.alternatives
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BrokerRouter Tests - Registration
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBrokerRouterRegistration:
|
||||
"""Tests for broker registration."""
|
||||
|
||||
def test_register_broker(self):
|
||||
"""Test registering a broker."""
|
||||
router = BrokerRouter()
|
||||
broker = MockBroker("Alpaca", [AssetClass.EQUITY, AssetClass.ETF])
|
||||
|
||||
router.register(broker)
|
||||
|
||||
assert "Alpaca" in router.registered_brokers
|
||||
assert AssetClass.EQUITY in router.supported_asset_classes
|
||||
|
||||
def test_register_with_specific_classes(self):
|
||||
"""Test registering with specific asset classes."""
|
||||
router = BrokerRouter()
|
||||
broker = MockBroker("Alpaca", [AssetClass.EQUITY, AssetClass.ETF, AssetClass.CRYPTO])
|
||||
|
||||
# Only register for equity
|
||||
router.register(broker, asset_classes=[AssetClass.EQUITY])
|
||||
|
||||
assert AssetClass.EQUITY in router.supported_asset_classes
|
||||
# ETF and CRYPTO not registered even though broker supports them
|
||||
assert AssetClass.CRYPTO not in router.supported_asset_classes
|
||||
|
||||
def test_register_duplicate_raises(self):
|
||||
"""Test registering duplicate broker raises error."""
|
||||
router = BrokerRouter()
|
||||
broker = MockBroker("Alpaca", [AssetClass.EQUITY])
|
||||
|
||||
router.register(broker)
|
||||
|
||||
with pytest.raises(DuplicateBrokerError, match="already registered"):
|
||||
router.register(broker)
|
||||
|
||||
def test_unregister_broker(self):
|
||||
"""Test unregistering a broker."""
|
||||
router = BrokerRouter()
|
||||
broker = MockBroker("Alpaca", [AssetClass.EQUITY])
|
||||
|
||||
router.register(broker)
|
||||
router.unregister("Alpaca")
|
||||
|
||||
assert "Alpaca" not in router.registered_brokers
|
||||
|
||||
def test_unregister_nonexistent_raises(self):
|
||||
"""Test unregistering non-existent broker raises error."""
|
||||
router = BrokerRouter()
|
||||
|
||||
with pytest.raises(BrokerNotFoundError, match="not registered"):
|
||||
router.unregister("NonExistent")
|
||||
|
||||
def test_get_broker(self):
|
||||
"""Test getting a broker by name."""
|
||||
router = BrokerRouter()
|
||||
broker = MockBroker("Alpaca", [AssetClass.EQUITY])
|
||||
|
||||
router.register(broker)
|
||||
|
||||
result = router.get_broker("Alpaca")
|
||||
assert result == broker
|
||||
|
||||
def test_get_broker_not_found(self):
|
||||
"""Test getting non-existent broker raises error."""
|
||||
router = BrokerRouter()
|
||||
|
||||
with pytest.raises(BrokerNotFoundError):
|
||||
router.get_broker("NonExistent")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BrokerRouter Tests - Routing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBrokerRouterRouting:
|
||||
"""Tests for order routing."""
|
||||
|
||||
@pytest.fixture
|
||||
def router_with_brokers(self):
|
||||
"""Create a router with multiple brokers."""
|
||||
router = BrokerRouter()
|
||||
|
||||
equity_broker = MockBroker("EquityBroker", [AssetClass.EQUITY, AssetClass.ETF])
|
||||
crypto_broker = MockBroker("CryptoBroker", [AssetClass.CRYPTO])
|
||||
futures_broker = MockBroker("FuturesBroker", [AssetClass.FUTURE, AssetClass.OPTION])
|
||||
|
||||
router.register(equity_broker, [AssetClass.EQUITY, AssetClass.ETF])
|
||||
router.register(crypto_broker, [AssetClass.CRYPTO])
|
||||
router.register(futures_broker, [AssetClass.FUTURE, AssetClass.OPTION])
|
||||
|
||||
return router
|
||||
|
||||
def test_route_equity(self, router_with_brokers):
|
||||
"""Test routing equity symbol."""
|
||||
broker, decision = router_with_brokers.route("AAPL")
|
||||
|
||||
assert broker.name == "EquityBroker"
|
||||
assert decision.asset_class == AssetClass.EQUITY
|
||||
|
||||
def test_route_etf(self, router_with_brokers):
|
||||
"""Test routing ETF symbol."""
|
||||
broker, decision = router_with_brokers.route("SPY")
|
||||
|
||||
assert broker.name == "EquityBroker"
|
||||
assert decision.asset_class == AssetClass.ETF
|
||||
|
||||
def test_route_crypto(self, router_with_brokers):
|
||||
"""Test routing crypto symbol."""
|
||||
broker, decision = router_with_brokers.route("BTCUSD")
|
||||
|
||||
assert broker.name == "CryptoBroker"
|
||||
assert decision.asset_class == AssetClass.CRYPTO
|
||||
|
||||
def test_route_futures(self, router_with_brokers):
|
||||
"""Test routing futures symbol."""
|
||||
broker, decision = router_with_brokers.route("ESZ24")
|
||||
|
||||
assert broker.name == "FuturesBroker"
|
||||
assert decision.asset_class == AssetClass.FUTURE
|
||||
|
||||
def test_route_no_broker_raises(self):
|
||||
"""Test routing when no broker available."""
|
||||
router = BrokerRouter()
|
||||
|
||||
with pytest.raises(NoBrokerError, match="No broker available"):
|
||||
router.route("AAPL")
|
||||
|
||||
def test_fallback_broker(self):
|
||||
"""Test fallback broker usage."""
|
||||
router = BrokerRouter()
|
||||
equity_broker = MockBroker("EquityBroker", [AssetClass.EQUITY])
|
||||
|
||||
router.register(equity_broker)
|
||||
router.set_fallback("EquityBroker")
|
||||
|
||||
# Route an unknown asset class (FOREX) - should use fallback
|
||||
router.add_symbol_mapping("EURUSD", AssetClass.FOREX)
|
||||
broker, decision = router.route("EURUSD")
|
||||
|
||||
assert broker.name == "EquityBroker"
|
||||
assert "Fallback" in decision.reason
|
||||
|
||||
def test_disabled_broker_skipped(self, router_with_brokers):
|
||||
"""Test that disabled brokers are skipped."""
|
||||
router_with_brokers.disable_broker("EquityBroker")
|
||||
|
||||
with pytest.raises(NoBrokerError):
|
||||
router_with_brokers.route("AAPL")
|
||||
|
||||
def test_enable_broker(self, router_with_brokers):
|
||||
"""Test enabling a broker."""
|
||||
router_with_brokers.disable_broker("EquityBroker")
|
||||
router_with_brokers.enable_broker("EquityBroker")
|
||||
|
||||
broker, _ = router_with_brokers.route("AAPL")
|
||||
assert broker.name == "EquityBroker"
|
||||
|
||||
def test_priority_routing(self):
|
||||
"""Test priority-based broker selection."""
|
||||
router = BrokerRouter()
|
||||
|
||||
low_priority = MockBroker("LowPriority", [AssetClass.EQUITY])
|
||||
high_priority = MockBroker("HighPriority", [AssetClass.EQUITY])
|
||||
|
||||
router.register(low_priority, priority=1)
|
||||
router.register(high_priority, priority=10)
|
||||
|
||||
broker, _ = router.route("AAPL")
|
||||
assert broker.name == "HighPriority"
|
||||
|
||||
def test_primary_broker_preference(self):
|
||||
"""Test primary broker is preferred."""
|
||||
router = BrokerRouter()
|
||||
|
||||
secondary = MockBroker("Secondary", [AssetClass.EQUITY])
|
||||
primary = MockBroker("Primary", [AssetClass.EQUITY])
|
||||
|
||||
router.register(secondary, priority=10)
|
||||
router.register(primary, priority=1, primary=True)
|
||||
|
||||
broker, _ = router.route("AAPL")
|
||||
assert broker.name == "Primary"
|
||||
|
||||
def test_routing_history(self, router_with_brokers):
|
||||
"""Test routing history tracking."""
|
||||
router_with_brokers.route("AAPL")
|
||||
router_with_brokers.route("SPY")
|
||||
router_with_brokers.route("BTCUSD")
|
||||
|
||||
history = router_with_brokers.get_routing_history()
|
||||
|
||||
assert len(history) == 3
|
||||
# Most recent first
|
||||
assert history[0].symbol == "BTCUSD"
|
||||
assert history[1].symbol == "SPY"
|
||||
assert history[2].symbol == "AAPL"
|
||||
|
||||
def test_routing_history_filter(self, router_with_brokers):
|
||||
"""Test filtering routing history by symbol."""
|
||||
router_with_brokers.route("AAPL")
|
||||
router_with_brokers.route("SPY")
|
||||
router_with_brokers.route("AAPL")
|
||||
|
||||
history = router_with_brokers.get_routing_history(symbol="AAPL")
|
||||
|
||||
assert len(history) == 2
|
||||
assert all(d.symbol == "AAPL" for d in history)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BrokerRouter Tests - Order Management
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBrokerRouterOrders:
|
||||
"""Tests for order management through router."""
|
||||
|
||||
async def _create_connected_router(self):
|
||||
"""Create a connected router."""
|
||||
router = BrokerRouter()
|
||||
|
||||
equity_broker = MockBroker("EquityBroker", [AssetClass.EQUITY, AssetClass.ETF])
|
||||
crypto_broker = MockBroker("CryptoBroker", [AssetClass.CRYPTO])
|
||||
|
||||
router.register(equity_broker)
|
||||
router.register(crypto_broker)
|
||||
|
||||
await router.connect_all()
|
||||
return router
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_order_auto_route(self):
|
||||
"""Test submitting order with auto-routing."""
|
||||
router = await self._create_connected_router()
|
||||
request = OrderRequest.market("AAPL", OrderSide.BUY, 100)
|
||||
|
||||
order = await router.submit_order(request)
|
||||
|
||||
assert order.symbol == "AAPL"
|
||||
assert "EquityBroker" in order.broker_order_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_order_specific_broker(self):
|
||||
"""Test submitting order to specific broker."""
|
||||
router = await self._create_connected_router()
|
||||
request = OrderRequest.market("AAPL", OrderSide.BUY, 100)
|
||||
|
||||
order = await router.submit_order(request, broker_name="EquityBroker")
|
||||
|
||||
assert "EquityBroker" in order.broker_order_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_order(self):
|
||||
"""Test cancelling an order."""
|
||||
router = await self._create_connected_router()
|
||||
request = OrderRequest.market("AAPL", OrderSide.BUY, 100)
|
||||
order = await router.submit_order(request)
|
||||
|
||||
cancelled = await router.cancel_order(
|
||||
order.broker_order_id,
|
||||
broker_name="EquityBroker"
|
||||
)
|
||||
|
||||
assert cancelled.status == OrderStatus.CANCELLED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_orders_single_broker(self):
|
||||
"""Test getting orders from single broker."""
|
||||
router = await self._create_connected_router()
|
||||
await router.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, 100))
|
||||
await router.submit_order(OrderRequest.market("BTCUSD", OrderSide.BUY, 1))
|
||||
|
||||
orders = await router.get_orders(broker_name="EquityBroker")
|
||||
|
||||
assert "EquityBroker" in orders
|
||||
assert len(orders["EquityBroker"]) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_orders_all_brokers(self):
|
||||
"""Test getting orders from all brokers."""
|
||||
router = await self._create_connected_router()
|
||||
await router.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, 100))
|
||||
await router.submit_order(OrderRequest.market("BTCUSD", OrderSide.BUY, 1))
|
||||
|
||||
orders = await router.get_orders()
|
||||
|
||||
assert len(orders) == 2 # Two brokers
|
||||
assert "EquityBroker" in orders
|
||||
assert "CryptoBroker" in orders
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_all_orders(self):
|
||||
"""Test cancelling all orders."""
|
||||
router = await self._create_connected_router()
|
||||
await router.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, 100))
|
||||
await router.submit_order(OrderRequest.market("MSFT", OrderSide.BUY, 50))
|
||||
|
||||
cancelled = await router.cancel_all_orders()
|
||||
|
||||
assert "EquityBroker" in cancelled
|
||||
assert len(cancelled["EquityBroker"]) == 2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BrokerRouter Tests - Position Management
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBrokerRouterPositions:
|
||||
"""Tests for position management through router."""
|
||||
|
||||
async def _create_router_with_positions(self):
|
||||
"""Create router with positions."""
|
||||
router = BrokerRouter()
|
||||
|
||||
equity_broker = MockBroker("EquityBroker", [AssetClass.EQUITY])
|
||||
crypto_broker = MockBroker("CryptoBroker", [AssetClass.CRYPTO])
|
||||
|
||||
router.register(equity_broker)
|
||||
router.register(crypto_broker)
|
||||
|
||||
await router.connect_all()
|
||||
|
||||
# Add some positions
|
||||
equity_broker.add_position(Position(
|
||||
symbol="AAPL",
|
||||
quantity=Decimal("100"),
|
||||
side=PositionSide.LONG,
|
||||
avg_entry_price=Decimal("150"),
|
||||
current_price=Decimal("160"),
|
||||
market_value=Decimal("16000"),
|
||||
cost_basis=Decimal("15000"),
|
||||
unrealized_pnl=Decimal("1000"),
|
||||
unrealized_pnl_percent=Decimal("6.67"),
|
||||
))
|
||||
|
||||
crypto_broker.add_position(Position(
|
||||
symbol="BTCUSD",
|
||||
quantity=Decimal("1"),
|
||||
side=PositionSide.LONG,
|
||||
avg_entry_price=Decimal("40000"),
|
||||
current_price=Decimal("45000"),
|
||||
market_value=Decimal("45000"),
|
||||
cost_basis=Decimal("40000"),
|
||||
unrealized_pnl=Decimal("5000"),
|
||||
unrealized_pnl_percent=Decimal("12.5"),
|
||||
))
|
||||
|
||||
return router
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_positions_all(self):
|
||||
"""Test getting positions from all brokers."""
|
||||
router = await self._create_router_with_positions()
|
||||
positions = await router.get_positions()
|
||||
|
||||
assert len(positions) == 2
|
||||
assert "EquityBroker" in positions
|
||||
assert "CryptoBroker" in positions
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_positions(self):
|
||||
"""Test getting aggregated positions."""
|
||||
router = await self._create_router_with_positions()
|
||||
positions = await router.get_all_positions()
|
||||
|
||||
assert len(positions) == 2
|
||||
symbols = {p[1].symbol for p in positions}
|
||||
assert "AAPL" in symbols
|
||||
assert "BTCUSD" in symbols
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_position(self):
|
||||
"""Test getting specific position."""
|
||||
router = await self._create_router_with_positions()
|
||||
result = await router.get_position("AAPL")
|
||||
|
||||
assert result is not None
|
||||
broker_name, position = result
|
||||
assert broker_name == "EquityBroker"
|
||||
assert position.symbol == "AAPL"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_position_not_found(self):
|
||||
"""Test getting non-existent position."""
|
||||
router = await self._create_router_with_positions()
|
||||
result = await router.get_position("NONEXISTENT")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_position(self):
|
||||
"""Test closing a position."""
|
||||
router = await self._create_router_with_positions()
|
||||
order = await router.close_position("AAPL")
|
||||
|
||||
assert order.symbol == "AAPL"
|
||||
assert order.side == OrderSide.SELL
|
||||
assert order.quantity == Decimal("100")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_position_not_found(self):
|
||||
"""Test closing non-existent position."""
|
||||
router = await self._create_router_with_positions()
|
||||
with pytest.raises(PositionError, match="No position found"):
|
||||
await router.close_position("NONEXISTENT")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BrokerRouter Tests - Account Management
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBrokerRouterAccounts:
|
||||
"""Tests for account management through router."""
|
||||
|
||||
async def _create_connected_router(self):
|
||||
"""Create connected router."""
|
||||
router = BrokerRouter()
|
||||
|
||||
broker1 = MockBroker("Broker1", [AssetClass.EQUITY])
|
||||
broker2 = MockBroker("Broker2", [AssetClass.CRYPTO])
|
||||
|
||||
router.register(broker1)
|
||||
router.register(broker2)
|
||||
|
||||
await router.connect_all()
|
||||
return router
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_accounts(self):
|
||||
"""Test getting all accounts."""
|
||||
router = await self._create_connected_router()
|
||||
accounts = await router.get_accounts()
|
||||
|
||||
assert len(accounts) == 2
|
||||
assert "Broker1" in accounts
|
||||
assert "Broker2" in accounts
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_total_equity(self):
|
||||
"""Test getting total equity."""
|
||||
router = await self._create_connected_router()
|
||||
equity = await router.get_total_equity()
|
||||
|
||||
# Each mock broker has 150000 equity
|
||||
assert equity == Decimal("300000")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_total_buying_power(self):
|
||||
"""Test getting total buying power."""
|
||||
router = await self._create_connected_router()
|
||||
power = await router.get_total_buying_power()
|
||||
|
||||
# Each mock broker has 200000 buying power
|
||||
assert power == Decimal("400000")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BrokerRouter Tests - Market Data
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBrokerRouterMarketData:
|
||||
"""Tests for market data through router."""
|
||||
|
||||
async def _create_connected_router(self):
|
||||
"""Create connected router."""
|
||||
router = BrokerRouter()
|
||||
|
||||
equity_broker = MockBroker("EquityBroker", [AssetClass.EQUITY])
|
||||
crypto_broker = MockBroker("CryptoBroker", [AssetClass.CRYPTO])
|
||||
|
||||
router.register(equity_broker)
|
||||
router.register(crypto_broker)
|
||||
|
||||
await router.connect_all()
|
||||
return router
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_quote(self):
|
||||
"""Test getting a quote."""
|
||||
router = await self._create_connected_router()
|
||||
quote = await router.get_quote("AAPL")
|
||||
|
||||
assert quote.symbol == "AAPL"
|
||||
assert quote.bid_price is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_quotes(self):
|
||||
"""Test getting multiple quotes."""
|
||||
router = await self._create_connected_router()
|
||||
quotes = await router.get_quotes(["AAPL", "BTCUSD"])
|
||||
|
||||
assert len(quotes) == 2
|
||||
assert "AAPL" in quotes
|
||||
assert "BTCUSD" in quotes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_asset(self):
|
||||
"""Test getting asset info."""
|
||||
router = await self._create_connected_router()
|
||||
asset = await router.get_asset("AAPL")
|
||||
|
||||
assert asset.symbol == "AAPL"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BrokerRouter Tests - Connection Management
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBrokerRouterConnection:
|
||||
"""Tests for connection management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_all(self):
|
||||
"""Test connecting all brokers."""
|
||||
router = BrokerRouter()
|
||||
|
||||
broker1 = MockBroker("Broker1", [AssetClass.EQUITY])
|
||||
broker2 = MockBroker("Broker2", [AssetClass.CRYPTO])
|
||||
|
||||
router.register(broker1)
|
||||
router.register(broker2)
|
||||
|
||||
results = await router.connect_all()
|
||||
|
||||
assert results["Broker1"] is True
|
||||
assert results["Broker2"] is True
|
||||
assert broker1.is_connected
|
||||
assert broker2.is_connected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all(self):
|
||||
"""Test disconnecting all brokers."""
|
||||
router = BrokerRouter()
|
||||
|
||||
broker = MockBroker("Broker", [AssetClass.EQUITY])
|
||||
router.register(broker)
|
||||
|
||||
await router.connect_all()
|
||||
await router.disconnect_all()
|
||||
|
||||
assert not broker.is_connected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_market_open(self):
|
||||
"""Test checking market status."""
|
||||
router = BrokerRouter()
|
||||
|
||||
broker = MockBroker("Broker", [AssetClass.EQUITY])
|
||||
router.register(broker)
|
||||
await router.connect_all()
|
||||
|
||||
is_open = await router.is_market_open()
|
||||
assert is_open is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BrokerRouter Tests - Status
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBrokerRouterStatus:
|
||||
"""Tests for status reporting."""
|
||||
|
||||
def test_get_broker_status(self):
|
||||
"""Test getting broker status."""
|
||||
router = BrokerRouter()
|
||||
|
||||
broker = MockBroker("TestBroker", [AssetClass.EQUITY, AssetClass.ETF])
|
||||
router.register(broker, priority=5, primary=True)
|
||||
|
||||
status = router.get_broker_status()
|
||||
|
||||
assert "TestBroker" in status
|
||||
broker_status = status["TestBroker"]
|
||||
assert broker_status["connected"] is False
|
||||
assert broker_status["enabled"] is True
|
||||
assert broker_status["priority"] == 5
|
||||
assert broker_status["is_primary"] is True
|
||||
assert "equity" in broker_status["asset_classes"]
|
||||
|
||||
def test_repr(self):
|
||||
"""Test string representation."""
|
||||
router = BrokerRouter()
|
||||
|
||||
broker1 = MockBroker("Broker1", [AssetClass.EQUITY])
|
||||
broker2 = MockBroker("Broker2", [AssetClass.CRYPTO])
|
||||
|
||||
router.register(broker1)
|
||||
router.register(broker2)
|
||||
|
||||
repr_str = repr(router)
|
||||
assert "BrokerRouter" in repr_str
|
||||
assert "Broker1" in repr_str
|
||||
assert "Broker2" in repr_str
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Router Exception Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRouterExceptions:
|
||||
"""Tests for router exceptions."""
|
||||
|
||||
def test_routing_error(self):
|
||||
"""Test RoutingError exception."""
|
||||
error = RoutingError("Routing failed")
|
||||
assert isinstance(error, BrokerError)
|
||||
|
||||
def test_no_broker_error(self):
|
||||
"""Test NoBrokerError exception."""
|
||||
error = NoBrokerError("No broker for asset class")
|
||||
assert isinstance(error, RoutingError)
|
||||
|
||||
def test_broker_not_found_error(self):
|
||||
"""Test BrokerNotFoundError exception."""
|
||||
error = BrokerNotFoundError("Broker not found")
|
||||
assert isinstance(error, RoutingError)
|
||||
|
||||
def test_duplicate_broker_error(self):
|
||||
"""Test DuplicateBrokerError exception."""
|
||||
error = DuplicateBrokerError("Broker already registered")
|
||||
assert isinstance(error, RoutingError)
|
||||
|
|
@ -4,9 +4,11 @@ This module provides a unified interface for interacting with various brokers
|
|||
(Alpaca, IBKR, Paper) and managing order execution.
|
||||
|
||||
Issue #22: [EXEC-21] Broker base interface - abstract broker class
|
||||
Issue #23: [EXEC-22] Broker router - route by asset class
|
||||
|
||||
Submodules:
|
||||
broker_base: Abstract base class for broker implementations
|
||||
broker_router: Router for multi-broker setups
|
||||
|
||||
Classes:
|
||||
Enums:
|
||||
|
|
@ -86,6 +88,19 @@ from .broker_base import (
|
|||
BrokerBase,
|
||||
)
|
||||
|
||||
from .broker_router import (
|
||||
# Router Classes
|
||||
BrokerRouter,
|
||||
BrokerRegistration,
|
||||
RoutingDecision,
|
||||
SymbolClassifier,
|
||||
# Router Exceptions
|
||||
RoutingError,
|
||||
NoBrokerError,
|
||||
BrokerNotFoundError,
|
||||
DuplicateBrokerError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Enums
|
||||
"AssetClass",
|
||||
|
|
@ -112,4 +127,13 @@ __all__ = [
|
|||
"RateLimitError",
|
||||
# Abstract Base Class
|
||||
"BrokerBase",
|
||||
# Router
|
||||
"BrokerRouter",
|
||||
"BrokerRegistration",
|
||||
"RoutingDecision",
|
||||
"SymbolClassifier",
|
||||
"RoutingError",
|
||||
"NoBrokerError",
|
||||
"BrokerNotFoundError",
|
||||
"DuplicateBrokerError",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,962 @@
|
|||
"""Broker Router for routing orders by asset class.
|
||||
|
||||
This module provides a router that directs orders to the appropriate broker
|
||||
based on the asset class being traded. This enables multi-broker setups
|
||||
where different brokers handle different asset classes.
|
||||
|
||||
Issue #23: [EXEC-22] Broker router - route by asset class
|
||||
|
||||
Example:
|
||||
>>> from tradingagents.execution import BrokerRouter, AssetClass
|
||||
>>>
|
||||
>>> router = BrokerRouter()
|
||||
>>> router.register(alpaca_broker, [AssetClass.EQUITY, AssetClass.CRYPTO])
|
||||
>>> router.register(ibkr_broker, [AssetClass.FUTURE, AssetClass.OPTION])
|
||||
>>>
|
||||
>>> # Orders automatically routed to correct broker
|
||||
>>> await router.submit_order(order_request) # Routes based on symbol
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||
import asyncio
|
||||
|
||||
from .broker_base import (
|
||||
AssetClass,
|
||||
AssetInfo,
|
||||
BrokerBase,
|
||||
BrokerError,
|
||||
Order,
|
||||
OrderRequest,
|
||||
OrderStatus,
|
||||
OrderSide,
|
||||
Position,
|
||||
Quote,
|
||||
AccountInfo,
|
||||
TimeInForce,
|
||||
PositionError,
|
||||
ConnectionError as BrokerConnectionError,
|
||||
)
|
||||
|
||||
|
||||
class RoutingError(BrokerError):
|
||||
"""Error in broker routing."""
|
||||
pass
|
||||
|
||||
|
||||
class NoBrokerError(RoutingError):
|
||||
"""No broker available for the requested asset class."""
|
||||
pass
|
||||
|
||||
|
||||
class BrokerNotFoundError(RoutingError):
|
||||
"""Specified broker not found."""
|
||||
pass
|
||||
|
||||
|
||||
class DuplicateBrokerError(RoutingError):
|
||||
"""Broker already registered."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrokerRegistration:
|
||||
"""Registration info for a broker.
|
||||
|
||||
Attributes:
|
||||
broker: The broker instance
|
||||
asset_classes: Asset classes this broker handles
|
||||
priority: Priority for routing (higher = preferred)
|
||||
is_primary: Whether this is the primary broker for its classes
|
||||
enabled: Whether this broker is currently enabled
|
||||
registered_at: When the broker was registered
|
||||
"""
|
||||
broker: BrokerBase
|
||||
asset_classes: Set[AssetClass]
|
||||
priority: int = 0
|
||||
is_primary: bool = False
|
||||
enabled: bool = True
|
||||
registered_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingDecision:
|
||||
"""Record of a routing decision.
|
||||
|
||||
Attributes:
|
||||
symbol: Symbol being routed
|
||||
asset_class: Detected asset class
|
||||
broker_name: Name of selected broker
|
||||
reason: Reason for the routing decision
|
||||
alternatives: Alternative brokers that could handle this
|
||||
timestamp: When decision was made
|
||||
"""
|
||||
symbol: str
|
||||
asset_class: AssetClass
|
||||
broker_name: str
|
||||
reason: str
|
||||
alternatives: List[str] = field(default_factory=list)
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class SymbolClassifier:
|
||||
"""Classifies symbols into asset classes.
|
||||
|
||||
This provides default classification based on symbol patterns.
|
||||
Can be customized with explicit mappings or external data sources.
|
||||
"""
|
||||
|
||||
# Default patterns for common exchanges/symbols
|
||||
CRYPTO_SUFFIXES = {"USD", "USDT", "BTC", "ETH"}
|
||||
ETF_SUFFIXES = {"ETF"}
|
||||
|
||||
# Known ETFs (partial list)
|
||||
KNOWN_ETFS = {
|
||||
"SPY", "QQQ", "IWM", "DIA", "VOO", "VTI", "VEA", "VWO",
|
||||
"XLF", "XLE", "XLK", "XLV", "XLP", "XLY", "XLI", "XLB", "XLU",
|
||||
"GLD", "SLV", "USO", "UNG", "TLT", "IEF", "SHY", "BND",
|
||||
"EEM", "EFA", "IEMG", "IEFA", "AGG", "LQD", "HYG", "JNK",
|
||||
}
|
||||
|
||||
# Known crypto tickers
|
||||
KNOWN_CRYPTO = {
|
||||
"BTC", "ETH", "BTCUSD", "ETHUSD", "SOLUSD", "DOGEUSD",
|
||||
"AVAXUSD", "LINKUSD", "UNIUSD", "MATICUSD", "ADAUSD",
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize classifier with custom mappings."""
|
||||
self._custom_mappings: Dict[str, AssetClass] = {}
|
||||
self._symbol_cache: Dict[str, AssetClass] = {}
|
||||
|
||||
def add_mapping(self, symbol: str, asset_class: AssetClass) -> None:
|
||||
"""Add a custom symbol-to-class mapping.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
asset_class: Asset class for the symbol
|
||||
"""
|
||||
self._custom_mappings[symbol.upper()] = asset_class
|
||||
# Invalidate cache for this symbol
|
||||
self._symbol_cache.pop(symbol.upper(), None)
|
||||
|
||||
def classify(self, symbol: str) -> AssetClass:
|
||||
"""Classify a symbol into an asset class.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Detected asset class (defaults to EQUITY)
|
||||
"""
|
||||
symbol = symbol.upper()
|
||||
|
||||
# Check cache first
|
||||
if symbol in self._symbol_cache:
|
||||
return self._symbol_cache[symbol]
|
||||
|
||||
# Check custom mappings
|
||||
if symbol in self._custom_mappings:
|
||||
result = self._custom_mappings[symbol]
|
||||
self._symbol_cache[symbol] = result
|
||||
return result
|
||||
|
||||
# Check known crypto
|
||||
if symbol in self.KNOWN_CRYPTO:
|
||||
result = AssetClass.CRYPTO
|
||||
self._symbol_cache[symbol] = result
|
||||
return result
|
||||
|
||||
# Check crypto patterns (ends with crypto currency)
|
||||
for suffix in self.CRYPTO_SUFFIXES:
|
||||
if symbol.endswith(suffix) and len(symbol) > len(suffix):
|
||||
result = AssetClass.CRYPTO
|
||||
self._symbol_cache[symbol] = result
|
||||
return result
|
||||
|
||||
# Check known ETFs
|
||||
if symbol in self.KNOWN_ETFS:
|
||||
result = AssetClass.ETF
|
||||
self._symbol_cache[symbol] = result
|
||||
return result
|
||||
|
||||
# Check futures patterns (month/year codes)
|
||||
# E.g., ESZ24 (S&P 500 Dec 2024), CLF25 (Crude Oil Jan 2025)
|
||||
if len(symbol) >= 4 and symbol[-2:].isdigit():
|
||||
if symbol[-3] in "FGHJKMNQUVXZ": # Future month codes
|
||||
result = AssetClass.FUTURE
|
||||
self._symbol_cache[symbol] = result
|
||||
return result
|
||||
|
||||
# Check options patterns (contain numbers and special chars)
|
||||
if "/" in symbol or ("C" in symbol and symbol[-1].isdigit()):
|
||||
result = AssetClass.OPTION
|
||||
self._symbol_cache[symbol] = result
|
||||
return result
|
||||
|
||||
# Default to equity
|
||||
result = AssetClass.EQUITY
|
||||
self._symbol_cache[symbol] = result
|
||||
return result
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the classification cache."""
|
||||
self._symbol_cache.clear()
|
||||
|
||||
|
||||
class BrokerRouter:
|
||||
"""Routes orders to appropriate brokers based on asset class.
|
||||
|
||||
The router maintains a registry of brokers and their supported asset
|
||||
classes. When an order comes in, it classifies the symbol and routes
|
||||
to the appropriate broker.
|
||||
|
||||
Features:
|
||||
- Multi-broker support with priority-based selection
|
||||
- Automatic symbol classification
|
||||
- Custom routing rules
|
||||
- Fallback broker support
|
||||
- Aggregated position views
|
||||
|
||||
Example:
|
||||
>>> router = BrokerRouter()
|
||||
>>>
|
||||
>>> # Register brokers
|
||||
>>> router.register(alpaca, [AssetClass.EQUITY, AssetClass.CRYPTO])
|
||||
>>> router.register(ibkr, [AssetClass.FUTURE], primary=True)
|
||||
>>>
|
||||
>>> # Submit order (auto-routed)
|
||||
>>> order = await router.submit_order(
|
||||
... OrderRequest.market("AAPL", OrderSide.BUY, 100)
|
||||
... )
|
||||
>>>
|
||||
>>> # Get aggregated positions across all brokers
|
||||
>>> positions = await router.get_all_positions()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
classifier: Optional[SymbolClassifier] = None,
|
||||
default_asset_class: AssetClass = AssetClass.EQUITY,
|
||||
):
|
||||
"""Initialize the broker router.
|
||||
|
||||
Args:
|
||||
classifier: Symbol classifier (creates default if None)
|
||||
default_asset_class: Default class for unknown symbols
|
||||
"""
|
||||
self._brokers: Dict[str, BrokerRegistration] = {}
|
||||
self._class_to_brokers: Dict[AssetClass, List[str]] = {}
|
||||
self._classifier = classifier or SymbolClassifier()
|
||||
self._default_asset_class = default_asset_class
|
||||
self._fallback_broker: Optional[str] = None
|
||||
self._routing_history: List[RoutingDecision] = []
|
||||
self._max_history = 1000
|
||||
|
||||
@property
|
||||
def registered_brokers(self) -> List[str]:
|
||||
"""Get list of registered broker names."""
|
||||
return list(self._brokers.keys())
|
||||
|
||||
@property
|
||||
def supported_asset_classes(self) -> Set[AssetClass]:
|
||||
"""Get all supported asset classes across all brokers."""
|
||||
classes = set()
|
||||
for reg in self._brokers.values():
|
||||
if reg.enabled:
|
||||
classes.update(reg.asset_classes)
|
||||
return classes
|
||||
|
||||
def register(
|
||||
self,
|
||||
broker: BrokerBase,
|
||||
asset_classes: Optional[List[AssetClass]] = None,
|
||||
priority: int = 0,
|
||||
primary: bool = False,
|
||||
) -> None:
|
||||
"""Register a broker for specific asset classes.
|
||||
|
||||
Args:
|
||||
broker: Broker instance to register
|
||||
asset_classes: Asset classes this broker handles (uses broker's
|
||||
supported classes if None)
|
||||
priority: Priority for routing (higher = preferred)
|
||||
primary: Whether this should be the primary broker for its classes
|
||||
|
||||
Raises:
|
||||
DuplicateBrokerError: If broker is already registered
|
||||
"""
|
||||
if broker.name in self._brokers:
|
||||
raise DuplicateBrokerError(f"Broker '{broker.name}' is already registered")
|
||||
|
||||
# Use broker's supported classes if not specified
|
||||
classes = set(asset_classes) if asset_classes else set(broker.supported_asset_classes)
|
||||
|
||||
registration = BrokerRegistration(
|
||||
broker=broker,
|
||||
asset_classes=classes,
|
||||
priority=priority,
|
||||
is_primary=primary,
|
||||
)
|
||||
|
||||
self._brokers[broker.name] = registration
|
||||
|
||||
# Update class-to-broker mapping
|
||||
for asset_class in classes:
|
||||
if asset_class not in self._class_to_brokers:
|
||||
self._class_to_brokers[asset_class] = []
|
||||
self._class_to_brokers[asset_class].append(broker.name)
|
||||
|
||||
# Sort by priority (highest first)
|
||||
self._class_to_brokers[asset_class].sort(
|
||||
key=lambda n: (
|
||||
self._brokers[n].is_primary,
|
||||
self._brokers[n].priority,
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
def unregister(self, broker_name: str) -> None:
|
||||
"""Unregister a broker.
|
||||
|
||||
Args:
|
||||
broker_name: Name of broker to unregister
|
||||
|
||||
Raises:
|
||||
BrokerNotFoundError: If broker is not registered
|
||||
"""
|
||||
if broker_name not in self._brokers:
|
||||
raise BrokerNotFoundError(f"Broker '{broker_name}' is not registered")
|
||||
|
||||
registration = self._brokers[broker_name]
|
||||
|
||||
# Remove from class mapping
|
||||
for asset_class in registration.asset_classes:
|
||||
if asset_class in self._class_to_brokers:
|
||||
self._class_to_brokers[asset_class].remove(broker_name)
|
||||
if not self._class_to_brokers[asset_class]:
|
||||
del self._class_to_brokers[asset_class]
|
||||
|
||||
del self._brokers[broker_name]
|
||||
|
||||
# Clear fallback if it was this broker
|
||||
if self._fallback_broker == broker_name:
|
||||
self._fallback_broker = None
|
||||
|
||||
def set_fallback(self, broker_name: str) -> None:
|
||||
"""Set a fallback broker for unclassified symbols.
|
||||
|
||||
Args:
|
||||
broker_name: Name of broker to use as fallback
|
||||
|
||||
Raises:
|
||||
BrokerNotFoundError: If broker is not registered
|
||||
"""
|
||||
if broker_name not in self._brokers:
|
||||
raise BrokerNotFoundError(f"Broker '{broker_name}' is not registered")
|
||||
self._fallback_broker = broker_name
|
||||
|
||||
def enable_broker(self, broker_name: str) -> None:
|
||||
"""Enable a broker for routing.
|
||||
|
||||
Args:
|
||||
broker_name: Name of broker to enable
|
||||
|
||||
Raises:
|
||||
BrokerNotFoundError: If broker is not registered
|
||||
"""
|
||||
if broker_name not in self._brokers:
|
||||
raise BrokerNotFoundError(f"Broker '{broker_name}' is not registered")
|
||||
self._brokers[broker_name].enabled = True
|
||||
|
||||
def disable_broker(self, broker_name: str) -> None:
|
||||
"""Disable a broker from routing.
|
||||
|
||||
Args:
|
||||
broker_name: Name of broker to disable
|
||||
|
||||
Raises:
|
||||
BrokerNotFoundError: If broker is not registered
|
||||
"""
|
||||
if broker_name not in self._brokers:
|
||||
raise BrokerNotFoundError(f"Broker '{broker_name}' is not registered")
|
||||
self._brokers[broker_name].enabled = False
|
||||
|
||||
def get_broker(self, broker_name: str) -> BrokerBase:
|
||||
"""Get a broker by name.
|
||||
|
||||
Args:
|
||||
broker_name: Name of broker
|
||||
|
||||
Returns:
|
||||
BrokerBase instance
|
||||
|
||||
Raises:
|
||||
BrokerNotFoundError: If broker is not registered
|
||||
"""
|
||||
if broker_name not in self._brokers:
|
||||
raise BrokerNotFoundError(f"Broker '{broker_name}' is not registered")
|
||||
return self._brokers[broker_name].broker
|
||||
|
||||
def route(self, symbol: str) -> Tuple[BrokerBase, RoutingDecision]:
|
||||
"""Route a symbol to the appropriate broker.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Tuple of (broker, routing_decision)
|
||||
|
||||
Raises:
|
||||
NoBrokerError: If no broker can handle this symbol
|
||||
"""
|
||||
# Classify the symbol
|
||||
asset_class = self._classifier.classify(symbol)
|
||||
|
||||
# Find brokers for this class
|
||||
broker_names = self._class_to_brokers.get(asset_class, [])
|
||||
|
||||
# Filter to enabled brokers
|
||||
enabled_names = [
|
||||
n for n in broker_names
|
||||
if self._brokers[n].enabled
|
||||
]
|
||||
|
||||
# Select the best broker
|
||||
if enabled_names:
|
||||
broker_name = enabled_names[0]
|
||||
reason = f"Primary broker for {asset_class.value}"
|
||||
alternatives = enabled_names[1:]
|
||||
elif self._fallback_broker and self._brokers[self._fallback_broker].enabled:
|
||||
broker_name = self._fallback_broker
|
||||
reason = f"Fallback broker (no broker for {asset_class.value})"
|
||||
alternatives = []
|
||||
else:
|
||||
raise NoBrokerError(
|
||||
f"No broker available for symbol '{symbol}' "
|
||||
f"(asset class: {asset_class.value})"
|
||||
)
|
||||
|
||||
decision = RoutingDecision(
|
||||
symbol=symbol,
|
||||
asset_class=asset_class,
|
||||
broker_name=broker_name,
|
||||
reason=reason,
|
||||
alternatives=alternatives,
|
||||
)
|
||||
|
||||
# Record routing history
|
||||
self._routing_history.append(decision)
|
||||
if len(self._routing_history) > self._max_history:
|
||||
self._routing_history = self._routing_history[-self._max_history:]
|
||||
|
||||
return self._brokers[broker_name].broker, decision
|
||||
|
||||
def add_symbol_mapping(self, symbol: str, asset_class: AssetClass) -> None:
|
||||
"""Add a custom symbol-to-asset-class mapping.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
asset_class: Asset class for the symbol
|
||||
"""
|
||||
self._classifier.add_mapping(symbol, asset_class)
|
||||
|
||||
# ==========================================================================
|
||||
# Connection Management
|
||||
# ==========================================================================
|
||||
|
||||
async def connect_all(self) -> Dict[str, bool]:
|
||||
"""Connect all registered brokers.
|
||||
|
||||
Returns:
|
||||
Dict mapping broker name to connection success
|
||||
"""
|
||||
results = {}
|
||||
for name, reg in self._brokers.items():
|
||||
try:
|
||||
results[name] = await reg.broker.connect()
|
||||
except Exception as e:
|
||||
results[name] = False
|
||||
return results
|
||||
|
||||
async def disconnect_all(self) -> None:
|
||||
"""Disconnect all registered brokers."""
|
||||
for reg in self._brokers.values():
|
||||
try:
|
||||
await reg.broker.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def is_market_open(self, broker_name: Optional[str] = None) -> bool:
|
||||
"""Check if market is open.
|
||||
|
||||
Args:
|
||||
broker_name: Specific broker to check (checks first enabled if None)
|
||||
|
||||
Returns:
|
||||
True if market is open
|
||||
"""
|
||||
if broker_name:
|
||||
broker = self.get_broker(broker_name)
|
||||
return await broker.is_market_open()
|
||||
|
||||
# Check first enabled broker
|
||||
for reg in self._brokers.values():
|
||||
if reg.enabled and reg.broker.is_connected:
|
||||
return await reg.broker.is_market_open()
|
||||
|
||||
return False
|
||||
|
||||
# ==========================================================================
|
||||
# Order Management
|
||||
# ==========================================================================
|
||||
|
||||
async def submit_order(
|
||||
self,
|
||||
request: OrderRequest,
|
||||
broker_name: Optional[str] = None,
|
||||
) -> Order:
|
||||
"""Submit an order, routing to the appropriate broker.
|
||||
|
||||
Args:
|
||||
request: Order request
|
||||
broker_name: Optional specific broker (auto-routes if None)
|
||||
|
||||
Returns:
|
||||
Order object
|
||||
|
||||
Raises:
|
||||
NoBrokerError: If no broker can handle this symbol
|
||||
"""
|
||||
if broker_name:
|
||||
broker = self.get_broker(broker_name)
|
||||
else:
|
||||
broker, _ = self.route(request.symbol)
|
||||
|
||||
return await broker.submit_order(request)
|
||||
|
||||
async def cancel_order(
|
||||
self,
|
||||
order_id: str,
|
||||
broker_name: str,
|
||||
) -> Order:
|
||||
"""Cancel an order.
|
||||
|
||||
Args:
|
||||
order_id: Order ID to cancel
|
||||
broker_name: Name of broker holding the order
|
||||
|
||||
Returns:
|
||||
Updated order
|
||||
"""
|
||||
broker = self.get_broker(broker_name)
|
||||
return await broker.cancel_order(order_id)
|
||||
|
||||
async def replace_order(
|
||||
self,
|
||||
order_id: str,
|
||||
broker_name: str,
|
||||
quantity: Optional[Decimal] = None,
|
||||
limit_price: Optional[Decimal] = None,
|
||||
stop_price: Optional[Decimal] = None,
|
||||
time_in_force: Optional[TimeInForce] = None,
|
||||
) -> Order:
|
||||
"""Replace an order.
|
||||
|
||||
Args:
|
||||
order_id: Order ID to replace
|
||||
broker_name: Name of broker holding the order
|
||||
quantity: New quantity
|
||||
limit_price: New limit price
|
||||
stop_price: New stop price
|
||||
time_in_force: New time in force
|
||||
|
||||
Returns:
|
||||
New order
|
||||
"""
|
||||
broker = self.get_broker(broker_name)
|
||||
return await broker.replace_order(
|
||||
order_id, quantity, limit_price, stop_price, time_in_force
|
||||
)
|
||||
|
||||
async def get_order(self, order_id: str, broker_name: str) -> Order:
|
||||
"""Get an order from a specific broker.
|
||||
|
||||
Args:
|
||||
order_id: Order ID
|
||||
broker_name: Name of broker
|
||||
|
||||
Returns:
|
||||
Order object
|
||||
"""
|
||||
broker = self.get_broker(broker_name)
|
||||
return await broker.get_order(order_id)
|
||||
|
||||
async def get_orders(
|
||||
self,
|
||||
broker_name: Optional[str] = None,
|
||||
status: Optional[OrderStatus] = None,
|
||||
symbols: Optional[List[str]] = None,
|
||||
) -> Dict[str, List[Order]]:
|
||||
"""Get orders from one or all brokers.
|
||||
|
||||
Args:
|
||||
broker_name: Specific broker (all if None)
|
||||
status: Filter by status
|
||||
symbols: Filter by symbols
|
||||
|
||||
Returns:
|
||||
Dict mapping broker name to list of orders
|
||||
"""
|
||||
results = {}
|
||||
|
||||
if broker_name:
|
||||
brokers = [(broker_name, self.get_broker(broker_name))]
|
||||
else:
|
||||
brokers = [
|
||||
(name, reg.broker)
|
||||
for name, reg in self._brokers.items()
|
||||
if reg.enabled and reg.broker.is_connected
|
||||
]
|
||||
|
||||
for name, broker in brokers:
|
||||
try:
|
||||
orders = await broker.get_orders(status=status, symbols=symbols)
|
||||
results[name] = orders
|
||||
except Exception:
|
||||
results[name] = []
|
||||
|
||||
return results
|
||||
|
||||
async def cancel_all_orders(
|
||||
self,
|
||||
broker_name: Optional[str] = None,
|
||||
symbols: Optional[List[str]] = None,
|
||||
) -> Dict[str, List[Order]]:
|
||||
"""Cancel all orders across one or all brokers.
|
||||
|
||||
Args:
|
||||
broker_name: Specific broker (all if None)
|
||||
symbols: Filter by symbols
|
||||
|
||||
Returns:
|
||||
Dict mapping broker name to list of cancelled orders
|
||||
"""
|
||||
results = {}
|
||||
|
||||
if broker_name:
|
||||
brokers = [(broker_name, self.get_broker(broker_name))]
|
||||
else:
|
||||
brokers = [
|
||||
(name, reg.broker)
|
||||
for name, reg in self._brokers.items()
|
||||
if reg.enabled and reg.broker.is_connected
|
||||
]
|
||||
|
||||
for name, broker in brokers:
|
||||
try:
|
||||
cancelled = await broker.cancel_all_orders(symbols=symbols)
|
||||
results[name] = cancelled
|
||||
except Exception:
|
||||
results[name] = []
|
||||
|
||||
return results
|
||||
|
||||
# ==========================================================================
|
||||
# Position Management
|
||||
# ==========================================================================
|
||||
|
||||
async def get_positions(
|
||||
self,
|
||||
broker_name: Optional[str] = None,
|
||||
) -> Dict[str, List[Position]]:
|
||||
"""Get positions from one or all brokers.
|
||||
|
||||
Args:
|
||||
broker_name: Specific broker (all if None)
|
||||
|
||||
Returns:
|
||||
Dict mapping broker name to list of positions
|
||||
"""
|
||||
results = {}
|
||||
|
||||
if broker_name:
|
||||
brokers = [(broker_name, self.get_broker(broker_name))]
|
||||
else:
|
||||
brokers = [
|
||||
(name, reg.broker)
|
||||
for name, reg in self._brokers.items()
|
||||
if reg.enabled and reg.broker.is_connected
|
||||
]
|
||||
|
||||
for name, broker in brokers:
|
||||
try:
|
||||
positions = await broker.get_positions()
|
||||
results[name] = positions
|
||||
except Exception:
|
||||
results[name] = []
|
||||
|
||||
return results
|
||||
|
||||
async def get_all_positions(self) -> List[Tuple[str, Position]]:
|
||||
"""Get aggregated positions across all brokers.
|
||||
|
||||
Returns:
|
||||
List of (broker_name, position) tuples
|
||||
"""
|
||||
all_positions = []
|
||||
positions_by_broker = await self.get_positions()
|
||||
|
||||
for broker_name, positions in positions_by_broker.items():
|
||||
for position in positions:
|
||||
all_positions.append((broker_name, position))
|
||||
|
||||
return all_positions
|
||||
|
||||
async def get_position(
|
||||
self,
|
||||
symbol: str,
|
||||
broker_name: Optional[str] = None,
|
||||
) -> Optional[Tuple[str, Position]]:
|
||||
"""Get position for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
broker_name: Specific broker (searches all if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (broker_name, position) or None
|
||||
"""
|
||||
if broker_name:
|
||||
broker = self.get_broker(broker_name)
|
||||
position = await broker.get_position(symbol)
|
||||
if position:
|
||||
return (broker_name, position)
|
||||
return None
|
||||
|
||||
# Search all brokers
|
||||
for name, reg in self._brokers.items():
|
||||
if reg.enabled and reg.broker.is_connected:
|
||||
try:
|
||||
position = await reg.broker.get_position(symbol)
|
||||
if position:
|
||||
return (name, position)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def close_position(
|
||||
self,
|
||||
symbol: str,
|
||||
broker_name: Optional[str] = None,
|
||||
quantity: Optional[Decimal] = None,
|
||||
) -> Order:
|
||||
"""Close a position.
|
||||
|
||||
Args:
|
||||
symbol: Symbol to close
|
||||
broker_name: Specific broker (searches if None)
|
||||
quantity: Quantity to close (full if None)
|
||||
|
||||
Returns:
|
||||
Closing order
|
||||
|
||||
Raises:
|
||||
PositionError: If position not found
|
||||
"""
|
||||
result = await self.get_position(symbol, broker_name)
|
||||
if result is None:
|
||||
raise PositionError(f"No position found for {symbol}")
|
||||
|
||||
actual_broker_name, position = result
|
||||
broker = self.get_broker(actual_broker_name)
|
||||
return await broker.close_position(symbol, quantity)
|
||||
|
||||
async def close_all_positions(
|
||||
self,
|
||||
broker_name: Optional[str] = None,
|
||||
) -> Dict[str, List[Order]]:
|
||||
"""Close all positions across one or all brokers.
|
||||
|
||||
Args:
|
||||
broker_name: Specific broker (all if None)
|
||||
|
||||
Returns:
|
||||
Dict mapping broker name to list of closing orders
|
||||
"""
|
||||
results = {}
|
||||
|
||||
if broker_name:
|
||||
brokers = [(broker_name, self.get_broker(broker_name))]
|
||||
else:
|
||||
brokers = [
|
||||
(name, reg.broker)
|
||||
for name, reg in self._brokers.items()
|
||||
if reg.enabled and reg.broker.is_connected
|
||||
]
|
||||
|
||||
for name, broker in brokers:
|
||||
try:
|
||||
orders = await broker.close_all_positions()
|
||||
results[name] = orders
|
||||
except Exception:
|
||||
results[name] = []
|
||||
|
||||
return results
|
||||
|
||||
# ==========================================================================
|
||||
# Account Information
|
||||
# ==========================================================================
|
||||
|
||||
async def get_accounts(self) -> Dict[str, AccountInfo]:
|
||||
"""Get account information from all brokers.
|
||||
|
||||
Returns:
|
||||
Dict mapping broker name to account info
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for name, reg in self._brokers.items():
|
||||
if reg.enabled and reg.broker.is_connected:
|
||||
try:
|
||||
results[name] = await reg.broker.get_account()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return results
|
||||
|
||||
async def get_total_equity(self) -> Decimal:
|
||||
"""Get total equity across all brokers.
|
||||
|
||||
Returns:
|
||||
Total equity value
|
||||
"""
|
||||
total = Decimal("0")
|
||||
accounts = await self.get_accounts()
|
||||
|
||||
for account in accounts.values():
|
||||
total += account.equity
|
||||
|
||||
return total
|
||||
|
||||
async def get_total_buying_power(self) -> Decimal:
|
||||
"""Get total buying power across all brokers.
|
||||
|
||||
Returns:
|
||||
Total buying power
|
||||
"""
|
||||
total = Decimal("0")
|
||||
accounts = await self.get_accounts()
|
||||
|
||||
for account in accounts.values():
|
||||
total += account.buying_power
|
||||
|
||||
return total
|
||||
|
||||
# ==========================================================================
|
||||
# Market Data
|
||||
# ==========================================================================
|
||||
|
||||
async def get_quote(self, symbol: str) -> Quote:
|
||||
"""Get quote for a symbol from the appropriate broker.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Quote object
|
||||
"""
|
||||
broker, _ = self.route(symbol)
|
||||
return await broker.get_quote(symbol)
|
||||
|
||||
async def get_quotes(self, symbols: List[str]) -> Dict[str, Quote]:
|
||||
"""Get quotes for multiple symbols.
|
||||
|
||||
Args:
|
||||
symbols: List of symbols
|
||||
|
||||
Returns:
|
||||
Dict mapping symbol to quote
|
||||
"""
|
||||
# Group symbols by broker
|
||||
broker_symbols: Dict[str, List[str]] = {}
|
||||
for symbol in symbols:
|
||||
broker, _ = self.route(symbol)
|
||||
if broker.name not in broker_symbols:
|
||||
broker_symbols[broker.name] = []
|
||||
broker_symbols[broker.name].append(symbol)
|
||||
|
||||
# Fetch from each broker
|
||||
results = {}
|
||||
for broker_name, syms in broker_symbols.items():
|
||||
broker = self.get_broker(broker_name)
|
||||
broker_quotes = await broker.get_quotes(syms)
|
||||
results.update(broker_quotes)
|
||||
|
||||
return results
|
||||
|
||||
async def get_asset(self, symbol: str) -> AssetInfo:
|
||||
"""Get asset information from the appropriate broker.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
AssetInfo object
|
||||
"""
|
||||
broker, _ = self.route(symbol)
|
||||
return await broker.get_asset(symbol)
|
||||
|
||||
# ==========================================================================
|
||||
# Utility Methods
|
||||
# ==========================================================================
|
||||
|
||||
def get_routing_history(
|
||||
self,
|
||||
limit: int = 100,
|
||||
symbol: Optional[str] = None,
|
||||
) -> List[RoutingDecision]:
|
||||
"""Get routing decision history.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of decisions to return
|
||||
symbol: Optional filter by symbol
|
||||
|
||||
Returns:
|
||||
List of routing decisions (newest first)
|
||||
"""
|
||||
history = self._routing_history[::-1] # Reverse for newest first
|
||||
|
||||
if symbol:
|
||||
history = [d for d in history if d.symbol == symbol]
|
||||
|
||||
return history[:limit]
|
||||
|
||||
def get_broker_status(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get status of all registered brokers.
|
||||
|
||||
Returns:
|
||||
Dict with broker status information
|
||||
"""
|
||||
status = {}
|
||||
|
||||
for name, reg in self._brokers.items():
|
||||
status[name] = {
|
||||
"connected": reg.broker.is_connected,
|
||||
"enabled": reg.enabled,
|
||||
"paper_trading": reg.broker.is_paper_trading,
|
||||
"asset_classes": [c.value for c in reg.asset_classes],
|
||||
"priority": reg.priority,
|
||||
"is_primary": reg.is_primary,
|
||||
"registered_at": reg.registered_at.isoformat(),
|
||||
}
|
||||
|
||||
return status
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation."""
|
||||
brokers = ", ".join(self._brokers.keys())
|
||||
return f"BrokerRouter(brokers=[{brokers}])"
|
||||
Loading…
Reference in New Issue