TradingAgents/tests/portfolio/test_position.py

230 lines
7.6 KiB
Python

"""
Tests for the Position class.
"""
import unittest
from decimal import Decimal
from datetime import datetime, timedelta
from tradingagents.portfolio import Position
from tradingagents.portfolio.exceptions import (
InvalidPositionError,
InvalidPriceError,
InvalidQuantityError,
)
class TestPosition(unittest.TestCase):
"""Test cases for Position class."""
def test_create_long_position(self):
"""Test creating a long position."""
position = Position(
ticker='AAPL',
quantity=Decimal('100'),
cost_basis=Decimal('150.00')
)
self.assertEqual(position.ticker, 'AAPL')
self.assertEqual(position.quantity, Decimal('100'))
self.assertEqual(position.cost_basis, Decimal('150.00'))
self.assertTrue(position.is_long)
self.assertFalse(position.is_short)
def test_create_short_position(self):
"""Test creating a short position."""
position = Position(
ticker='TSLA',
quantity=Decimal('-50'),
cost_basis=Decimal('200.00')
)
self.assertEqual(position.quantity, Decimal('-50'))
self.assertFalse(position.is_long)
self.assertTrue(position.is_short)
def test_invalid_ticker(self):
"""Test that invalid tickers are rejected."""
with self.assertRaises(InvalidPositionError):
Position(
ticker='../etc/passwd',
quantity=Decimal('100'),
cost_basis=Decimal('150.00')
)
def test_zero_quantity_rejected(self):
"""Test that zero quantity is rejected."""
with self.assertRaises(InvalidQuantityError):
Position(
ticker='AAPL',
quantity=Decimal('0'),
cost_basis=Decimal('150.00')
)
def test_negative_cost_basis_rejected(self):
"""Test that negative cost basis is rejected."""
with self.assertRaises(InvalidPriceError):
Position(
ticker='AAPL',
quantity=Decimal('100'),
cost_basis=Decimal('-150.00')
)
def test_market_value(self):
"""Test market value calculation."""
position = Position(
ticker='AAPL',
quantity=Decimal('100'),
cost_basis=Decimal('150.00')
)
market_value = position.market_value(Decimal('160.00'))
self.assertEqual(market_value, Decimal('16000.00'))
def test_total_cost(self):
"""Test total cost calculation."""
position = Position(
ticker='AAPL',
quantity=Decimal('100'),
cost_basis=Decimal('150.00')
)
self.assertEqual(position.total_cost(), Decimal('15000.00'))
def test_unrealized_pnl_long_profit(self):
"""Test unrealized P&L for profitable long position."""
position = Position(
ticker='AAPL',
quantity=Decimal('100'),
cost_basis=Decimal('150.00')
)
pnl = position.unrealized_pnl(Decimal('160.00'))
self.assertEqual(pnl, Decimal('1000.00')) # (160 - 150) * 100
def test_unrealized_pnl_long_loss(self):
"""Test unrealized P&L for losing long position."""
position = Position(
ticker='AAPL',
quantity=Decimal('100'),
cost_basis=Decimal('150.00')
)
pnl = position.unrealized_pnl(Decimal('140.00'))
self.assertEqual(pnl, Decimal('-1000.00')) # (140 - 150) * 100
def test_unrealized_pnl_short_profit(self):
"""Test unrealized P&L for profitable short position."""
position = Position(
ticker='TSLA',
quantity=Decimal('-50'),
cost_basis=Decimal('200.00')
)
pnl = position.unrealized_pnl(Decimal('180.00'))
self.assertEqual(pnl, Decimal('1000.00')) # (200 - 180) * 50
def test_unrealized_pnl_percent(self):
"""Test unrealized P&L percentage calculation."""
position = Position(
ticker='AAPL',
quantity=Decimal('100'),
cost_basis=Decimal('150.00')
)
pnl_pct = position.unrealized_pnl_percent(Decimal('165.00'))
self.assertEqual(pnl_pct, Decimal('0.1')) # 10% gain
def test_update_quantity(self):
"""Test updating position quantity."""
position = Position(
ticker='AAPL',
quantity=Decimal('100'),
cost_basis=Decimal('150.00')
)
position.update_quantity(Decimal('50'))
self.assertEqual(position.quantity, Decimal('150'))
def test_update_quantity_cannot_reach_zero(self):
"""Test that update_quantity cannot result in zero."""
position = Position(
ticker='AAPL',
quantity=Decimal('100'),
cost_basis=Decimal('150.00')
)
with self.assertRaises(InvalidQuantityError):
position.update_quantity(Decimal('-100'))
def test_update_cost_basis(self):
"""Test weighted average cost basis calculation."""
position = Position(
ticker='AAPL',
quantity=Decimal('100'),
cost_basis=Decimal('150.00')
)
# Add 50 shares at $160
position.update_cost_basis(Decimal('50'), Decimal('160.00'))
# New cost basis should be (100*150 + 50*160) / 150 = 153.33...
expected = (Decimal('100') * Decimal('150.00') + Decimal('50') * Decimal('160.00')) / Decimal('150')
self.assertAlmostEqual(float(position.cost_basis), float(expected), places=2)
def test_stop_loss_trigger_long(self):
"""Test stop-loss trigger for long position."""
position = Position(
ticker='AAPL',
quantity=Decimal('100'),
cost_basis=Decimal('150.00'),
stop_loss=Decimal('145.00')
)
self.assertFalse(position.should_trigger_stop_loss(Decimal('150.00')))
self.assertFalse(position.should_trigger_stop_loss(Decimal('146.00')))
self.assertTrue(position.should_trigger_stop_loss(Decimal('145.00')))
self.assertTrue(position.should_trigger_stop_loss(Decimal('140.00')))
def test_take_profit_trigger_long(self):
"""Test take-profit trigger for long position."""
position = Position(
ticker='AAPL',
quantity=Decimal('100'),
cost_basis=Decimal('150.00'),
take_profit=Decimal('160.00')
)
self.assertFalse(position.should_trigger_take_profit(Decimal('150.00')))
self.assertFalse(position.should_trigger_take_profit(Decimal('159.00')))
self.assertTrue(position.should_trigger_take_profit(Decimal('160.00')))
self.assertTrue(position.should_trigger_take_profit(Decimal('165.00')))
def test_to_dict_and_from_dict(self):
"""Test serialization and deserialization."""
position = Position(
ticker='AAPL',
quantity=Decimal('100'),
cost_basis=Decimal('150.00'),
sector='Technology',
stop_loss=Decimal('145.00'),
take_profit=Decimal('160.00')
)
# Convert to dict
data = position.to_dict()
# Create from dict
restored = Position.from_dict(data)
self.assertEqual(restored.ticker, position.ticker)
self.assertEqual(restored.quantity, position.quantity)
self.assertEqual(restored.cost_basis, position.cost_basis)
self.assertEqual(restored.sector, position.sector)
self.assertEqual(restored.stop_loss, position.stop_loss)
self.assertEqual(restored.take_profit, position.take_profit)
if __name__ == '__main__':
unittest.main()