TradingAgents/tradingagents/portfolio/position.py

398 lines
13 KiB
Python

"""
Position management for the portfolio system.
This module provides the Position class for tracking individual security
positions including quantity, cost basis, market value, and P&L.
"""
from dataclasses import dataclass, field
from datetime import datetime
from decimal import Decimal
from typing import Optional, Dict, Any
import logging
from tradingagents.security import validate_ticker
from .exceptions import (
InvalidPositionError,
InvalidPriceError,
InvalidQuantityError,
ValidationError,
)
logger = logging.getLogger(__name__)
@dataclass
class Position:
"""
Represents a position in a single security.
A position tracks ownership of a specific security, including quantity,
cost basis, and provides calculations for market value and P&L.
Attributes:
ticker: The security ticker symbol
quantity: Number of shares/units owned (can be negative for short positions)
cost_basis: Average cost per share/unit
sector: Optional sector classification
opened_at: Timestamp when position was first opened
last_updated: Timestamp of last position update
stop_loss: Optional stop-loss price
take_profit: Optional take-profit price
metadata: Optional additional metadata
"""
ticker: str
quantity: Decimal
cost_basis: Decimal
sector: Optional[str] = None
opened_at: datetime = field(default_factory=datetime.now)
last_updated: datetime = field(default_factory=datetime.now)
stop_loss: Optional[Decimal] = None
take_profit: Optional[Decimal] = None
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
"""Validate position data after initialization."""
# Validate ticker
try:
self.ticker = validate_ticker(self.ticker)
except ValueError as e:
raise InvalidPositionError(f"Invalid ticker: {e}")
# Convert to Decimal if needed
if not isinstance(self.quantity, Decimal):
try:
self.quantity = Decimal(str(self.quantity))
except (ValueError, TypeError) as e:
raise InvalidQuantityError(f"Invalid quantity: {e}")
if not isinstance(self.cost_basis, Decimal):
try:
self.cost_basis = Decimal(str(self.cost_basis))
except (ValueError, TypeError) as e:
raise InvalidPriceError(f"Invalid cost basis: {e}")
# Validate quantity is not zero
if self.quantity == 0:
raise InvalidQuantityError("Position quantity cannot be zero")
# Validate cost basis is positive
if self.cost_basis <= 0:
raise InvalidPriceError("Cost basis must be positive")
# Convert optional Decimal fields
if self.stop_loss is not None and not isinstance(self.stop_loss, Decimal):
self.stop_loss = Decimal(str(self.stop_loss))
if self.take_profit is not None and not isinstance(self.take_profit, Decimal):
self.take_profit = Decimal(str(self.take_profit))
# Validate stop loss and take profit
if self.stop_loss is not None and self.stop_loss <= 0:
raise InvalidPriceError("Stop loss must be positive")
if self.take_profit is not None and self.take_profit <= 0:
raise InvalidPriceError("Take profit must be positive")
logger.info(
f"Created position: {self.ticker} "
f"quantity={self.quantity} cost_basis={self.cost_basis}"
)
@property
def is_long(self) -> bool:
"""Check if this is a long position."""
return self.quantity > 0
@property
def is_short(self) -> bool:
"""Check if this is a short position."""
return self.quantity < 0
def market_value(self, current_price: Decimal) -> Decimal:
"""
Calculate the current market value of the position.
Args:
current_price: Current market price of the security
Returns:
Market value of the position
Raises:
InvalidPriceError: If current_price is invalid
"""
if not isinstance(current_price, Decimal):
try:
current_price = Decimal(str(current_price))
except (ValueError, TypeError) as e:
raise InvalidPriceError(f"Invalid current price: {e}")
if current_price <= 0:
raise InvalidPriceError("Current price must be positive")
return self.quantity * current_price
def total_cost(self) -> Decimal:
"""
Calculate the total cost of the position.
Returns:
Total cost (quantity * cost_basis)
"""
return abs(self.quantity) * self.cost_basis
def unrealized_pnl(self, current_price: Decimal) -> Decimal:
"""
Calculate unrealized profit/loss.
For long positions: (current_price - cost_basis) * quantity
For short positions: (cost_basis - current_price) * abs(quantity)
Args:
current_price: Current market price of the security
Returns:
Unrealized profit (positive) or loss (negative)
Raises:
InvalidPriceError: If current_price is invalid
"""
if not isinstance(current_price, Decimal):
try:
current_price = Decimal(str(current_price))
except (ValueError, TypeError) as e:
raise InvalidPriceError(f"Invalid current price: {e}")
if current_price <= 0:
raise InvalidPriceError("Current price must be positive")
if self.is_long:
return (current_price - self.cost_basis) * self.quantity
else:
# For short positions
return (self.cost_basis - current_price) * abs(self.quantity)
def unrealized_pnl_percent(self, current_price: Decimal) -> Decimal:
"""
Calculate unrealized P&L as a percentage of cost basis.
Args:
current_price: Current market price of the security
Returns:
Unrealized P&L as a percentage (e.g., 0.15 for 15% gain)
Raises:
InvalidPriceError: If current_price is invalid
"""
if not isinstance(current_price, Decimal):
try:
current_price = Decimal(str(current_price))
except (ValueError, TypeError) as e:
raise InvalidPriceError(f"Invalid current price: {e}")
if current_price <= 0:
raise InvalidPriceError("Current price must be positive")
total_cost = self.total_cost()
if total_cost == 0:
return Decimal('0')
pnl = self.unrealized_pnl(current_price)
return pnl / total_cost
def update_quantity(self, quantity_delta: Decimal) -> None:
"""
Update the position quantity and cost basis.
This method handles adding to or reducing a position, including
proper cost basis calculation.
Args:
quantity_delta: Change in quantity (positive to add, negative to reduce)
Raises:
InvalidQuantityError: If the resulting quantity would be zero
"""
if not isinstance(quantity_delta, Decimal):
try:
quantity_delta = Decimal(str(quantity_delta))
except (ValueError, TypeError) as e:
raise InvalidQuantityError(f"Invalid quantity delta: {e}")
new_quantity = self.quantity + quantity_delta
if new_quantity == 0:
raise InvalidQuantityError(
"Quantity delta would result in zero position. "
"Use close_position instead."
)
# Check if we're reversing the position (going from long to short or vice versa)
if (self.is_long and new_quantity < 0) or (self.is_short and new_quantity > 0):
raise InvalidQuantityError(
"Cannot reverse position direction. Close position first."
)
self.quantity = new_quantity
self.last_updated = datetime.now()
logger.info(
f"Updated position {self.ticker}: "
f"delta={quantity_delta} new_quantity={self.quantity}"
)
def update_cost_basis(
self,
quantity_delta: Decimal,
price: Decimal
) -> None:
"""
Update cost basis when adding to a position.
Uses weighted average cost basis calculation.
Args:
quantity_delta: Additional quantity being added
price: Price of the additional shares
Raises:
InvalidQuantityError: If quantity_delta is invalid
InvalidPriceError: If price is invalid
"""
if not isinstance(quantity_delta, Decimal):
try:
quantity_delta = Decimal(str(quantity_delta))
except (ValueError, TypeError) as e:
raise InvalidQuantityError(f"Invalid quantity delta: {e}")
if not isinstance(price, Decimal):
try:
price = Decimal(str(price))
except (ValueError, TypeError) as e:
raise InvalidPriceError(f"Invalid price: {e}")
if price <= 0:
raise InvalidPriceError("Price must be positive")
# Only update cost basis when adding to the position
if (self.is_long and quantity_delta > 0) or (self.is_short and quantity_delta < 0):
current_value = abs(self.quantity) * self.cost_basis
new_value = abs(quantity_delta) * price
new_total_quantity = abs(self.quantity) + abs(quantity_delta)
self.cost_basis = (current_value + new_value) / new_total_quantity
logger.debug(
f"Updated cost basis for {self.ticker}: "
f"new_cost_basis={self.cost_basis}"
)
def should_trigger_stop_loss(self, current_price: Decimal) -> bool:
"""
Check if stop loss should be triggered.
Args:
current_price: Current market price
Returns:
True if stop loss should be triggered, False otherwise
"""
if self.stop_loss is None:
return False
if not isinstance(current_price, Decimal):
try:
current_price = Decimal(str(current_price))
except (ValueError, TypeError):
return False
if self.is_long:
return current_price <= self.stop_loss
else:
# For short positions, stop loss is triggered when price goes up
return current_price >= self.stop_loss
def should_trigger_take_profit(self, current_price: Decimal) -> bool:
"""
Check if take profit should be triggered.
Args:
current_price: Current market price
Returns:
True if take profit should be triggered, False otherwise
"""
if self.take_profit is None:
return False
if not isinstance(current_price, Decimal):
try:
current_price = Decimal(str(current_price))
except (ValueError, TypeError):
return False
if self.is_long:
return current_price >= self.take_profit
else:
# For short positions, take profit is triggered when price goes down
return current_price <= self.take_profit
def to_dict(self) -> Dict[str, Any]:
"""
Convert position to dictionary for serialization.
Returns:
Dictionary representation of the position
"""
return {
'ticker': self.ticker,
'quantity': str(self.quantity),
'cost_basis': str(self.cost_basis),
'sector': self.sector,
'opened_at': self.opened_at.isoformat(),
'last_updated': self.last_updated.isoformat(),
'stop_loss': str(self.stop_loss) if self.stop_loss else None,
'take_profit': str(self.take_profit) if self.take_profit else None,
'metadata': self.metadata,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Position':
"""
Create a Position from a dictionary.
Args:
data: Dictionary containing position data
Returns:
Position instance
Raises:
ValidationError: If data is invalid
"""
try:
return cls(
ticker=data['ticker'],
quantity=Decimal(data['quantity']),
cost_basis=Decimal(data['cost_basis']),
sector=data.get('sector'),
opened_at=datetime.fromisoformat(data['opened_at']),
last_updated=datetime.fromisoformat(data['last_updated']),
stop_loss=Decimal(data['stop_loss']) if data.get('stop_loss') else None,
take_profit=Decimal(data['take_profit']) if data.get('take_profit') else None,
metadata=data.get('metadata', {}),
)
except (KeyError, ValueError, TypeError) as e:
raise ValidationError(f"Invalid position data: {e}")
def __repr__(self) -> str:
"""String representation of the position."""
position_type = "LONG" if self.is_long else "SHORT"
return (
f"Position({self.ticker}, {position_type}, "
f"qty={self.quantity}, cost={self.cost_basis})"
)