TradingAgents/tests/brokers/test_base_broker.py

444 lines
15 KiB
Python

"""
Comprehensive tests for base broker interface.
Tests order data structures, enumerations, convenience methods,
and abstract interface compliance.
"""
import pytest
from decimal import Decimal
from datetime import datetime
from abc import ABC
from tradingagents.brokers.base import (
BaseBroker,
BrokerOrder,
BrokerPosition,
BrokerAccount,
OrderSide,
OrderType,
OrderStatus,
BrokerError,
ConnectionError,
OrderError,
InsufficientFundsError,
)
class TestOrderEnumerations:
"""Test order-related enumerations."""
def test_order_side_values(self):
"""Test OrderSide enumeration values."""
assert OrderSide.BUY.value == "buy"
assert OrderSide.SELL.value == "sell"
def test_order_type_values(self):
"""Test OrderType enumeration values."""
assert OrderType.MARKET.value == "market"
assert OrderType.LIMIT.value == "limit"
assert OrderType.STOP.value == "stop"
assert OrderType.STOP_LIMIT.value == "stop_limit"
def test_order_status_values(self):
"""Test OrderStatus enumeration values."""
assert OrderStatus.PENDING.value == "pending"
assert OrderStatus.SUBMITTED.value == "submitted"
assert OrderStatus.FILLED.value == "filled"
assert OrderStatus.PARTIALLY_FILLED.value == "partially_filled"
assert OrderStatus.CANCELLED.value == "cancelled"
assert OrderStatus.REJECTED.value == "rejected"
class TestBrokerOrder:
"""Test BrokerOrder dataclass."""
def test_create_market_buy_order(self):
"""Test creating a market buy order."""
order = BrokerOrder(
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
order_type=OrderType.MARKET
)
assert order.symbol == "AAPL"
assert order.side == OrderSide.BUY
assert order.quantity == Decimal("100")
assert order.order_type == OrderType.MARKET
assert order.status == OrderStatus.PENDING
assert order.time_in_force == "day"
assert order.order_id is None
assert order.filled_qty == Decimal("0")
def test_create_limit_sell_order(self):
"""Test creating a limit sell order."""
order = BrokerOrder(
symbol="TSLA",
side=OrderSide.SELL,
quantity=Decimal("50"),
order_type=OrderType.LIMIT,
limit_price=Decimal("250.50")
)
assert order.symbol == "TSLA"
assert order.side == OrderSide.SELL
assert order.limit_price == Decimal("250.50")
def test_create_stop_loss_order(self):
"""Test creating a stop-loss order."""
order = BrokerOrder(
symbol="NVDA",
side=OrderSide.SELL,
quantity=Decimal("25"),
order_type=OrderType.STOP,
stop_price=Decimal("800.00")
)
assert order.stop_price == Decimal("800.00")
assert order.order_type == OrderType.STOP
def test_create_stop_limit_order(self):
"""Test creating a stop-limit order."""
order = BrokerOrder(
symbol="AMD",
side=OrderSide.BUY,
quantity=Decimal("100"),
order_type=OrderType.STOP_LIMIT,
stop_price=Decimal("140.00"),
limit_price=Decimal("142.00")
)
assert order.stop_price == Decimal("140.00")
assert order.limit_price == Decimal("142.00")
def test_order_with_custom_time_in_force(self):
"""Test order with custom time_in_force."""
order = BrokerOrder(
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
order_type=OrderType.MARKET,
time_in_force="gtc"
)
assert order.time_in_force == "gtc"
def test_order_with_filled_data(self):
"""Test order with filled data."""
filled_at = datetime.now()
order = BrokerOrder(
symbol="AAPL",
side=OrderSide.BUY,
quantity=Decimal("100"),
order_type=OrderType.MARKET,
order_id="order-123",
status=OrderStatus.FILLED,
filled_qty=Decimal("100"),
filled_price=Decimal("150.25"),
filled_at=filled_at
)
assert order.order_id == "order-123"
assert order.status == OrderStatus.FILLED
assert order.filled_qty == Decimal("100")
assert order.filled_price == Decimal("150.25")
assert order.filled_at == filled_at
class TestBrokerPosition:
"""Test BrokerPosition dataclass."""
def test_create_position(self):
"""Test creating a broker position."""
position = BrokerPosition(
symbol="AAPL",
quantity=Decimal("100"),
avg_entry_price=Decimal("150.00"),
current_price=Decimal("155.00"),
market_value=Decimal("15500.00"),
unrealized_pnl=Decimal("500.00"),
unrealized_pnl_percent=Decimal("0.0333"),
cost_basis=Decimal("15000.00")
)
assert position.symbol == "AAPL"
assert position.quantity == Decimal("100")
assert position.avg_entry_price == Decimal("150.00")
assert position.current_price == Decimal("155.00")
assert position.market_value == Decimal("15500.00")
assert position.unrealized_pnl == Decimal("500.00")
assert position.unrealized_pnl_percent == Decimal("0.0333")
assert position.cost_basis == Decimal("15000.00")
def test_position_with_loss(self):
"""Test position with unrealized loss."""
position = BrokerPosition(
symbol="TSLA",
quantity=Decimal("50"),
avg_entry_price=Decimal("250.00"),
current_price=Decimal("240.00"),
market_value=Decimal("12000.00"),
unrealized_pnl=Decimal("-500.00"),
unrealized_pnl_percent=Decimal("-0.04"),
cost_basis=Decimal("12500.00")
)
assert position.unrealized_pnl < 0
assert position.unrealized_pnl_percent < 0
class TestBrokerAccount:
"""Test BrokerAccount dataclass."""
def test_create_account(self):
"""Test creating a broker account."""
account = BrokerAccount(
account_number="ACC123456",
cash=Decimal("50000.00"),
buying_power=Decimal("200000.00"),
portfolio_value=Decimal("75000.00"),
equity=Decimal("75000.00"),
last_equity=Decimal("74500.00"),
multiplier=Decimal("4"),
currency="USD",
pattern_day_trader=False
)
assert account.account_number == "ACC123456"
assert account.cash == Decimal("50000.00")
assert account.buying_power == Decimal("200000.00")
assert account.portfolio_value == Decimal("75000.00")
assert account.currency == "USD"
assert account.pattern_day_trader is False
def test_account_defaults(self):
"""Test account with default values."""
account = BrokerAccount(
account_number="ACC123456",
cash=Decimal("50000.00"),
buying_power=Decimal("50000.00"),
portfolio_value=Decimal("50000.00"),
equity=Decimal("50000.00"),
last_equity=Decimal("50000.00"),
multiplier=Decimal("1")
)
# Default values
assert account.currency == "USD"
assert account.pattern_day_trader is False
def test_account_with_pdt_status(self):
"""Test account with pattern day trader status."""
account = BrokerAccount(
account_number="ACC123456",
cash=Decimal("30000.00"),
buying_power=Decimal("120000.00"),
portfolio_value=Decimal("50000.00"),
equity=Decimal("50000.00"),
last_equity=Decimal("49000.00"),
multiplier=Decimal("4"),
pattern_day_trader=True
)
assert account.pattern_day_trader is True
assert account.multiplier == Decimal("4")
class TestBrokerExceptions:
"""Test broker exception classes."""
def test_broker_error(self):
"""Test BrokerError exception."""
with pytest.raises(BrokerError, match="Test error"):
raise BrokerError("Test error")
def test_connection_error(self):
"""Test ConnectionError exception."""
with pytest.raises(ConnectionError, match="Connection failed"):
raise ConnectionError("Connection failed")
# Should also be a BrokerError
with pytest.raises(BrokerError):
raise ConnectionError("Connection failed")
def test_order_error(self):
"""Test OrderError exception."""
with pytest.raises(OrderError, match="Order failed"):
raise OrderError("Order failed")
# Should also be a BrokerError
with pytest.raises(BrokerError):
raise OrderError("Order failed")
def test_insufficient_funds_error(self):
"""Test InsufficientFundsError exception."""
with pytest.raises(InsufficientFundsError, match="Insufficient funds"):
raise InsufficientFundsError("Insufficient funds")
# Should also be a BrokerError
with pytest.raises(BrokerError):
raise InsufficientFundsError("Insufficient funds")
class TestBaseBrokerInterface:
"""Test BaseBroker abstract interface."""
def test_base_broker_is_abstract(self):
"""Test that BaseBroker cannot be instantiated directly."""
# BaseBroker is abstract and should not be instantiable
assert ABC in BaseBroker.__bases__
def test_base_broker_paper_trading_flag(self):
"""Test that BaseBroker stores paper_trading flag."""
# Create a concrete implementation for testing
class ConcreteBroker(BaseBroker):
def connect(self): return True
def disconnect(self): pass
def get_account(self): pass
def get_positions(self): pass
def get_position(self, symbol): pass
def submit_order(self, order): pass
def cancel_order(self, order_id): pass
def get_order(self, order_id): pass
def get_orders(self, status=None, limit=50): pass
def get_current_price(self, symbol): pass
broker = ConcreteBroker(paper_trading=True)
assert broker.paper_trading is True
broker = ConcreteBroker(paper_trading=False)
assert broker.paper_trading is False
class TestBaseBrokerConvenienceMethods:
"""Test convenience methods in BaseBroker."""
class MockBroker(BaseBroker):
"""Mock broker for testing convenience methods."""
def __init__(self):
super().__init__(paper_trading=True)
self.submitted_orders = []
def connect(self): return True
def disconnect(self): pass
def get_account(self): pass
def get_positions(self): pass
def get_position(self, symbol): pass
def submit_order(self, order):
self.submitted_orders.append(order)
order.order_id = f"order-{len(self.submitted_orders)}"
order.status = OrderStatus.SUBMITTED
return order
def cancel_order(self, order_id): pass
def get_order(self, order_id): pass
def get_orders(self, status=None, limit=50): pass
def get_current_price(self, symbol): pass
def test_buy_market_convenience(self):
"""Test buy_market convenience method."""
broker = self.MockBroker()
order = broker.buy_market("AAPL", Decimal("100"))
assert order.symbol == "AAPL"
assert order.side == OrderSide.BUY
assert order.quantity == Decimal("100")
assert order.order_type == OrderType.MARKET
assert order.time_in_force == "day"
assert len(broker.submitted_orders) == 1
def test_buy_market_custom_time_in_force(self):
"""Test buy_market with custom time_in_force."""
broker = self.MockBroker()
order = broker.buy_market("AAPL", Decimal("100"), time_in_force="gtc")
assert order.time_in_force == "gtc"
def test_sell_market_convenience(self):
"""Test sell_market convenience method."""
broker = self.MockBroker()
order = broker.sell_market("TSLA", Decimal("50"))
assert order.symbol == "TSLA"
assert order.side == OrderSide.SELL
assert order.quantity == Decimal("50")
assert order.order_type == OrderType.MARKET
def test_buy_limit_convenience(self):
"""Test buy_limit convenience method."""
broker = self.MockBroker()
order = broker.buy_limit("NVDA", Decimal("25"), Decimal("850.00"))
assert order.symbol == "NVDA"
assert order.side == OrderSide.BUY
assert order.quantity == Decimal("25")
assert order.order_type == OrderType.LIMIT
assert order.limit_price == Decimal("850.00")
def test_sell_limit_convenience(self):
"""Test sell_limit convenience method."""
broker = self.MockBroker()
order = broker.sell_limit("AMD", Decimal("100"), Decimal("150.00"))
assert order.symbol == "AMD"
assert order.side == OrderSide.SELL
assert order.quantity == Decimal("100")
assert order.order_type == OrderType.LIMIT
assert order.limit_price == Decimal("150.00")
def test_buy_limit_with_gtc(self):
"""Test buy_limit with GTC time_in_force."""
broker = self.MockBroker()
order = broker.buy_limit(
"AAPL",
Decimal("100"),
Decimal("145.00"),
time_in_force="gtc"
)
assert order.time_in_force == "gtc"
assert order.limit_price == Decimal("145.00")
@pytest.mark.parametrize("side,expected", [
(OrderSide.BUY, "buy"),
(OrderSide.SELL, "sell"),
])
def test_order_side_parametrized(side, expected):
"""Parametrized test for OrderSide values."""
assert side.value == expected
@pytest.mark.parametrize("order_type,expected", [
(OrderType.MARKET, "market"),
(OrderType.LIMIT, "limit"),
(OrderType.STOP, "stop"),
(OrderType.STOP_LIMIT, "stop_limit"),
])
def test_order_type_parametrized(order_type, expected):
"""Parametrized test for OrderType values."""
assert order_type.value == expected
@pytest.mark.parametrize("quantity,price", [
(Decimal("1"), Decimal("100.00")),
(Decimal("100"), Decimal("150.50")),
(Decimal("1000"), Decimal("25.75")),
(Decimal("0.5"), Decimal("1000.00")), # Fractional shares
])
def test_order_with_various_quantities(quantity, price):
"""Parametrized test for orders with various quantities."""
order = BrokerOrder(
symbol="TEST",
side=OrderSide.BUY,
quantity=quantity,
order_type=OrderType.LIMIT,
limit_price=price
)
assert order.quantity == quantity
assert order.limit_price == price