968 lines
32 KiB
Python
968 lines
32 KiB
Python
"""Tests for Broker Router module.
|
|
|
|
Issue #23: [EXEC-22] Broker router - route by asset class
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime
|
|
from decimal import Decimal
|
|
from typing import Dict, List, Optional
|
|
|
|
from tradingagents.execution import (
|
|
# Enums
|
|
AssetClass,
|
|
OrderSide,
|
|
OrderType,
|
|
TimeInForce,
|
|
OrderStatus,
|
|
PositionSide,
|
|
# Data Classes
|
|
OrderRequest,
|
|
Order,
|
|
Position,
|
|
AccountInfo,
|
|
Quote,
|
|
AssetInfo,
|
|
# Exceptions
|
|
BrokerError,
|
|
OrderError,
|
|
PositionError,
|
|
# Base Class
|
|
BrokerBase,
|
|
# Router
|
|
BrokerRouter,
|
|
BrokerRegistration,
|
|
RoutingDecision,
|
|
SymbolClassifier,
|
|
RoutingError,
|
|
NoBrokerError,
|
|
BrokerNotFoundError,
|
|
DuplicateBrokerError,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Mock Broker for Testing
|
|
# =============================================================================
|
|
|
|
|
|
class MockBroker(BrokerBase):
|
|
"""Mock broker for testing router functionality."""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
asset_classes: List[AssetClass],
|
|
paper_trading: bool = True,
|
|
):
|
|
super().__init__(
|
|
name=name,
|
|
supported_asset_classes=asset_classes,
|
|
paper_trading=paper_trading,
|
|
)
|
|
self._orders: Dict[str, Order] = {}
|
|
self._positions: Dict[str, Position] = {}
|
|
self._order_counter = 0
|
|
|
|
async def connect(self) -> bool:
|
|
self._connected = True
|
|
return True
|
|
|
|
async def disconnect(self) -> None:
|
|
self._connected = False
|
|
|
|
async def is_market_open(self) -> bool:
|
|
return True
|
|
|
|
async def get_account(self) -> AccountInfo:
|
|
return AccountInfo(
|
|
account_id=f"{self._name}_account",
|
|
account_type="margin",
|
|
status="active",
|
|
cash=Decimal("100000"),
|
|
portfolio_value=Decimal("150000"),
|
|
buying_power=Decimal("200000"),
|
|
equity=Decimal("150000"),
|
|
)
|
|
|
|
async def submit_order(self, request: OrderRequest) -> Order:
|
|
self._order_counter += 1
|
|
order = Order(
|
|
broker_order_id=f"{self._name}-ORD-{self._order_counter}",
|
|
client_order_id=request.client_order_id,
|
|
symbol=request.symbol,
|
|
side=request.side,
|
|
quantity=request.quantity,
|
|
order_type=request.order_type,
|
|
status=OrderStatus.NEW,
|
|
limit_price=request.limit_price,
|
|
stop_price=request.stop_price,
|
|
time_in_force=request.time_in_force,
|
|
created_at=datetime.now(),
|
|
)
|
|
self._orders[order.broker_order_id] = order
|
|
return order
|
|
|
|
async def cancel_order(self, order_id: str) -> Order:
|
|
if order_id not in self._orders:
|
|
raise OrderError(f"Order {order_id} not found")
|
|
order = self._orders[order_id]
|
|
order.status = OrderStatus.CANCELLED
|
|
return order
|
|
|
|
async def replace_order(
|
|
self,
|
|
order_id: str,
|
|
quantity: Optional[Decimal] = None,
|
|
limit_price: Optional[Decimal] = None,
|
|
stop_price: Optional[Decimal] = None,
|
|
time_in_force: Optional[TimeInForce] = None,
|
|
) -> Order:
|
|
old_order = self._orders.get(order_id)
|
|
if not old_order:
|
|
raise OrderError(f"Order {order_id} not found")
|
|
|
|
self._order_counter += 1
|
|
new_order = Order(
|
|
broker_order_id=f"{self._name}-ORD-{self._order_counter}",
|
|
client_order_id=old_order.client_order_id,
|
|
symbol=old_order.symbol,
|
|
side=old_order.side,
|
|
quantity=quantity or old_order.quantity,
|
|
order_type=old_order.order_type,
|
|
status=OrderStatus.NEW,
|
|
limit_price=limit_price or old_order.limit_price,
|
|
stop_price=stop_price or old_order.stop_price,
|
|
time_in_force=time_in_force or old_order.time_in_force,
|
|
)
|
|
self._orders[new_order.broker_order_id] = new_order
|
|
return new_order
|
|
|
|
async def get_order(self, order_id: str) -> Order:
|
|
if order_id not in self._orders:
|
|
raise OrderError(f"Order {order_id} not found")
|
|
return self._orders[order_id]
|
|
|
|
async def get_orders(
|
|
self,
|
|
status: Optional[OrderStatus] = None,
|
|
limit: int = 100,
|
|
symbols: Optional[List[str]] = None,
|
|
) -> List[Order]:
|
|
orders = list(self._orders.values())
|
|
if status:
|
|
orders = [o for o in orders if o.status == status]
|
|
if symbols:
|
|
orders = [o for o in orders if o.symbol in symbols]
|
|
return orders[:limit]
|
|
|
|
async def get_positions(self) -> List[Position]:
|
|
return list(self._positions.values())
|
|
|
|
async def get_position(self, symbol: str) -> Optional[Position]:
|
|
return self._positions.get(symbol)
|
|
|
|
async def get_quote(self, symbol: str) -> Quote:
|
|
return Quote(
|
|
symbol=symbol,
|
|
bid_price=Decimal("100.00"),
|
|
ask_price=Decimal("100.05"),
|
|
last_price=Decimal("100.02"),
|
|
volume=1000000,
|
|
)
|
|
|
|
async def get_asset(self, symbol: str) -> AssetInfo:
|
|
return AssetInfo(
|
|
symbol=symbol,
|
|
name=f"{symbol} Asset",
|
|
tradable=True,
|
|
)
|
|
|
|
def add_position(self, position: Position) -> None:
|
|
"""Helper to add test positions."""
|
|
self._positions[position.symbol] = position
|
|
|
|
|
|
# =============================================================================
|
|
# SymbolClassifier Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestSymbolClassifier:
|
|
"""Tests for SymbolClassifier."""
|
|
|
|
def test_classify_equity(self):
|
|
"""Test equity classification."""
|
|
classifier = SymbolClassifier()
|
|
|
|
assert classifier.classify("AAPL") == AssetClass.EQUITY
|
|
assert classifier.classify("MSFT") == AssetClass.EQUITY
|
|
assert classifier.classify("GOOGL") == AssetClass.EQUITY
|
|
|
|
def test_classify_etf(self):
|
|
"""Test ETF classification."""
|
|
classifier = SymbolClassifier()
|
|
|
|
assert classifier.classify("SPY") == AssetClass.ETF
|
|
assert classifier.classify("QQQ") == AssetClass.ETF
|
|
assert classifier.classify("VTI") == AssetClass.ETF
|
|
|
|
def test_classify_crypto(self):
|
|
"""Test crypto classification."""
|
|
classifier = SymbolClassifier()
|
|
|
|
assert classifier.classify("BTCUSD") == AssetClass.CRYPTO
|
|
assert classifier.classify("ETHUSD") == AssetClass.CRYPTO
|
|
assert classifier.classify("BTC") == AssetClass.CRYPTO
|
|
|
|
def test_classify_future(self):
|
|
"""Test futures classification."""
|
|
classifier = SymbolClassifier()
|
|
|
|
assert classifier.classify("ESZ24") == AssetClass.FUTURE # S&P 500 Dec 2024
|
|
assert classifier.classify("CLF25") == AssetClass.FUTURE # Crude Oil Jan 2025
|
|
assert classifier.classify("GCG24") == AssetClass.FUTURE # Gold Feb 2024
|
|
|
|
def test_custom_mapping(self):
|
|
"""Test custom symbol mapping."""
|
|
classifier = SymbolClassifier()
|
|
|
|
# Add custom mapping
|
|
classifier.add_mapping("CUSTOM", AssetClass.BOND)
|
|
|
|
assert classifier.classify("CUSTOM") == AssetClass.BOND
|
|
|
|
def test_custom_mapping_overrides_default(self):
|
|
"""Test that custom mappings override defaults."""
|
|
classifier = SymbolClassifier()
|
|
|
|
# SPY is normally an ETF
|
|
assert classifier.classify("SPY") == AssetClass.ETF
|
|
|
|
# Override with custom mapping
|
|
classifier.add_mapping("SPY", AssetClass.EQUITY)
|
|
classifier.clear_cache() # Clear cache to pick up new mapping
|
|
|
|
assert classifier.classify("SPY") == AssetClass.EQUITY
|
|
|
|
def test_cache(self):
|
|
"""Test classification caching."""
|
|
classifier = SymbolClassifier()
|
|
|
|
# First call should classify
|
|
result1 = classifier.classify("AAPL")
|
|
|
|
# Second call should use cache
|
|
result2 = classifier.classify("AAPL")
|
|
|
|
assert result1 == result2 == AssetClass.EQUITY
|
|
|
|
def test_clear_cache(self):
|
|
"""Test clearing classification cache."""
|
|
classifier = SymbolClassifier()
|
|
|
|
classifier.classify("AAPL") # Populate cache
|
|
|
|
classifier.clear_cache()
|
|
|
|
# Should still work after cache clear
|
|
assert classifier.classify("AAPL") == AssetClass.EQUITY
|
|
|
|
|
|
# =============================================================================
|
|
# BrokerRegistration Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestBrokerRegistration:
|
|
"""Tests for BrokerRegistration dataclass."""
|
|
|
|
def test_create_registration(self):
|
|
"""Test creating a broker registration."""
|
|
broker = MockBroker("TestBroker", [AssetClass.EQUITY])
|
|
|
|
reg = BrokerRegistration(
|
|
broker=broker,
|
|
asset_classes={AssetClass.EQUITY, AssetClass.ETF},
|
|
priority=10,
|
|
is_primary=True,
|
|
)
|
|
|
|
assert reg.broker == broker
|
|
assert AssetClass.EQUITY in reg.asset_classes
|
|
assert reg.priority == 10
|
|
assert reg.is_primary is True
|
|
assert reg.enabled is True
|
|
|
|
def test_default_values(self):
|
|
"""Test default registration values."""
|
|
broker = MockBroker("TestBroker", [AssetClass.EQUITY])
|
|
|
|
reg = BrokerRegistration(
|
|
broker=broker,
|
|
asset_classes={AssetClass.EQUITY},
|
|
)
|
|
|
|
assert reg.priority == 0
|
|
assert reg.is_primary is False
|
|
assert reg.enabled is True
|
|
assert reg.registered_at is not None
|
|
|
|
|
|
# =============================================================================
|
|
# RoutingDecision Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestRoutingDecision:
|
|
"""Tests for RoutingDecision dataclass."""
|
|
|
|
def test_create_decision(self):
|
|
"""Test creating a routing decision."""
|
|
decision = RoutingDecision(
|
|
symbol="AAPL",
|
|
asset_class=AssetClass.EQUITY,
|
|
broker_name="Alpaca",
|
|
reason="Primary broker for equity",
|
|
alternatives=["IBKR"],
|
|
)
|
|
|
|
assert decision.symbol == "AAPL"
|
|
assert decision.asset_class == AssetClass.EQUITY
|
|
assert decision.broker_name == "Alpaca"
|
|
assert "IBKR" in decision.alternatives
|
|
|
|
|
|
# =============================================================================
|
|
# BrokerRouter Tests - Registration
|
|
# =============================================================================
|
|
|
|
|
|
class TestBrokerRouterRegistration:
|
|
"""Tests for broker registration."""
|
|
|
|
def test_register_broker(self):
|
|
"""Test registering a broker."""
|
|
router = BrokerRouter()
|
|
broker = MockBroker("Alpaca", [AssetClass.EQUITY, AssetClass.ETF])
|
|
|
|
router.register(broker)
|
|
|
|
assert "Alpaca" in router.registered_brokers
|
|
assert AssetClass.EQUITY in router.supported_asset_classes
|
|
|
|
def test_register_with_specific_classes(self):
|
|
"""Test registering with specific asset classes."""
|
|
router = BrokerRouter()
|
|
broker = MockBroker("Alpaca", [AssetClass.EQUITY, AssetClass.ETF, AssetClass.CRYPTO])
|
|
|
|
# Only register for equity
|
|
router.register(broker, asset_classes=[AssetClass.EQUITY])
|
|
|
|
assert AssetClass.EQUITY in router.supported_asset_classes
|
|
# ETF and CRYPTO not registered even though broker supports them
|
|
assert AssetClass.CRYPTO not in router.supported_asset_classes
|
|
|
|
def test_register_duplicate_raises(self):
|
|
"""Test registering duplicate broker raises error."""
|
|
router = BrokerRouter()
|
|
broker = MockBroker("Alpaca", [AssetClass.EQUITY])
|
|
|
|
router.register(broker)
|
|
|
|
with pytest.raises(DuplicateBrokerError, match="already registered"):
|
|
router.register(broker)
|
|
|
|
def test_unregister_broker(self):
|
|
"""Test unregistering a broker."""
|
|
router = BrokerRouter()
|
|
broker = MockBroker("Alpaca", [AssetClass.EQUITY])
|
|
|
|
router.register(broker)
|
|
router.unregister("Alpaca")
|
|
|
|
assert "Alpaca" not in router.registered_brokers
|
|
|
|
def test_unregister_nonexistent_raises(self):
|
|
"""Test unregistering non-existent broker raises error."""
|
|
router = BrokerRouter()
|
|
|
|
with pytest.raises(BrokerNotFoundError, match="not registered"):
|
|
router.unregister("NonExistent")
|
|
|
|
def test_get_broker(self):
|
|
"""Test getting a broker by name."""
|
|
router = BrokerRouter()
|
|
broker = MockBroker("Alpaca", [AssetClass.EQUITY])
|
|
|
|
router.register(broker)
|
|
|
|
result = router.get_broker("Alpaca")
|
|
assert result == broker
|
|
|
|
def test_get_broker_not_found(self):
|
|
"""Test getting non-existent broker raises error."""
|
|
router = BrokerRouter()
|
|
|
|
with pytest.raises(BrokerNotFoundError):
|
|
router.get_broker("NonExistent")
|
|
|
|
|
|
# =============================================================================
|
|
# BrokerRouter Tests - Routing
|
|
# =============================================================================
|
|
|
|
|
|
class TestBrokerRouterRouting:
|
|
"""Tests for order routing."""
|
|
|
|
@pytest.fixture
|
|
def router_with_brokers(self):
|
|
"""Create a router with multiple brokers."""
|
|
router = BrokerRouter()
|
|
|
|
equity_broker = MockBroker("EquityBroker", [AssetClass.EQUITY, AssetClass.ETF])
|
|
crypto_broker = MockBroker("CryptoBroker", [AssetClass.CRYPTO])
|
|
futures_broker = MockBroker("FuturesBroker", [AssetClass.FUTURE, AssetClass.OPTION])
|
|
|
|
router.register(equity_broker, [AssetClass.EQUITY, AssetClass.ETF])
|
|
router.register(crypto_broker, [AssetClass.CRYPTO])
|
|
router.register(futures_broker, [AssetClass.FUTURE, AssetClass.OPTION])
|
|
|
|
return router
|
|
|
|
def test_route_equity(self, router_with_brokers):
|
|
"""Test routing equity symbol."""
|
|
broker, decision = router_with_brokers.route("AAPL")
|
|
|
|
assert broker.name == "EquityBroker"
|
|
assert decision.asset_class == AssetClass.EQUITY
|
|
|
|
def test_route_etf(self, router_with_brokers):
|
|
"""Test routing ETF symbol."""
|
|
broker, decision = router_with_brokers.route("SPY")
|
|
|
|
assert broker.name == "EquityBroker"
|
|
assert decision.asset_class == AssetClass.ETF
|
|
|
|
def test_route_crypto(self, router_with_brokers):
|
|
"""Test routing crypto symbol."""
|
|
broker, decision = router_with_brokers.route("BTCUSD")
|
|
|
|
assert broker.name == "CryptoBroker"
|
|
assert decision.asset_class == AssetClass.CRYPTO
|
|
|
|
def test_route_futures(self, router_with_brokers):
|
|
"""Test routing futures symbol."""
|
|
broker, decision = router_with_brokers.route("ESZ24")
|
|
|
|
assert broker.name == "FuturesBroker"
|
|
assert decision.asset_class == AssetClass.FUTURE
|
|
|
|
def test_route_no_broker_raises(self):
|
|
"""Test routing when no broker available."""
|
|
router = BrokerRouter()
|
|
|
|
with pytest.raises(NoBrokerError, match="No broker available"):
|
|
router.route("AAPL")
|
|
|
|
def test_fallback_broker(self):
|
|
"""Test fallback broker usage."""
|
|
router = BrokerRouter()
|
|
equity_broker = MockBroker("EquityBroker", [AssetClass.EQUITY])
|
|
|
|
router.register(equity_broker)
|
|
router.set_fallback("EquityBroker")
|
|
|
|
# Route an unknown asset class (FOREX) - should use fallback
|
|
router.add_symbol_mapping("EURUSD", AssetClass.FOREX)
|
|
broker, decision = router.route("EURUSD")
|
|
|
|
assert broker.name == "EquityBroker"
|
|
assert "Fallback" in decision.reason
|
|
|
|
def test_disabled_broker_skipped(self, router_with_brokers):
|
|
"""Test that disabled brokers are skipped."""
|
|
router_with_brokers.disable_broker("EquityBroker")
|
|
|
|
with pytest.raises(NoBrokerError):
|
|
router_with_brokers.route("AAPL")
|
|
|
|
def test_enable_broker(self, router_with_brokers):
|
|
"""Test enabling a broker."""
|
|
router_with_brokers.disable_broker("EquityBroker")
|
|
router_with_brokers.enable_broker("EquityBroker")
|
|
|
|
broker, _ = router_with_brokers.route("AAPL")
|
|
assert broker.name == "EquityBroker"
|
|
|
|
def test_priority_routing(self):
|
|
"""Test priority-based broker selection."""
|
|
router = BrokerRouter()
|
|
|
|
low_priority = MockBroker("LowPriority", [AssetClass.EQUITY])
|
|
high_priority = MockBroker("HighPriority", [AssetClass.EQUITY])
|
|
|
|
router.register(low_priority, priority=1)
|
|
router.register(high_priority, priority=10)
|
|
|
|
broker, _ = router.route("AAPL")
|
|
assert broker.name == "HighPriority"
|
|
|
|
def test_primary_broker_preference(self):
|
|
"""Test primary broker is preferred."""
|
|
router = BrokerRouter()
|
|
|
|
secondary = MockBroker("Secondary", [AssetClass.EQUITY])
|
|
primary = MockBroker("Primary", [AssetClass.EQUITY])
|
|
|
|
router.register(secondary, priority=10)
|
|
router.register(primary, priority=1, primary=True)
|
|
|
|
broker, _ = router.route("AAPL")
|
|
assert broker.name == "Primary"
|
|
|
|
def test_routing_history(self, router_with_brokers):
|
|
"""Test routing history tracking."""
|
|
router_with_brokers.route("AAPL")
|
|
router_with_brokers.route("SPY")
|
|
router_with_brokers.route("BTCUSD")
|
|
|
|
history = router_with_brokers.get_routing_history()
|
|
|
|
assert len(history) == 3
|
|
# Most recent first
|
|
assert history[0].symbol == "BTCUSD"
|
|
assert history[1].symbol == "SPY"
|
|
assert history[2].symbol == "AAPL"
|
|
|
|
def test_routing_history_filter(self, router_with_brokers):
|
|
"""Test filtering routing history by symbol."""
|
|
router_with_brokers.route("AAPL")
|
|
router_with_brokers.route("SPY")
|
|
router_with_brokers.route("AAPL")
|
|
|
|
history = router_with_brokers.get_routing_history(symbol="AAPL")
|
|
|
|
assert len(history) == 2
|
|
assert all(d.symbol == "AAPL" for d in history)
|
|
|
|
|
|
# =============================================================================
|
|
# BrokerRouter Tests - Order Management
|
|
# =============================================================================
|
|
|
|
|
|
class TestBrokerRouterOrders:
|
|
"""Tests for order management through router."""
|
|
|
|
async def _create_connected_router(self):
|
|
"""Create a connected router."""
|
|
router = BrokerRouter()
|
|
|
|
equity_broker = MockBroker("EquityBroker", [AssetClass.EQUITY, AssetClass.ETF])
|
|
crypto_broker = MockBroker("CryptoBroker", [AssetClass.CRYPTO])
|
|
|
|
router.register(equity_broker)
|
|
router.register(crypto_broker)
|
|
|
|
await router.connect_all()
|
|
return router
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_submit_order_auto_route(self):
|
|
"""Test submitting order with auto-routing."""
|
|
router = await self._create_connected_router()
|
|
request = OrderRequest.market("AAPL", OrderSide.BUY, 100)
|
|
|
|
order = await router.submit_order(request)
|
|
|
|
assert order.symbol == "AAPL"
|
|
assert "EquityBroker" in order.broker_order_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_submit_order_specific_broker(self):
|
|
"""Test submitting order to specific broker."""
|
|
router = await self._create_connected_router()
|
|
request = OrderRequest.market("AAPL", OrderSide.BUY, 100)
|
|
|
|
order = await router.submit_order(request, broker_name="EquityBroker")
|
|
|
|
assert "EquityBroker" in order.broker_order_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_order(self):
|
|
"""Test cancelling an order."""
|
|
router = await self._create_connected_router()
|
|
request = OrderRequest.market("AAPL", OrderSide.BUY, 100)
|
|
order = await router.submit_order(request)
|
|
|
|
cancelled = await router.cancel_order(
|
|
order.broker_order_id,
|
|
broker_name="EquityBroker"
|
|
)
|
|
|
|
assert cancelled.status == OrderStatus.CANCELLED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_orders_single_broker(self):
|
|
"""Test getting orders from single broker."""
|
|
router = await self._create_connected_router()
|
|
await router.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, 100))
|
|
await router.submit_order(OrderRequest.market("BTCUSD", OrderSide.BUY, 1))
|
|
|
|
orders = await router.get_orders(broker_name="EquityBroker")
|
|
|
|
assert "EquityBroker" in orders
|
|
assert len(orders["EquityBroker"]) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_orders_all_brokers(self):
|
|
"""Test getting orders from all brokers."""
|
|
router = await self._create_connected_router()
|
|
await router.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, 100))
|
|
await router.submit_order(OrderRequest.market("BTCUSD", OrderSide.BUY, 1))
|
|
|
|
orders = await router.get_orders()
|
|
|
|
assert len(orders) == 2 # Two brokers
|
|
assert "EquityBroker" in orders
|
|
assert "CryptoBroker" in orders
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_all_orders(self):
|
|
"""Test cancelling all orders."""
|
|
router = await self._create_connected_router()
|
|
await router.submit_order(OrderRequest.market("AAPL", OrderSide.BUY, 100))
|
|
await router.submit_order(OrderRequest.market("MSFT", OrderSide.BUY, 50))
|
|
|
|
cancelled = await router.cancel_all_orders()
|
|
|
|
assert "EquityBroker" in cancelled
|
|
assert len(cancelled["EquityBroker"]) == 2
|
|
|
|
|
|
# =============================================================================
|
|
# BrokerRouter Tests - Position Management
|
|
# =============================================================================
|
|
|
|
|
|
class TestBrokerRouterPositions:
|
|
"""Tests for position management through router."""
|
|
|
|
async def _create_router_with_positions(self):
|
|
"""Create router with positions."""
|
|
router = BrokerRouter()
|
|
|
|
equity_broker = MockBroker("EquityBroker", [AssetClass.EQUITY])
|
|
crypto_broker = MockBroker("CryptoBroker", [AssetClass.CRYPTO])
|
|
|
|
router.register(equity_broker)
|
|
router.register(crypto_broker)
|
|
|
|
await router.connect_all()
|
|
|
|
# Add some positions
|
|
equity_broker.add_position(Position(
|
|
symbol="AAPL",
|
|
quantity=Decimal("100"),
|
|
side=PositionSide.LONG,
|
|
avg_entry_price=Decimal("150"),
|
|
current_price=Decimal("160"),
|
|
market_value=Decimal("16000"),
|
|
cost_basis=Decimal("15000"),
|
|
unrealized_pnl=Decimal("1000"),
|
|
unrealized_pnl_percent=Decimal("6.67"),
|
|
))
|
|
|
|
crypto_broker.add_position(Position(
|
|
symbol="BTCUSD",
|
|
quantity=Decimal("1"),
|
|
side=PositionSide.LONG,
|
|
avg_entry_price=Decimal("40000"),
|
|
current_price=Decimal("45000"),
|
|
market_value=Decimal("45000"),
|
|
cost_basis=Decimal("40000"),
|
|
unrealized_pnl=Decimal("5000"),
|
|
unrealized_pnl_percent=Decimal("12.5"),
|
|
))
|
|
|
|
return router
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_positions_all(self):
|
|
"""Test getting positions from all brokers."""
|
|
router = await self._create_router_with_positions()
|
|
positions = await router.get_positions()
|
|
|
|
assert len(positions) == 2
|
|
assert "EquityBroker" in positions
|
|
assert "CryptoBroker" in positions
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_all_positions(self):
|
|
"""Test getting aggregated positions."""
|
|
router = await self._create_router_with_positions()
|
|
positions = await router.get_all_positions()
|
|
|
|
assert len(positions) == 2
|
|
symbols = {p[1].symbol for p in positions}
|
|
assert "AAPL" in symbols
|
|
assert "BTCUSD" in symbols
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_position(self):
|
|
"""Test getting specific position."""
|
|
router = await self._create_router_with_positions()
|
|
result = await router.get_position("AAPL")
|
|
|
|
assert result is not None
|
|
broker_name, position = result
|
|
assert broker_name == "EquityBroker"
|
|
assert position.symbol == "AAPL"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_position_not_found(self):
|
|
"""Test getting non-existent position."""
|
|
router = await self._create_router_with_positions()
|
|
result = await router.get_position("NONEXISTENT")
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_position(self):
|
|
"""Test closing a position."""
|
|
router = await self._create_router_with_positions()
|
|
order = await router.close_position("AAPL")
|
|
|
|
assert order.symbol == "AAPL"
|
|
assert order.side == OrderSide.SELL
|
|
assert order.quantity == Decimal("100")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_position_not_found(self):
|
|
"""Test closing non-existent position."""
|
|
router = await self._create_router_with_positions()
|
|
with pytest.raises(PositionError, match="No position found"):
|
|
await router.close_position("NONEXISTENT")
|
|
|
|
|
|
# =============================================================================
|
|
# BrokerRouter Tests - Account Management
|
|
# =============================================================================
|
|
|
|
|
|
class TestBrokerRouterAccounts:
|
|
"""Tests for account management through router."""
|
|
|
|
async def _create_connected_router(self):
|
|
"""Create connected router."""
|
|
router = BrokerRouter()
|
|
|
|
broker1 = MockBroker("Broker1", [AssetClass.EQUITY])
|
|
broker2 = MockBroker("Broker2", [AssetClass.CRYPTO])
|
|
|
|
router.register(broker1)
|
|
router.register(broker2)
|
|
|
|
await router.connect_all()
|
|
return router
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_accounts(self):
|
|
"""Test getting all accounts."""
|
|
router = await self._create_connected_router()
|
|
accounts = await router.get_accounts()
|
|
|
|
assert len(accounts) == 2
|
|
assert "Broker1" in accounts
|
|
assert "Broker2" in accounts
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_total_equity(self):
|
|
"""Test getting total equity."""
|
|
router = await self._create_connected_router()
|
|
equity = await router.get_total_equity()
|
|
|
|
# Each mock broker has 150000 equity
|
|
assert equity == Decimal("300000")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_total_buying_power(self):
|
|
"""Test getting total buying power."""
|
|
router = await self._create_connected_router()
|
|
power = await router.get_total_buying_power()
|
|
|
|
# Each mock broker has 200000 buying power
|
|
assert power == Decimal("400000")
|
|
|
|
|
|
# =============================================================================
|
|
# BrokerRouter Tests - Market Data
|
|
# =============================================================================
|
|
|
|
|
|
class TestBrokerRouterMarketData:
|
|
"""Tests for market data through router."""
|
|
|
|
async def _create_connected_router(self):
|
|
"""Create connected router."""
|
|
router = BrokerRouter()
|
|
|
|
equity_broker = MockBroker("EquityBroker", [AssetClass.EQUITY])
|
|
crypto_broker = MockBroker("CryptoBroker", [AssetClass.CRYPTO])
|
|
|
|
router.register(equity_broker)
|
|
router.register(crypto_broker)
|
|
|
|
await router.connect_all()
|
|
return router
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_quote(self):
|
|
"""Test getting a quote."""
|
|
router = await self._create_connected_router()
|
|
quote = await router.get_quote("AAPL")
|
|
|
|
assert quote.symbol == "AAPL"
|
|
assert quote.bid_price is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_quotes(self):
|
|
"""Test getting multiple quotes."""
|
|
router = await self._create_connected_router()
|
|
quotes = await router.get_quotes(["AAPL", "BTCUSD"])
|
|
|
|
assert len(quotes) == 2
|
|
assert "AAPL" in quotes
|
|
assert "BTCUSD" in quotes
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_asset(self):
|
|
"""Test getting asset info."""
|
|
router = await self._create_connected_router()
|
|
asset = await router.get_asset("AAPL")
|
|
|
|
assert asset.symbol == "AAPL"
|
|
|
|
|
|
# =============================================================================
|
|
# BrokerRouter Tests - Connection Management
|
|
# =============================================================================
|
|
|
|
|
|
class TestBrokerRouterConnection:
|
|
"""Tests for connection management."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_all(self):
|
|
"""Test connecting all brokers."""
|
|
router = BrokerRouter()
|
|
|
|
broker1 = MockBroker("Broker1", [AssetClass.EQUITY])
|
|
broker2 = MockBroker("Broker2", [AssetClass.CRYPTO])
|
|
|
|
router.register(broker1)
|
|
router.register(broker2)
|
|
|
|
results = await router.connect_all()
|
|
|
|
assert results["Broker1"] is True
|
|
assert results["Broker2"] is True
|
|
assert broker1.is_connected
|
|
assert broker2.is_connected
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_all(self):
|
|
"""Test disconnecting all brokers."""
|
|
router = BrokerRouter()
|
|
|
|
broker = MockBroker("Broker", [AssetClass.EQUITY])
|
|
router.register(broker)
|
|
|
|
await router.connect_all()
|
|
await router.disconnect_all()
|
|
|
|
assert not broker.is_connected
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_is_market_open(self):
|
|
"""Test checking market status."""
|
|
router = BrokerRouter()
|
|
|
|
broker = MockBroker("Broker", [AssetClass.EQUITY])
|
|
router.register(broker)
|
|
await router.connect_all()
|
|
|
|
is_open = await router.is_market_open()
|
|
assert is_open is True
|
|
|
|
|
|
# =============================================================================
|
|
# BrokerRouter Tests - Status
|
|
# =============================================================================
|
|
|
|
|
|
class TestBrokerRouterStatus:
|
|
"""Tests for status reporting."""
|
|
|
|
def test_get_broker_status(self):
|
|
"""Test getting broker status."""
|
|
router = BrokerRouter()
|
|
|
|
broker = MockBroker("TestBroker", [AssetClass.EQUITY, AssetClass.ETF])
|
|
router.register(broker, priority=5, primary=True)
|
|
|
|
status = router.get_broker_status()
|
|
|
|
assert "TestBroker" in status
|
|
broker_status = status["TestBroker"]
|
|
assert broker_status["connected"] is False
|
|
assert broker_status["enabled"] is True
|
|
assert broker_status["priority"] == 5
|
|
assert broker_status["is_primary"] is True
|
|
assert "equity" in broker_status["asset_classes"]
|
|
|
|
def test_repr(self):
|
|
"""Test string representation."""
|
|
router = BrokerRouter()
|
|
|
|
broker1 = MockBroker("Broker1", [AssetClass.EQUITY])
|
|
broker2 = MockBroker("Broker2", [AssetClass.CRYPTO])
|
|
|
|
router.register(broker1)
|
|
router.register(broker2)
|
|
|
|
repr_str = repr(router)
|
|
assert "BrokerRouter" in repr_str
|
|
assert "Broker1" in repr_str
|
|
assert "Broker2" in repr_str
|
|
|
|
|
|
# =============================================================================
|
|
# Router Exception Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestRouterExceptions:
|
|
"""Tests for router exceptions."""
|
|
|
|
def test_routing_error(self):
|
|
"""Test RoutingError exception."""
|
|
error = RoutingError("Routing failed")
|
|
assert isinstance(error, BrokerError)
|
|
|
|
def test_no_broker_error(self):
|
|
"""Test NoBrokerError exception."""
|
|
error = NoBrokerError("No broker for asset class")
|
|
assert isinstance(error, RoutingError)
|
|
|
|
def test_broker_not_found_error(self):
|
|
"""Test BrokerNotFoundError exception."""
|
|
error = BrokerNotFoundError("Broker not found")
|
|
assert isinstance(error, RoutingError)
|
|
|
|
def test_duplicate_broker_error(self):
|
|
"""Test DuplicateBrokerError exception."""
|
|
error = DuplicateBrokerError("Broker already registered")
|
|
assert isinstance(error, RoutingError)
|