682 lines
22 KiB
Python
682 lines
22 KiB
Python
"""
|
|
Core portfolio management for the TradingAgents framework.
|
|
|
|
This module provides the main Portfolio class for managing positions,
|
|
executing orders, tracking P&L, and calculating risk metrics.
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from decimal import Decimal
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
import threading
|
|
import logging
|
|
|
|
from tradingagents.security import validate_ticker
|
|
from .position import Position
|
|
from .orders import (
|
|
Order, MarketOrder, LimitOrder, StopLossOrder, TakeProfitOrder,
|
|
OrderStatus, create_order_from_dict
|
|
)
|
|
from .risk import RiskManager, RiskLimits
|
|
from .analytics import PerformanceAnalytics, TradeRecord, PerformanceMetrics
|
|
from .persistence import PortfolioPersistence
|
|
from .exceptions import (
|
|
InsufficientFundsError,
|
|
InsufficientSharesError,
|
|
InvalidOrderError,
|
|
PositionNotFoundError,
|
|
RiskLimitExceededError,
|
|
ValidationError,
|
|
PersistenceError,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Portfolio:
|
|
"""
|
|
Main portfolio management class.
|
|
|
|
This class manages a portfolio of positions, handles order execution,
|
|
tracks cash and P&L, enforces risk limits, and provides performance
|
|
analytics.
|
|
|
|
Thread-safe for concurrent operations.
|
|
|
|
Attributes:
|
|
initial_capital: Initial portfolio capital
|
|
cash: Current cash balance
|
|
positions: Dictionary of current positions (ticker -> Position)
|
|
commission_rate: Commission rate as a fraction (e.g., 0.001 for 0.1%)
|
|
risk_manager: Risk management component
|
|
analytics: Performance analytics component
|
|
persistence: Persistence component
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
initial_capital: Decimal,
|
|
commission_rate: Decimal = Decimal('0.001'),
|
|
risk_limits: Optional[RiskLimits] = None,
|
|
persist_dir: Optional[str] = None
|
|
):
|
|
"""
|
|
Initialize a new portfolio.
|
|
|
|
Args:
|
|
initial_capital: Starting capital
|
|
commission_rate: Commission rate as a fraction (default 0.1%)
|
|
risk_limits: Risk limits configuration (uses defaults if None)
|
|
persist_dir: Directory for persistence (default ./portfolio_data)
|
|
|
|
Raises:
|
|
ValidationError: If inputs are invalid
|
|
"""
|
|
# Validate inputs
|
|
if not isinstance(initial_capital, Decimal):
|
|
try:
|
|
initial_capital = Decimal(str(initial_capital))
|
|
except (ValueError, TypeError) as e:
|
|
raise ValidationError(f"Invalid initial capital: {e}")
|
|
|
|
if initial_capital <= 0:
|
|
raise ValidationError("Initial capital must be positive")
|
|
|
|
if not isinstance(commission_rate, Decimal):
|
|
try:
|
|
commission_rate = Decimal(str(commission_rate))
|
|
except (ValueError, TypeError) as e:
|
|
raise ValidationError(f"Invalid commission rate: {e}")
|
|
|
|
if commission_rate < 0 or commission_rate > 1:
|
|
raise ValidationError("Commission rate must be between 0 and 1")
|
|
|
|
# Initialize core attributes
|
|
self.initial_capital = initial_capital
|
|
self.cash = initial_capital
|
|
self.commission_rate = commission_rate
|
|
self.positions: Dict[str, Position] = {}
|
|
|
|
# Trade tracking
|
|
self.trade_history: List[TradeRecord] = []
|
|
self.closed_positions: Dict[str, List[Position]] = {}
|
|
self.pending_orders: List[Order] = []
|
|
|
|
# Equity curve tracking
|
|
self.equity_curve: List[Tuple[datetime, Decimal]] = [
|
|
(datetime.now(), initial_capital)
|
|
]
|
|
|
|
# Peak tracking for drawdown
|
|
self.peak_value = initial_capital
|
|
|
|
# Components
|
|
self.risk_manager = RiskManager(risk_limits)
|
|
self.analytics = PerformanceAnalytics()
|
|
self.persistence = PortfolioPersistence(persist_dir)
|
|
|
|
# Thread safety
|
|
self._lock = threading.RLock()
|
|
|
|
logger.info(
|
|
f"Initialized portfolio with capital={initial_capital}, "
|
|
f"commission={commission_rate}"
|
|
)
|
|
|
|
def execute_order(
|
|
self,
|
|
order: Order,
|
|
current_price: Decimal,
|
|
check_risk: bool = True
|
|
) -> None:
|
|
"""
|
|
Execute an order at the current price.
|
|
|
|
Args:
|
|
order: Order to execute
|
|
current_price: Current market price
|
|
check_risk: Whether to check risk limits (default True)
|
|
|
|
Raises:
|
|
InvalidOrderError: If order cannot be executed
|
|
InsufficientFundsError: If insufficient cash for buy order
|
|
InsufficientSharesError: If insufficient shares for sell order
|
|
RiskLimitExceededError: If trade would exceed risk limits
|
|
ValidationError: If inputs are invalid
|
|
"""
|
|
with self._lock:
|
|
# Validate price
|
|
if not isinstance(current_price, Decimal):
|
|
try:
|
|
current_price = Decimal(str(current_price))
|
|
except (ValueError, TypeError) as e:
|
|
raise ValidationError(f"Invalid current price: {e}")
|
|
|
|
if current_price <= 0:
|
|
raise ValidationError("Current price must be positive")
|
|
|
|
# Check if order can execute at current price
|
|
if not order.can_execute(current_price):
|
|
raise InvalidOrderError(
|
|
f"Order cannot execute at current price {current_price}"
|
|
)
|
|
|
|
# Calculate order value and commission
|
|
order_value = abs(order.quantity) * current_price
|
|
commission = order_value * self.commission_rate
|
|
|
|
# Execute based on order side
|
|
if order.is_buy:
|
|
self._execute_buy_order(
|
|
order, current_price, order_value, commission, check_risk
|
|
)
|
|
else:
|
|
self._execute_sell_order(
|
|
order, current_price, order_value, commission, check_risk
|
|
)
|
|
|
|
# Mark order as executed
|
|
order.mark_executed(abs(order.quantity), current_price)
|
|
|
|
# Update equity curve
|
|
self._update_equity_curve(current_price)
|
|
|
|
logger.info(
|
|
f"Executed {order.side.value} order: {order.ticker} "
|
|
f"qty={abs(order.quantity)} price={current_price} "
|
|
f"commission={commission}"
|
|
)
|
|
|
|
def _execute_buy_order(
|
|
self,
|
|
order: Order,
|
|
current_price: Decimal,
|
|
order_value: Decimal,
|
|
commission: Decimal,
|
|
check_risk: bool
|
|
) -> None:
|
|
"""Execute a buy order."""
|
|
total_cost = order_value + commission
|
|
|
|
# Check sufficient funds
|
|
if total_cost > self.cash:
|
|
raise InsufficientFundsError(
|
|
f"Insufficient funds: need {total_cost}, have {self.cash}"
|
|
)
|
|
|
|
# Risk checks
|
|
if check_risk:
|
|
# Check position size limit
|
|
portfolio_value = self.total_value()
|
|
new_position_value = order_value
|
|
|
|
if order.ticker in self.positions:
|
|
current_position_value = self.positions[order.ticker].market_value(current_price)
|
|
new_position_value += current_position_value
|
|
|
|
self.risk_manager.check_position_size_limit(
|
|
new_position_value, portfolio_value, order.ticker
|
|
)
|
|
|
|
# Check cash reserve
|
|
new_cash = self.cash - total_cost
|
|
self.risk_manager.check_cash_reserve(new_cash, portfolio_value)
|
|
|
|
# Update or create position
|
|
if order.ticker in self.positions:
|
|
# Add to existing position
|
|
position = self.positions[order.ticker]
|
|
position.update_cost_basis(order.quantity, current_price)
|
|
position.update_quantity(order.quantity)
|
|
else:
|
|
# Create new position
|
|
self.positions[order.ticker] = Position(
|
|
ticker=order.ticker,
|
|
quantity=order.quantity,
|
|
cost_basis=current_price,
|
|
metadata=order.metadata
|
|
)
|
|
|
|
# Deduct cash
|
|
self.cash -= total_cost
|
|
|
|
def _execute_sell_order(
|
|
self,
|
|
order: Order,
|
|
current_price: Decimal,
|
|
order_value: Decimal,
|
|
commission: Decimal,
|
|
check_risk: bool
|
|
) -> None:
|
|
"""Execute a sell order."""
|
|
# Check if position exists
|
|
if order.ticker not in self.positions:
|
|
raise PositionNotFoundError(
|
|
f"No position in {order.ticker} to sell"
|
|
)
|
|
|
|
position = self.positions[order.ticker]
|
|
sell_quantity = abs(order.quantity)
|
|
|
|
# Check sufficient shares
|
|
if sell_quantity > abs(position.quantity):
|
|
raise InsufficientSharesError(
|
|
f"Insufficient shares: trying to sell {sell_quantity}, "
|
|
f"have {abs(position.quantity)}"
|
|
)
|
|
|
|
# Calculate P&L for this sale
|
|
cost_basis_value = sell_quantity * position.cost_basis
|
|
sale_proceeds = order_value - commission
|
|
pnl = sale_proceeds - cost_basis_value
|
|
pnl_percent = pnl / cost_basis_value if cost_basis_value > 0 else Decimal('0')
|
|
|
|
# Check if closing entire position
|
|
if sell_quantity == abs(position.quantity):
|
|
# Record completed trade
|
|
trade_record = TradeRecord(
|
|
ticker=order.ticker,
|
|
entry_date=position.opened_at,
|
|
exit_date=datetime.now(),
|
|
entry_price=position.cost_basis,
|
|
exit_price=current_price,
|
|
quantity=position.quantity,
|
|
pnl=pnl,
|
|
pnl_percent=pnl_percent,
|
|
commission=commission,
|
|
holding_period=(datetime.now() - position.opened_at).days,
|
|
is_win=pnl > 0
|
|
)
|
|
self.trade_history.append(trade_record)
|
|
|
|
# Move to closed positions
|
|
if order.ticker not in self.closed_positions:
|
|
self.closed_positions[order.ticker] = []
|
|
self.closed_positions[order.ticker].append(position)
|
|
|
|
# Remove from active positions
|
|
del self.positions[order.ticker]
|
|
|
|
else:
|
|
# Partially close position
|
|
position.update_quantity(-sell_quantity)
|
|
|
|
# Add proceeds to cash
|
|
self.cash += sale_proceeds
|
|
|
|
def get_position(self, ticker: str) -> Optional[Position]:
|
|
"""
|
|
Get a position by ticker.
|
|
|
|
Args:
|
|
ticker: Ticker symbol
|
|
|
|
Returns:
|
|
Position object or None if not found
|
|
|
|
Raises:
|
|
ValidationError: If ticker is invalid
|
|
"""
|
|
with self._lock:
|
|
try:
|
|
ticker = validate_ticker(ticker)
|
|
except ValueError as e:
|
|
raise ValidationError(f"Invalid ticker: {e}")
|
|
|
|
return self.positions.get(ticker)
|
|
|
|
def get_all_positions(self) -> Dict[str, Position]:
|
|
"""
|
|
Get all current positions.
|
|
|
|
Returns:
|
|
Dictionary mapping ticker to Position
|
|
"""
|
|
with self._lock:
|
|
return self.positions.copy()
|
|
|
|
def total_value(self, prices: Optional[Dict[str, Decimal]] = None) -> Decimal:
|
|
"""
|
|
Calculate total portfolio value.
|
|
|
|
Args:
|
|
prices: Optional dict of current prices (ticker -> price)
|
|
If None, uses cost basis for positions
|
|
|
|
Returns:
|
|
Total portfolio value (cash + positions)
|
|
|
|
Raises:
|
|
ValidationError: If prices are invalid
|
|
"""
|
|
with self._lock:
|
|
total = self.cash
|
|
|
|
for ticker, position in self.positions.items():
|
|
if prices and ticker in prices:
|
|
price = prices[ticker]
|
|
if not isinstance(price, Decimal):
|
|
price = Decimal(str(price))
|
|
if price <= 0:
|
|
raise ValidationError(f"Invalid price for {ticker}: {price}")
|
|
total += position.market_value(price)
|
|
else:
|
|
# Use cost basis if no price provided
|
|
total += position.total_cost()
|
|
|
|
return total
|
|
|
|
def unrealized_pnl(self, prices: Dict[str, Decimal]) -> Decimal:
|
|
"""
|
|
Calculate total unrealized P&L.
|
|
|
|
Args:
|
|
prices: Dictionary of current prices (ticker -> price)
|
|
|
|
Returns:
|
|
Total unrealized P&L
|
|
|
|
Raises:
|
|
ValidationError: If prices are invalid
|
|
"""
|
|
with self._lock:
|
|
total_pnl = Decimal('0')
|
|
|
|
for ticker, position in self.positions.items():
|
|
if ticker in prices:
|
|
price = prices[ticker]
|
|
if not isinstance(price, Decimal):
|
|
price = Decimal(str(price))
|
|
total_pnl += position.unrealized_pnl(price)
|
|
|
|
return total_pnl
|
|
|
|
def realized_pnl(self) -> Decimal:
|
|
"""
|
|
Calculate total realized P&L from closed trades.
|
|
|
|
Returns:
|
|
Total realized P&L
|
|
"""
|
|
with self._lock:
|
|
return sum(trade.pnl for trade in self.trade_history)
|
|
|
|
def get_performance_metrics(
|
|
self,
|
|
risk_free_rate: Decimal = Decimal('0.02')
|
|
) -> PerformanceMetrics:
|
|
"""
|
|
Get comprehensive performance metrics.
|
|
|
|
Args:
|
|
risk_free_rate: Annual risk-free rate (default 2%)
|
|
|
|
Returns:
|
|
PerformanceMetrics object
|
|
|
|
Raises:
|
|
ValidationError: If risk_free_rate is invalid
|
|
"""
|
|
with self._lock:
|
|
return self.analytics.generate_performance_metrics(
|
|
self.equity_curve,
|
|
self.trade_history,
|
|
self.initial_capital,
|
|
risk_free_rate
|
|
)
|
|
|
|
def get_equity_curve(self) -> List[Tuple[datetime, Decimal]]:
|
|
"""
|
|
Get the equity curve.
|
|
|
|
Returns:
|
|
List of (datetime, value) tuples
|
|
"""
|
|
with self._lock:
|
|
return self.equity_curve.copy()
|
|
|
|
def _update_equity_curve(
|
|
self,
|
|
current_price: Optional[Decimal] = None,
|
|
prices: Optional[Dict[str, Decimal]] = None
|
|
) -> None:
|
|
"""
|
|
Update the equity curve with current portfolio value.
|
|
|
|
Args:
|
|
current_price: Single price to use for all positions
|
|
prices: Dictionary of prices per ticker
|
|
"""
|
|
if prices is None and current_price is None:
|
|
# Use cost basis
|
|
value = self.total_value()
|
|
elif prices is not None:
|
|
value = self.total_value(prices)
|
|
else:
|
|
# Use single price for all positions
|
|
price_dict = {ticker: current_price for ticker in self.positions.keys()}
|
|
value = self.total_value(price_dict)
|
|
|
|
self.equity_curve.append((datetime.now(), value))
|
|
|
|
# Update peak value
|
|
if value > self.peak_value:
|
|
self.peak_value = value
|
|
|
|
def check_stop_loss_triggers(
|
|
self,
|
|
prices: Dict[str, Decimal]
|
|
) -> List[Order]:
|
|
"""
|
|
Check if any positions should trigger stop-loss orders.
|
|
|
|
Args:
|
|
prices: Dictionary of current prices
|
|
|
|
Returns:
|
|
List of stop-loss orders that should be executed
|
|
"""
|
|
with self._lock:
|
|
stop_loss_orders = []
|
|
|
|
for ticker, position in self.positions.items():
|
|
if ticker not in prices:
|
|
continue
|
|
|
|
price = prices[ticker]
|
|
if not isinstance(price, Decimal):
|
|
price = Decimal(str(price))
|
|
|
|
if position.should_trigger_stop_loss(price):
|
|
# Create stop-loss order to close position
|
|
order = StopLossOrder(
|
|
ticker=ticker,
|
|
quantity=-position.quantity, # Opposite sign to close
|
|
stop_price=position.stop_loss
|
|
)
|
|
stop_loss_orders.append(order)
|
|
|
|
logger.warning(
|
|
f"Stop-loss triggered for {ticker} at {price} "
|
|
f"(stop={position.stop_loss})"
|
|
)
|
|
|
|
return stop_loss_orders
|
|
|
|
def check_take_profit_triggers(
|
|
self,
|
|
prices: Dict[str, Decimal]
|
|
) -> List[Order]:
|
|
"""
|
|
Check if any positions should trigger take-profit orders.
|
|
|
|
Args:
|
|
prices: Dictionary of current prices
|
|
|
|
Returns:
|
|
List of take-profit orders that should be executed
|
|
"""
|
|
with self._lock:
|
|
take_profit_orders = []
|
|
|
|
for ticker, position in self.positions.items():
|
|
if ticker not in prices:
|
|
continue
|
|
|
|
price = prices[ticker]
|
|
if not isinstance(price, Decimal):
|
|
price = Decimal(str(price))
|
|
|
|
if position.should_trigger_take_profit(price):
|
|
# Create take-profit order to close position
|
|
order = TakeProfitOrder(
|
|
ticker=ticker,
|
|
quantity=-position.quantity, # Opposite sign to close
|
|
target_price=position.take_profit
|
|
)
|
|
take_profit_orders.append(order)
|
|
|
|
logger.info(
|
|
f"Take-profit triggered for {ticker} at {price} "
|
|
f"(target={position.take_profit})"
|
|
)
|
|
|
|
return take_profit_orders
|
|
|
|
def save(self, filename: str = 'portfolio_state.json') -> None:
|
|
"""
|
|
Save portfolio state to a file.
|
|
|
|
Args:
|
|
filename: Name of the file to save to
|
|
|
|
Raises:
|
|
PersistenceError: If save fails
|
|
"""
|
|
with self._lock:
|
|
portfolio_data = self.to_dict()
|
|
self.persistence.save_to_json(portfolio_data, filename)
|
|
logger.info(f"Saved portfolio to {filename}")
|
|
|
|
@classmethod
|
|
def load(cls, filename: str = 'portfolio_state.json', persist_dir: Optional[str] = None) -> 'Portfolio':
|
|
"""
|
|
Load portfolio state from a file.
|
|
|
|
Args:
|
|
filename: Name of the file to load from
|
|
persist_dir: Directory containing the file
|
|
|
|
Returns:
|
|
Portfolio instance
|
|
|
|
Raises:
|
|
PersistenceError: If load fails
|
|
"""
|
|
persistence = PortfolioPersistence(persist_dir)
|
|
portfolio_data = persistence.load_from_json(filename)
|
|
|
|
# Create portfolio with loaded data
|
|
portfolio = cls(
|
|
initial_capital=portfolio_data['initial_capital'],
|
|
commission_rate=portfolio_data['commission_rate'],
|
|
persist_dir=persist_dir
|
|
)
|
|
|
|
# Restore state
|
|
portfolio.cash = portfolio_data['cash']
|
|
|
|
# Restore positions
|
|
for ticker, pos_data in portfolio_data.get('positions', {}).items():
|
|
portfolio.positions[ticker] = Position.from_dict(pos_data)
|
|
|
|
# Restore trade history
|
|
for trade_data in portfolio_data.get('trade_history', []):
|
|
trade = TradeRecord(
|
|
ticker=trade_data['ticker'],
|
|
entry_date=datetime.fromisoformat(trade_data['entry_date']),
|
|
exit_date=datetime.fromisoformat(trade_data['exit_date']),
|
|
entry_price=Decimal(trade_data['entry_price']),
|
|
exit_price=Decimal(trade_data['exit_price']),
|
|
quantity=Decimal(trade_data['quantity']),
|
|
pnl=Decimal(trade_data['pnl']),
|
|
pnl_percent=Decimal(trade_data['pnl_percent']),
|
|
commission=Decimal(trade_data['commission']),
|
|
holding_period=trade_data['holding_period'],
|
|
is_win=trade_data['is_win']
|
|
)
|
|
portfolio.trade_history.append(trade)
|
|
|
|
# Restore equity curve
|
|
for point in portfolio_data.get('equity_curve', []):
|
|
portfolio.equity_curve.append((
|
|
datetime.fromisoformat(point[0]),
|
|
Decimal(point[1])
|
|
))
|
|
|
|
# Restore peak value
|
|
portfolio.peak_value = portfolio_data.get('peak_value', portfolio.initial_capital)
|
|
|
|
logger.info(f"Loaded portfolio from {filename}")
|
|
|
|
return portfolio
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""
|
|
Convert portfolio to dictionary for serialization.
|
|
|
|
Returns:
|
|
Dictionary representation of the portfolio
|
|
"""
|
|
with self._lock:
|
|
return {
|
|
'initial_capital': str(self.initial_capital),
|
|
'cash': str(self.cash),
|
|
'commission_rate': str(self.commission_rate),
|
|
'positions': {
|
|
ticker: position.to_dict()
|
|
for ticker, position in self.positions.items()
|
|
},
|
|
'trade_history': [
|
|
trade.to_dict() for trade in self.trade_history
|
|
],
|
|
'equity_curve': [
|
|
(dt.isoformat(), str(value))
|
|
for dt, value in self.equity_curve
|
|
],
|
|
'peak_value': str(self.peak_value),
|
|
'timestamp': datetime.now().isoformat(),
|
|
}
|
|
|
|
def summary(self) -> Dict[str, Any]:
|
|
"""
|
|
Get a summary of the portfolio.
|
|
|
|
Returns:
|
|
Dictionary with portfolio summary
|
|
"""
|
|
with self._lock:
|
|
total_val = self.total_value()
|
|
realized = self.realized_pnl()
|
|
|
|
return {
|
|
'total_value': str(total_val),
|
|
'cash': str(self.cash),
|
|
'invested': str(total_val - self.cash),
|
|
'num_positions': len(self.positions),
|
|
'realized_pnl': str(realized),
|
|
'total_return': str((total_val - self.initial_capital) / self.initial_capital),
|
|
'num_trades': len(self.trade_history),
|
|
'positions': list(self.positions.keys()),
|
|
}
|
|
|
|
def __repr__(self) -> str:
|
|
"""String representation of the portfolio."""
|
|
with self._lock:
|
|
total_val = self.total_value()
|
|
return (
|
|
f"Portfolio(value={total_val}, cash={self.cash}, "
|
|
f"positions={len(self.positions)})"
|
|
)
|