""" Core portfolio management for the TradingAgents framework. This module provides the main Portfolio class for managing positions, executing orders, tracking P&L, and calculating risk metrics. """ from dataclasses import dataclass, field from datetime import datetime from decimal import Decimal from typing import Dict, List, Optional, Tuple, Any import threading import logging from tradingagents.security import validate_ticker from .position import Position from .orders import ( Order, MarketOrder, LimitOrder, StopLossOrder, TakeProfitOrder, OrderStatus, create_order_from_dict ) from .risk import RiskManager, RiskLimits from .analytics import PerformanceAnalytics, TradeRecord, PerformanceMetrics from .persistence import PortfolioPersistence from .exceptions import ( InsufficientFundsError, InsufficientSharesError, InvalidOrderError, PositionNotFoundError, RiskLimitExceededError, ValidationError, PersistenceError, ) logger = logging.getLogger(__name__) class Portfolio: """ Main portfolio management class. This class manages a portfolio of positions, handles order execution, tracks cash and P&L, enforces risk limits, and provides performance analytics. Thread-safe for concurrent operations. Attributes: initial_capital: Initial portfolio capital cash: Current cash balance positions: Dictionary of current positions (ticker -> Position) commission_rate: Commission rate as a fraction (e.g., 0.001 for 0.1%) risk_manager: Risk management component analytics: Performance analytics component persistence: Persistence component """ def __init__( self, initial_capital: Decimal, commission_rate: Decimal = Decimal('0.001'), risk_limits: Optional[RiskLimits] = None, persist_dir: Optional[str] = None ): """ Initialize a new portfolio. Args: initial_capital: Starting capital commission_rate: Commission rate as a fraction (default 0.1%) risk_limits: Risk limits configuration (uses defaults if None) persist_dir: Directory for persistence (default ./portfolio_data) Raises: ValidationError: If inputs are invalid """ # Validate inputs if not isinstance(initial_capital, Decimal): try: initial_capital = Decimal(str(initial_capital)) except (ValueError, TypeError) as e: raise ValidationError(f"Invalid initial capital: {e}") if initial_capital <= 0: raise ValidationError("Initial capital must be positive") if not isinstance(commission_rate, Decimal): try: commission_rate = Decimal(str(commission_rate)) except (ValueError, TypeError) as e: raise ValidationError(f"Invalid commission rate: {e}") if commission_rate < 0 or commission_rate > 1: raise ValidationError("Commission rate must be between 0 and 1") # Initialize core attributes self.initial_capital = initial_capital self.cash = initial_capital self.commission_rate = commission_rate self.positions: Dict[str, Position] = {} # Trade tracking self.trade_history: List[TradeRecord] = [] self.closed_positions: Dict[str, List[Position]] = {} self.pending_orders: List[Order] = [] # Equity curve tracking self.equity_curve: List[Tuple[datetime, Decimal]] = [ (datetime.now(), initial_capital) ] # Peak tracking for drawdown self.peak_value = initial_capital # Components self.risk_manager = RiskManager(risk_limits) self.analytics = PerformanceAnalytics() self.persistence = PortfolioPersistence(persist_dir) # Thread safety self._lock = threading.RLock() logger.info( f"Initialized portfolio with capital={initial_capital}, " f"commission={commission_rate}" ) def execute_order( self, order: Order, current_price: Decimal, check_risk: bool = True ) -> None: """ Execute an order at the current price. Args: order: Order to execute current_price: Current market price check_risk: Whether to check risk limits (default True) Raises: InvalidOrderError: If order cannot be executed InsufficientFundsError: If insufficient cash for buy order InsufficientSharesError: If insufficient shares for sell order RiskLimitExceededError: If trade would exceed risk limits ValidationError: If inputs are invalid """ with self._lock: # Validate price if not isinstance(current_price, Decimal): try: current_price = Decimal(str(current_price)) except (ValueError, TypeError) as e: raise ValidationError(f"Invalid current price: {e}") if current_price <= 0: raise ValidationError("Current price must be positive") # Check if order can execute at current price if not order.can_execute(current_price): raise InvalidOrderError( f"Order cannot execute at current price {current_price}" ) # Calculate order value and commission order_value = abs(order.quantity) * current_price commission = order_value * self.commission_rate # Execute based on order side if order.is_buy: self._execute_buy_order( order, current_price, order_value, commission, check_risk ) else: self._execute_sell_order( order, current_price, order_value, commission, check_risk ) # Mark order as executed order.mark_executed(abs(order.quantity), current_price) # Update equity curve self._update_equity_curve(current_price) logger.info( f"Executed {order.side.value} order: {order.ticker} " f"qty={abs(order.quantity)} price={current_price} " f"commission={commission}" ) def _execute_buy_order( self, order: Order, current_price: Decimal, order_value: Decimal, commission: Decimal, check_risk: bool ) -> None: """Execute a buy order.""" total_cost = order_value + commission # Check sufficient funds if total_cost > self.cash: raise InsufficientFundsError( f"Insufficient funds: need {total_cost}, have {self.cash}" ) # Risk checks if check_risk: # Check position size limit portfolio_value = self.total_value() new_position_value = order_value if order.ticker in self.positions: current_position_value = self.positions[order.ticker].market_value(current_price) new_position_value += current_position_value self.risk_manager.check_position_size_limit( new_position_value, portfolio_value, order.ticker ) # Check cash reserve new_cash = self.cash - total_cost self.risk_manager.check_cash_reserve(new_cash, portfolio_value) # Update or create position if order.ticker in self.positions: # Add to existing position position = self.positions[order.ticker] position.update_cost_basis(order.quantity, current_price) position.update_quantity(order.quantity) else: # Create new position self.positions[order.ticker] = Position( ticker=order.ticker, quantity=order.quantity, cost_basis=current_price, metadata=order.metadata ) # Deduct cash self.cash -= total_cost def _execute_sell_order( self, order: Order, current_price: Decimal, order_value: Decimal, commission: Decimal, check_risk: bool ) -> None: """Execute a sell order.""" # Check if position exists if order.ticker not in self.positions: raise PositionNotFoundError( f"No position in {order.ticker} to sell" ) position = self.positions[order.ticker] sell_quantity = abs(order.quantity) # Check sufficient shares if sell_quantity > abs(position.quantity): raise InsufficientSharesError( f"Insufficient shares: trying to sell {sell_quantity}, " f"have {abs(position.quantity)}" ) # Calculate P&L for this sale cost_basis_value = sell_quantity * position.cost_basis sale_proceeds = order_value - commission pnl = sale_proceeds - cost_basis_value pnl_percent = pnl / cost_basis_value if cost_basis_value > 0 else Decimal('0') # Check if closing entire position if sell_quantity == abs(position.quantity): # Record completed trade trade_record = TradeRecord( ticker=order.ticker, entry_date=position.opened_at, exit_date=datetime.now(), entry_price=position.cost_basis, exit_price=current_price, quantity=position.quantity, pnl=pnl, pnl_percent=pnl_percent, commission=commission, holding_period=(datetime.now() - position.opened_at).days, is_win=pnl > 0 ) self.trade_history.append(trade_record) # Move to closed positions if order.ticker not in self.closed_positions: self.closed_positions[order.ticker] = [] self.closed_positions[order.ticker].append(position) # Remove from active positions del self.positions[order.ticker] else: # Partially close position position.update_quantity(-sell_quantity) # Add proceeds to cash self.cash += sale_proceeds def get_position(self, ticker: str) -> Optional[Position]: """ Get a position by ticker. Args: ticker: Ticker symbol Returns: Position object or None if not found Raises: ValidationError: If ticker is invalid """ with self._lock: try: ticker = validate_ticker(ticker) except ValueError as e: raise ValidationError(f"Invalid ticker: {e}") return self.positions.get(ticker) def get_all_positions(self) -> Dict[str, Position]: """ Get all current positions. Returns: Dictionary mapping ticker to Position """ with self._lock: return self.positions.copy() def total_value(self, prices: Optional[Dict[str, Decimal]] = None) -> Decimal: """ Calculate total portfolio value. Args: prices: Optional dict of current prices (ticker -> price) If None, uses cost basis for positions Returns: Total portfolio value (cash + positions) Raises: ValidationError: If prices are invalid """ with self._lock: total = self.cash for ticker, position in self.positions.items(): if prices and ticker in prices: price = prices[ticker] if not isinstance(price, Decimal): price = Decimal(str(price)) if price <= 0: raise ValidationError(f"Invalid price for {ticker}: {price}") total += position.market_value(price) else: # Use cost basis if no price provided total += position.total_cost() return total def unrealized_pnl(self, prices: Dict[str, Decimal]) -> Decimal: """ Calculate total unrealized P&L. Args: prices: Dictionary of current prices (ticker -> price) Returns: Total unrealized P&L Raises: ValidationError: If prices are invalid """ with self._lock: total_pnl = Decimal('0') for ticker, position in self.positions.items(): if ticker in prices: price = prices[ticker] if not isinstance(price, Decimal): price = Decimal(str(price)) total_pnl += position.unrealized_pnl(price) return total_pnl def realized_pnl(self) -> Decimal: """ Calculate total realized P&L from closed trades. Returns: Total realized P&L """ with self._lock: return sum(trade.pnl for trade in self.trade_history) def get_performance_metrics( self, risk_free_rate: Decimal = Decimal('0.02') ) -> PerformanceMetrics: """ Get comprehensive performance metrics. Args: risk_free_rate: Annual risk-free rate (default 2%) Returns: PerformanceMetrics object Raises: ValidationError: If risk_free_rate is invalid """ with self._lock: return self.analytics.generate_performance_metrics( self.equity_curve, self.trade_history, self.initial_capital, risk_free_rate ) def get_equity_curve(self) -> List[Tuple[datetime, Decimal]]: """ Get the equity curve. Returns: List of (datetime, value) tuples """ with self._lock: return self.equity_curve.copy() def _update_equity_curve( self, current_price: Optional[Decimal] = None, prices: Optional[Dict[str, Decimal]] = None ) -> None: """ Update the equity curve with current portfolio value. Args: current_price: Single price to use for all positions prices: Dictionary of prices per ticker """ if prices is None and current_price is None: # Use cost basis value = self.total_value() elif prices is not None: value = self.total_value(prices) else: # Use single price for all positions price_dict = {ticker: current_price for ticker in self.positions.keys()} value = self.total_value(price_dict) self.equity_curve.append((datetime.now(), value)) # Update peak value if value > self.peak_value: self.peak_value = value def check_stop_loss_triggers( self, prices: Dict[str, Decimal] ) -> List[Order]: """ Check if any positions should trigger stop-loss orders. Args: prices: Dictionary of current prices Returns: List of stop-loss orders that should be executed """ with self._lock: stop_loss_orders = [] for ticker, position in self.positions.items(): if ticker not in prices: continue price = prices[ticker] if not isinstance(price, Decimal): price = Decimal(str(price)) if position.should_trigger_stop_loss(price): # Create stop-loss order to close position order = StopLossOrder( ticker=ticker, quantity=-position.quantity, # Opposite sign to close stop_price=position.stop_loss ) stop_loss_orders.append(order) logger.warning( f"Stop-loss triggered for {ticker} at {price} " f"(stop={position.stop_loss})" ) return stop_loss_orders def check_take_profit_triggers( self, prices: Dict[str, Decimal] ) -> List[Order]: """ Check if any positions should trigger take-profit orders. Args: prices: Dictionary of current prices Returns: List of take-profit orders that should be executed """ with self._lock: take_profit_orders = [] for ticker, position in self.positions.items(): if ticker not in prices: continue price = prices[ticker] if not isinstance(price, Decimal): price = Decimal(str(price)) if position.should_trigger_take_profit(price): # Create take-profit order to close position order = TakeProfitOrder( ticker=ticker, quantity=-position.quantity, # Opposite sign to close target_price=position.take_profit ) take_profit_orders.append(order) logger.info( f"Take-profit triggered for {ticker} at {price} " f"(target={position.take_profit})" ) return take_profit_orders def save(self, filename: str = 'portfolio_state.json') -> None: """ Save portfolio state to a file. Args: filename: Name of the file to save to Raises: PersistenceError: If save fails """ with self._lock: portfolio_data = self.to_dict() self.persistence.save_to_json(portfolio_data, filename) logger.info(f"Saved portfolio to {filename}") @classmethod def load(cls, filename: str = 'portfolio_state.json', persist_dir: Optional[str] = None) -> 'Portfolio': """ Load portfolio state from a file. Args: filename: Name of the file to load from persist_dir: Directory containing the file Returns: Portfolio instance Raises: PersistenceError: If load fails """ persistence = PortfolioPersistence(persist_dir) portfolio_data = persistence.load_from_json(filename) # Create portfolio with loaded data portfolio = cls( initial_capital=portfolio_data['initial_capital'], commission_rate=portfolio_data['commission_rate'], persist_dir=persist_dir ) # Restore state portfolio.cash = portfolio_data['cash'] # Restore positions for ticker, pos_data in portfolio_data.get('positions', {}).items(): portfolio.positions[ticker] = Position.from_dict(pos_data) # Restore trade history for trade_data in portfolio_data.get('trade_history', []): trade = TradeRecord( ticker=trade_data['ticker'], entry_date=datetime.fromisoformat(trade_data['entry_date']), exit_date=datetime.fromisoformat(trade_data['exit_date']), entry_price=Decimal(trade_data['entry_price']), exit_price=Decimal(trade_data['exit_price']), quantity=Decimal(trade_data['quantity']), pnl=Decimal(trade_data['pnl']), pnl_percent=Decimal(trade_data['pnl_percent']), commission=Decimal(trade_data['commission']), holding_period=trade_data['holding_period'], is_win=trade_data['is_win'] ) portfolio.trade_history.append(trade) # Restore equity curve for point in portfolio_data.get('equity_curve', []): portfolio.equity_curve.append(( datetime.fromisoformat(point[0]), Decimal(point[1]) )) # Restore peak value portfolio.peak_value = portfolio_data.get('peak_value', portfolio.initial_capital) logger.info(f"Loaded portfolio from {filename}") return portfolio def to_dict(self) -> Dict[str, Any]: """ Convert portfolio to dictionary for serialization. Returns: Dictionary representation of the portfolio """ with self._lock: return { 'initial_capital': str(self.initial_capital), 'cash': str(self.cash), 'commission_rate': str(self.commission_rate), 'positions': { ticker: position.to_dict() for ticker, position in self.positions.items() }, 'trade_history': [ trade.to_dict() for trade in self.trade_history ], 'equity_curve': [ (dt.isoformat(), str(value)) for dt, value in self.equity_curve ], 'peak_value': str(self.peak_value), 'timestamp': datetime.now().isoformat(), } def summary(self) -> Dict[str, Any]: """ Get a summary of the portfolio. Returns: Dictionary with portfolio summary """ with self._lock: total_val = self.total_value() realized = self.realized_pnl() return { 'total_value': str(total_val), 'cash': str(self.cash), 'invested': str(total_val - self.cash), 'num_positions': len(self.positions), 'realized_pnl': str(realized), 'total_return': str((total_val - self.initial_capital) / self.initial_capital), 'num_trades': len(self.trade_history), 'positions': list(self.positions.keys()), } def __repr__(self) -> str: """String representation of the portfolio.""" with self._lock: total_val = self.total_value() return ( f"Portfolio(value={total_val}, cash={self.cash}, " f"positions={len(self.positions)})" )