230 lines
7.6 KiB
Python
230 lines
7.6 KiB
Python
"""
|
|
Tests for the Position class.
|
|
"""
|
|
|
|
import unittest
|
|
from decimal import Decimal
|
|
from datetime import datetime, timedelta
|
|
|
|
from tradingagents.portfolio import Position
|
|
from tradingagents.portfolio.exceptions import (
|
|
InvalidPositionError,
|
|
InvalidPriceError,
|
|
InvalidQuantityError,
|
|
)
|
|
|
|
|
|
class TestPosition(unittest.TestCase):
|
|
"""Test cases for Position class."""
|
|
|
|
def test_create_long_position(self):
|
|
"""Test creating a long position."""
|
|
position = Position(
|
|
ticker='AAPL',
|
|
quantity=Decimal('100'),
|
|
cost_basis=Decimal('150.00')
|
|
)
|
|
|
|
self.assertEqual(position.ticker, 'AAPL')
|
|
self.assertEqual(position.quantity, Decimal('100'))
|
|
self.assertEqual(position.cost_basis, Decimal('150.00'))
|
|
self.assertTrue(position.is_long)
|
|
self.assertFalse(position.is_short)
|
|
|
|
def test_create_short_position(self):
|
|
"""Test creating a short position."""
|
|
position = Position(
|
|
ticker='TSLA',
|
|
quantity=Decimal('-50'),
|
|
cost_basis=Decimal('200.00')
|
|
)
|
|
|
|
self.assertEqual(position.quantity, Decimal('-50'))
|
|
self.assertFalse(position.is_long)
|
|
self.assertTrue(position.is_short)
|
|
|
|
def test_invalid_ticker(self):
|
|
"""Test that invalid tickers are rejected."""
|
|
with self.assertRaises(InvalidPositionError):
|
|
Position(
|
|
ticker='../etc/passwd',
|
|
quantity=Decimal('100'),
|
|
cost_basis=Decimal('150.00')
|
|
)
|
|
|
|
def test_zero_quantity_rejected(self):
|
|
"""Test that zero quantity is rejected."""
|
|
with self.assertRaises(InvalidQuantityError):
|
|
Position(
|
|
ticker='AAPL',
|
|
quantity=Decimal('0'),
|
|
cost_basis=Decimal('150.00')
|
|
)
|
|
|
|
def test_negative_cost_basis_rejected(self):
|
|
"""Test that negative cost basis is rejected."""
|
|
with self.assertRaises(InvalidPriceError):
|
|
Position(
|
|
ticker='AAPL',
|
|
quantity=Decimal('100'),
|
|
cost_basis=Decimal('-150.00')
|
|
)
|
|
|
|
def test_market_value(self):
|
|
"""Test market value calculation."""
|
|
position = Position(
|
|
ticker='AAPL',
|
|
quantity=Decimal('100'),
|
|
cost_basis=Decimal('150.00')
|
|
)
|
|
|
|
market_value = position.market_value(Decimal('160.00'))
|
|
self.assertEqual(market_value, Decimal('16000.00'))
|
|
|
|
def test_total_cost(self):
|
|
"""Test total cost calculation."""
|
|
position = Position(
|
|
ticker='AAPL',
|
|
quantity=Decimal('100'),
|
|
cost_basis=Decimal('150.00')
|
|
)
|
|
|
|
self.assertEqual(position.total_cost(), Decimal('15000.00'))
|
|
|
|
def test_unrealized_pnl_long_profit(self):
|
|
"""Test unrealized P&L for profitable long position."""
|
|
position = Position(
|
|
ticker='AAPL',
|
|
quantity=Decimal('100'),
|
|
cost_basis=Decimal('150.00')
|
|
)
|
|
|
|
pnl = position.unrealized_pnl(Decimal('160.00'))
|
|
self.assertEqual(pnl, Decimal('1000.00')) # (160 - 150) * 100
|
|
|
|
def test_unrealized_pnl_long_loss(self):
|
|
"""Test unrealized P&L for losing long position."""
|
|
position = Position(
|
|
ticker='AAPL',
|
|
quantity=Decimal('100'),
|
|
cost_basis=Decimal('150.00')
|
|
)
|
|
|
|
pnl = position.unrealized_pnl(Decimal('140.00'))
|
|
self.assertEqual(pnl, Decimal('-1000.00')) # (140 - 150) * 100
|
|
|
|
def test_unrealized_pnl_short_profit(self):
|
|
"""Test unrealized P&L for profitable short position."""
|
|
position = Position(
|
|
ticker='TSLA',
|
|
quantity=Decimal('-50'),
|
|
cost_basis=Decimal('200.00')
|
|
)
|
|
|
|
pnl = position.unrealized_pnl(Decimal('180.00'))
|
|
self.assertEqual(pnl, Decimal('1000.00')) # (200 - 180) * 50
|
|
|
|
def test_unrealized_pnl_percent(self):
|
|
"""Test unrealized P&L percentage calculation."""
|
|
position = Position(
|
|
ticker='AAPL',
|
|
quantity=Decimal('100'),
|
|
cost_basis=Decimal('150.00')
|
|
)
|
|
|
|
pnl_pct = position.unrealized_pnl_percent(Decimal('165.00'))
|
|
self.assertEqual(pnl_pct, Decimal('0.1')) # 10% gain
|
|
|
|
def test_update_quantity(self):
|
|
"""Test updating position quantity."""
|
|
position = Position(
|
|
ticker='AAPL',
|
|
quantity=Decimal('100'),
|
|
cost_basis=Decimal('150.00')
|
|
)
|
|
|
|
position.update_quantity(Decimal('50'))
|
|
self.assertEqual(position.quantity, Decimal('150'))
|
|
|
|
def test_update_quantity_cannot_reach_zero(self):
|
|
"""Test that update_quantity cannot result in zero."""
|
|
position = Position(
|
|
ticker='AAPL',
|
|
quantity=Decimal('100'),
|
|
cost_basis=Decimal('150.00')
|
|
)
|
|
|
|
with self.assertRaises(InvalidQuantityError):
|
|
position.update_quantity(Decimal('-100'))
|
|
|
|
def test_update_cost_basis(self):
|
|
"""Test weighted average cost basis calculation."""
|
|
position = Position(
|
|
ticker='AAPL',
|
|
quantity=Decimal('100'),
|
|
cost_basis=Decimal('150.00')
|
|
)
|
|
|
|
# Add 50 shares at $160
|
|
position.update_cost_basis(Decimal('50'), Decimal('160.00'))
|
|
|
|
# New cost basis should be (100*150 + 50*160) / 150 = 153.33...
|
|
expected = (Decimal('100') * Decimal('150.00') + Decimal('50') * Decimal('160.00')) / Decimal('150')
|
|
self.assertAlmostEqual(float(position.cost_basis), float(expected), places=2)
|
|
|
|
def test_stop_loss_trigger_long(self):
|
|
"""Test stop-loss trigger for long position."""
|
|
position = Position(
|
|
ticker='AAPL',
|
|
quantity=Decimal('100'),
|
|
cost_basis=Decimal('150.00'),
|
|
stop_loss=Decimal('145.00')
|
|
)
|
|
|
|
self.assertFalse(position.should_trigger_stop_loss(Decimal('150.00')))
|
|
self.assertFalse(position.should_trigger_stop_loss(Decimal('146.00')))
|
|
self.assertTrue(position.should_trigger_stop_loss(Decimal('145.00')))
|
|
self.assertTrue(position.should_trigger_stop_loss(Decimal('140.00')))
|
|
|
|
def test_take_profit_trigger_long(self):
|
|
"""Test take-profit trigger for long position."""
|
|
position = Position(
|
|
ticker='AAPL',
|
|
quantity=Decimal('100'),
|
|
cost_basis=Decimal('150.00'),
|
|
take_profit=Decimal('160.00')
|
|
)
|
|
|
|
self.assertFalse(position.should_trigger_take_profit(Decimal('150.00')))
|
|
self.assertFalse(position.should_trigger_take_profit(Decimal('159.00')))
|
|
self.assertTrue(position.should_trigger_take_profit(Decimal('160.00')))
|
|
self.assertTrue(position.should_trigger_take_profit(Decimal('165.00')))
|
|
|
|
def test_to_dict_and_from_dict(self):
|
|
"""Test serialization and deserialization."""
|
|
position = Position(
|
|
ticker='AAPL',
|
|
quantity=Decimal('100'),
|
|
cost_basis=Decimal('150.00'),
|
|
sector='Technology',
|
|
stop_loss=Decimal('145.00'),
|
|
take_profit=Decimal('160.00')
|
|
)
|
|
|
|
# Convert to dict
|
|
data = position.to_dict()
|
|
|
|
# Create from dict
|
|
restored = Position.from_dict(data)
|
|
|
|
self.assertEqual(restored.ticker, position.ticker)
|
|
self.assertEqual(restored.quantity, position.quantity)
|
|
self.assertEqual(restored.cost_basis, position.cost_basis)
|
|
self.assertEqual(restored.sector, position.sector)
|
|
self.assertEqual(restored.stop_loss, position.stop_loss)
|
|
self.assertEqual(restored.take_profit, position.take_profit)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|