444 lines
15 KiB
Python
444 lines
15 KiB
Python
"""
|
|
Comprehensive tests for base broker interface.
|
|
|
|
Tests order data structures, enumerations, convenience methods,
|
|
and abstract interface compliance.
|
|
"""
|
|
|
|
import pytest
|
|
from decimal import Decimal
|
|
from datetime import datetime
|
|
from abc import ABC
|
|
|
|
from tradingagents.brokers.base import (
|
|
BaseBroker,
|
|
BrokerOrder,
|
|
BrokerPosition,
|
|
BrokerAccount,
|
|
OrderSide,
|
|
OrderType,
|
|
OrderStatus,
|
|
BrokerError,
|
|
ConnectionError,
|
|
OrderError,
|
|
InsufficientFundsError,
|
|
)
|
|
|
|
|
|
class TestOrderEnumerations:
|
|
"""Test order-related enumerations."""
|
|
|
|
def test_order_side_values(self):
|
|
"""Test OrderSide enumeration values."""
|
|
assert OrderSide.BUY.value == "buy"
|
|
assert OrderSide.SELL.value == "sell"
|
|
|
|
def test_order_type_values(self):
|
|
"""Test OrderType enumeration values."""
|
|
assert OrderType.MARKET.value == "market"
|
|
assert OrderType.LIMIT.value == "limit"
|
|
assert OrderType.STOP.value == "stop"
|
|
assert OrderType.STOP_LIMIT.value == "stop_limit"
|
|
|
|
def test_order_status_values(self):
|
|
"""Test OrderStatus enumeration values."""
|
|
assert OrderStatus.PENDING.value == "pending"
|
|
assert OrderStatus.SUBMITTED.value == "submitted"
|
|
assert OrderStatus.FILLED.value == "filled"
|
|
assert OrderStatus.PARTIALLY_FILLED.value == "partially_filled"
|
|
assert OrderStatus.CANCELLED.value == "cancelled"
|
|
assert OrderStatus.REJECTED.value == "rejected"
|
|
|
|
|
|
class TestBrokerOrder:
|
|
"""Test BrokerOrder dataclass."""
|
|
|
|
def test_create_market_buy_order(self):
|
|
"""Test creating a market buy order."""
|
|
order = BrokerOrder(
|
|
symbol="AAPL",
|
|
side=OrderSide.BUY,
|
|
quantity=Decimal("100"),
|
|
order_type=OrderType.MARKET
|
|
)
|
|
|
|
assert order.symbol == "AAPL"
|
|
assert order.side == OrderSide.BUY
|
|
assert order.quantity == Decimal("100")
|
|
assert order.order_type == OrderType.MARKET
|
|
assert order.status == OrderStatus.PENDING
|
|
assert order.time_in_force == "day"
|
|
assert order.order_id is None
|
|
assert order.filled_qty == Decimal("0")
|
|
|
|
def test_create_limit_sell_order(self):
|
|
"""Test creating a limit sell order."""
|
|
order = BrokerOrder(
|
|
symbol="TSLA",
|
|
side=OrderSide.SELL,
|
|
quantity=Decimal("50"),
|
|
order_type=OrderType.LIMIT,
|
|
limit_price=Decimal("250.50")
|
|
)
|
|
|
|
assert order.symbol == "TSLA"
|
|
assert order.side == OrderSide.SELL
|
|
assert order.limit_price == Decimal("250.50")
|
|
|
|
def test_create_stop_loss_order(self):
|
|
"""Test creating a stop-loss order."""
|
|
order = BrokerOrder(
|
|
symbol="NVDA",
|
|
side=OrderSide.SELL,
|
|
quantity=Decimal("25"),
|
|
order_type=OrderType.STOP,
|
|
stop_price=Decimal("800.00")
|
|
)
|
|
|
|
assert order.stop_price == Decimal("800.00")
|
|
assert order.order_type == OrderType.STOP
|
|
|
|
def test_create_stop_limit_order(self):
|
|
"""Test creating a stop-limit order."""
|
|
order = BrokerOrder(
|
|
symbol="AMD",
|
|
side=OrderSide.BUY,
|
|
quantity=Decimal("100"),
|
|
order_type=OrderType.STOP_LIMIT,
|
|
stop_price=Decimal("140.00"),
|
|
limit_price=Decimal("142.00")
|
|
)
|
|
|
|
assert order.stop_price == Decimal("140.00")
|
|
assert order.limit_price == Decimal("142.00")
|
|
|
|
def test_order_with_custom_time_in_force(self):
|
|
"""Test order with custom time_in_force."""
|
|
order = BrokerOrder(
|
|
symbol="AAPL",
|
|
side=OrderSide.BUY,
|
|
quantity=Decimal("100"),
|
|
order_type=OrderType.MARKET,
|
|
time_in_force="gtc"
|
|
)
|
|
|
|
assert order.time_in_force == "gtc"
|
|
|
|
def test_order_with_filled_data(self):
|
|
"""Test order with filled data."""
|
|
filled_at = datetime.now()
|
|
order = BrokerOrder(
|
|
symbol="AAPL",
|
|
side=OrderSide.BUY,
|
|
quantity=Decimal("100"),
|
|
order_type=OrderType.MARKET,
|
|
order_id="order-123",
|
|
status=OrderStatus.FILLED,
|
|
filled_qty=Decimal("100"),
|
|
filled_price=Decimal("150.25"),
|
|
filled_at=filled_at
|
|
)
|
|
|
|
assert order.order_id == "order-123"
|
|
assert order.status == OrderStatus.FILLED
|
|
assert order.filled_qty == Decimal("100")
|
|
assert order.filled_price == Decimal("150.25")
|
|
assert order.filled_at == filled_at
|
|
|
|
|
|
class TestBrokerPosition:
|
|
"""Test BrokerPosition dataclass."""
|
|
|
|
def test_create_position(self):
|
|
"""Test creating a broker position."""
|
|
position = BrokerPosition(
|
|
symbol="AAPL",
|
|
quantity=Decimal("100"),
|
|
avg_entry_price=Decimal("150.00"),
|
|
current_price=Decimal("155.00"),
|
|
market_value=Decimal("15500.00"),
|
|
unrealized_pnl=Decimal("500.00"),
|
|
unrealized_pnl_percent=Decimal("0.0333"),
|
|
cost_basis=Decimal("15000.00")
|
|
)
|
|
|
|
assert position.symbol == "AAPL"
|
|
assert position.quantity == Decimal("100")
|
|
assert position.avg_entry_price == Decimal("150.00")
|
|
assert position.current_price == Decimal("155.00")
|
|
assert position.market_value == Decimal("15500.00")
|
|
assert position.unrealized_pnl == Decimal("500.00")
|
|
assert position.unrealized_pnl_percent == Decimal("0.0333")
|
|
assert position.cost_basis == Decimal("15000.00")
|
|
|
|
def test_position_with_loss(self):
|
|
"""Test position with unrealized loss."""
|
|
position = BrokerPosition(
|
|
symbol="TSLA",
|
|
quantity=Decimal("50"),
|
|
avg_entry_price=Decimal("250.00"),
|
|
current_price=Decimal("240.00"),
|
|
market_value=Decimal("12000.00"),
|
|
unrealized_pnl=Decimal("-500.00"),
|
|
unrealized_pnl_percent=Decimal("-0.04"),
|
|
cost_basis=Decimal("12500.00")
|
|
)
|
|
|
|
assert position.unrealized_pnl < 0
|
|
assert position.unrealized_pnl_percent < 0
|
|
|
|
|
|
class TestBrokerAccount:
|
|
"""Test BrokerAccount dataclass."""
|
|
|
|
def test_create_account(self):
|
|
"""Test creating a broker account."""
|
|
account = BrokerAccount(
|
|
account_number="ACC123456",
|
|
cash=Decimal("50000.00"),
|
|
buying_power=Decimal("200000.00"),
|
|
portfolio_value=Decimal("75000.00"),
|
|
equity=Decimal("75000.00"),
|
|
last_equity=Decimal("74500.00"),
|
|
multiplier=Decimal("4"),
|
|
currency="USD",
|
|
pattern_day_trader=False
|
|
)
|
|
|
|
assert account.account_number == "ACC123456"
|
|
assert account.cash == Decimal("50000.00")
|
|
assert account.buying_power == Decimal("200000.00")
|
|
assert account.portfolio_value == Decimal("75000.00")
|
|
assert account.currency == "USD"
|
|
assert account.pattern_day_trader is False
|
|
|
|
def test_account_defaults(self):
|
|
"""Test account with default values."""
|
|
account = BrokerAccount(
|
|
account_number="ACC123456",
|
|
cash=Decimal("50000.00"),
|
|
buying_power=Decimal("50000.00"),
|
|
portfolio_value=Decimal("50000.00"),
|
|
equity=Decimal("50000.00"),
|
|
last_equity=Decimal("50000.00"),
|
|
multiplier=Decimal("1")
|
|
)
|
|
|
|
# Default values
|
|
assert account.currency == "USD"
|
|
assert account.pattern_day_trader is False
|
|
|
|
def test_account_with_pdt_status(self):
|
|
"""Test account with pattern day trader status."""
|
|
account = BrokerAccount(
|
|
account_number="ACC123456",
|
|
cash=Decimal("30000.00"),
|
|
buying_power=Decimal("120000.00"),
|
|
portfolio_value=Decimal("50000.00"),
|
|
equity=Decimal("50000.00"),
|
|
last_equity=Decimal("49000.00"),
|
|
multiplier=Decimal("4"),
|
|
pattern_day_trader=True
|
|
)
|
|
|
|
assert account.pattern_day_trader is True
|
|
assert account.multiplier == Decimal("4")
|
|
|
|
|
|
class TestBrokerExceptions:
|
|
"""Test broker exception classes."""
|
|
|
|
def test_broker_error(self):
|
|
"""Test BrokerError exception."""
|
|
with pytest.raises(BrokerError, match="Test error"):
|
|
raise BrokerError("Test error")
|
|
|
|
def test_connection_error(self):
|
|
"""Test ConnectionError exception."""
|
|
with pytest.raises(ConnectionError, match="Connection failed"):
|
|
raise ConnectionError("Connection failed")
|
|
|
|
# Should also be a BrokerError
|
|
with pytest.raises(BrokerError):
|
|
raise ConnectionError("Connection failed")
|
|
|
|
def test_order_error(self):
|
|
"""Test OrderError exception."""
|
|
with pytest.raises(OrderError, match="Order failed"):
|
|
raise OrderError("Order failed")
|
|
|
|
# Should also be a BrokerError
|
|
with pytest.raises(BrokerError):
|
|
raise OrderError("Order failed")
|
|
|
|
def test_insufficient_funds_error(self):
|
|
"""Test InsufficientFundsError exception."""
|
|
with pytest.raises(InsufficientFundsError, match="Insufficient funds"):
|
|
raise InsufficientFundsError("Insufficient funds")
|
|
|
|
# Should also be a BrokerError
|
|
with pytest.raises(BrokerError):
|
|
raise InsufficientFundsError("Insufficient funds")
|
|
|
|
|
|
class TestBaseBrokerInterface:
|
|
"""Test BaseBroker abstract interface."""
|
|
|
|
def test_base_broker_is_abstract(self):
|
|
"""Test that BaseBroker cannot be instantiated directly."""
|
|
# BaseBroker is abstract and should not be instantiable
|
|
assert ABC in BaseBroker.__bases__
|
|
|
|
def test_base_broker_paper_trading_flag(self):
|
|
"""Test that BaseBroker stores paper_trading flag."""
|
|
# Create a concrete implementation for testing
|
|
class ConcreteBroker(BaseBroker):
|
|
def connect(self): return True
|
|
def disconnect(self): pass
|
|
def get_account(self): pass
|
|
def get_positions(self): pass
|
|
def get_position(self, symbol): pass
|
|
def submit_order(self, order): pass
|
|
def cancel_order(self, order_id): pass
|
|
def get_order(self, order_id): pass
|
|
def get_orders(self, status=None, limit=50): pass
|
|
def get_current_price(self, symbol): pass
|
|
|
|
broker = ConcreteBroker(paper_trading=True)
|
|
assert broker.paper_trading is True
|
|
|
|
broker = ConcreteBroker(paper_trading=False)
|
|
assert broker.paper_trading is False
|
|
|
|
|
|
class TestBaseBrokerConvenienceMethods:
|
|
"""Test convenience methods in BaseBroker."""
|
|
|
|
class MockBroker(BaseBroker):
|
|
"""Mock broker for testing convenience methods."""
|
|
|
|
def __init__(self):
|
|
super().__init__(paper_trading=True)
|
|
self.submitted_orders = []
|
|
|
|
def connect(self): return True
|
|
def disconnect(self): pass
|
|
def get_account(self): pass
|
|
def get_positions(self): pass
|
|
def get_position(self, symbol): pass
|
|
|
|
def submit_order(self, order):
|
|
self.submitted_orders.append(order)
|
|
order.order_id = f"order-{len(self.submitted_orders)}"
|
|
order.status = OrderStatus.SUBMITTED
|
|
return order
|
|
|
|
def cancel_order(self, order_id): pass
|
|
def get_order(self, order_id): pass
|
|
def get_orders(self, status=None, limit=50): pass
|
|
def get_current_price(self, symbol): pass
|
|
|
|
def test_buy_market_convenience(self):
|
|
"""Test buy_market convenience method."""
|
|
broker = self.MockBroker()
|
|
order = broker.buy_market("AAPL", Decimal("100"))
|
|
|
|
assert order.symbol == "AAPL"
|
|
assert order.side == OrderSide.BUY
|
|
assert order.quantity == Decimal("100")
|
|
assert order.order_type == OrderType.MARKET
|
|
assert order.time_in_force == "day"
|
|
assert len(broker.submitted_orders) == 1
|
|
|
|
def test_buy_market_custom_time_in_force(self):
|
|
"""Test buy_market with custom time_in_force."""
|
|
broker = self.MockBroker()
|
|
order = broker.buy_market("AAPL", Decimal("100"), time_in_force="gtc")
|
|
|
|
assert order.time_in_force == "gtc"
|
|
|
|
def test_sell_market_convenience(self):
|
|
"""Test sell_market convenience method."""
|
|
broker = self.MockBroker()
|
|
order = broker.sell_market("TSLA", Decimal("50"))
|
|
|
|
assert order.symbol == "TSLA"
|
|
assert order.side == OrderSide.SELL
|
|
assert order.quantity == Decimal("50")
|
|
assert order.order_type == OrderType.MARKET
|
|
|
|
def test_buy_limit_convenience(self):
|
|
"""Test buy_limit convenience method."""
|
|
broker = self.MockBroker()
|
|
order = broker.buy_limit("NVDA", Decimal("25"), Decimal("850.00"))
|
|
|
|
assert order.symbol == "NVDA"
|
|
assert order.side == OrderSide.BUY
|
|
assert order.quantity == Decimal("25")
|
|
assert order.order_type == OrderType.LIMIT
|
|
assert order.limit_price == Decimal("850.00")
|
|
|
|
def test_sell_limit_convenience(self):
|
|
"""Test sell_limit convenience method."""
|
|
broker = self.MockBroker()
|
|
order = broker.sell_limit("AMD", Decimal("100"), Decimal("150.00"))
|
|
|
|
assert order.symbol == "AMD"
|
|
assert order.side == OrderSide.SELL
|
|
assert order.quantity == Decimal("100")
|
|
assert order.order_type == OrderType.LIMIT
|
|
assert order.limit_price == Decimal("150.00")
|
|
|
|
def test_buy_limit_with_gtc(self):
|
|
"""Test buy_limit with GTC time_in_force."""
|
|
broker = self.MockBroker()
|
|
order = broker.buy_limit(
|
|
"AAPL",
|
|
Decimal("100"),
|
|
Decimal("145.00"),
|
|
time_in_force="gtc"
|
|
)
|
|
|
|
assert order.time_in_force == "gtc"
|
|
assert order.limit_price == Decimal("145.00")
|
|
|
|
|
|
@pytest.mark.parametrize("side,expected", [
|
|
(OrderSide.BUY, "buy"),
|
|
(OrderSide.SELL, "sell"),
|
|
])
|
|
def test_order_side_parametrized(side, expected):
|
|
"""Parametrized test for OrderSide values."""
|
|
assert side.value == expected
|
|
|
|
|
|
@pytest.mark.parametrize("order_type,expected", [
|
|
(OrderType.MARKET, "market"),
|
|
(OrderType.LIMIT, "limit"),
|
|
(OrderType.STOP, "stop"),
|
|
(OrderType.STOP_LIMIT, "stop_limit"),
|
|
])
|
|
def test_order_type_parametrized(order_type, expected):
|
|
"""Parametrized test for OrderType values."""
|
|
assert order_type.value == expected
|
|
|
|
|
|
@pytest.mark.parametrize("quantity,price", [
|
|
(Decimal("1"), Decimal("100.00")),
|
|
(Decimal("100"), Decimal("150.50")),
|
|
(Decimal("1000"), Decimal("25.75")),
|
|
(Decimal("0.5"), Decimal("1000.00")), # Fractional shares
|
|
])
|
|
def test_order_with_various_quantities(quantity, price):
|
|
"""Parametrized test for orders with various quantities."""
|
|
order = BrokerOrder(
|
|
symbol="TEST",
|
|
side=OrderSide.BUY,
|
|
quantity=quantity,
|
|
order_type=OrderType.LIMIT,
|
|
limit_price=price
|
|
)
|
|
|
|
assert order.quantity == quantity
|
|
assert order.limit_price == price
|