486 lines
17 KiB
Python
486 lines
17 KiB
Python
"""
|
|
Integration layer between the portfolio management system and TradingAgents.
|
|
|
|
This module provides functionality to connect the portfolio to the TradingAgentsGraph,
|
|
execute trading decisions from agents, and provide portfolio context to agents.
|
|
"""
|
|
|
|
from datetime import datetime
|
|
from decimal import Decimal
|
|
from typing import Dict, List, Optional, Any, Callable
|
|
import logging
|
|
|
|
from .portfolio import Portfolio
|
|
from .orders import MarketOrder, LimitOrder, OrderType
|
|
from .exceptions import (
|
|
InvalidOrderError,
|
|
InsufficientFundsError,
|
|
IntegrationError,
|
|
ValidationError,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TradingAgentsPortfolioIntegration:
|
|
"""
|
|
Integrates portfolio management with TradingAgents framework.
|
|
|
|
This class connects the portfolio to TradingAgentsGraph, executes
|
|
decisions from agents, and provides portfolio context for decision-making.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
portfolio: Portfolio,
|
|
price_fetcher: Optional[Callable[[str], Decimal]] = None
|
|
):
|
|
"""
|
|
Initialize the integration layer.
|
|
|
|
Args:
|
|
portfolio: Portfolio instance to manage
|
|
price_fetcher: Optional function to fetch current prices (ticker -> price)
|
|
If None, prices must be provided with each operation
|
|
"""
|
|
self.portfolio = portfolio
|
|
self.price_fetcher = price_fetcher
|
|
self.execution_history: List[Dict[str, Any]] = []
|
|
|
|
logger.info("Initialized TradingAgentsPortfolioIntegration")
|
|
|
|
def execute_agent_decision(
|
|
self,
|
|
decision: Dict[str, Any],
|
|
current_prices: Optional[Dict[str, Decimal]] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Execute a trading decision from TradingAgents.
|
|
|
|
Expected decision format:
|
|
{
|
|
'action': 'buy' | 'sell' | 'hold',
|
|
'ticker': str,
|
|
'quantity': int | float | Decimal (optional, uses position sizing if not provided),
|
|
'order_type': 'market' | 'limit' (optional, default 'market'),
|
|
'limit_price': Decimal (required if order_type is 'limit'),
|
|
'reasoning': str (optional),
|
|
}
|
|
|
|
Args:
|
|
decision: Trading decision from agent
|
|
current_prices: Optional dict of current prices
|
|
|
|
Returns:
|
|
Execution result with status and details
|
|
|
|
Raises:
|
|
IntegrationError: If decision format is invalid
|
|
InvalidOrderError: If order cannot be executed
|
|
"""
|
|
try:
|
|
# Validate decision format
|
|
if not isinstance(decision, dict):
|
|
raise IntegrationError("Decision must be a dictionary")
|
|
|
|
action = decision.get('action', '').lower()
|
|
if action not in ['buy', 'sell', 'hold']:
|
|
raise IntegrationError(f"Invalid action: {action}")
|
|
|
|
ticker = decision.get('ticker')
|
|
if not ticker:
|
|
raise IntegrationError("Ticker is required")
|
|
|
|
# Handle 'hold' action
|
|
if action == 'hold':
|
|
result = {
|
|
'status': 'success',
|
|
'action': 'hold',
|
|
'ticker': ticker,
|
|
'message': 'No action taken',
|
|
}
|
|
self._log_execution(decision, result)
|
|
return result
|
|
|
|
# Get current price
|
|
current_price = self._get_price(ticker, current_prices)
|
|
|
|
# Determine quantity
|
|
quantity = self._determine_quantity(decision, ticker, current_price)
|
|
|
|
# Create and execute order
|
|
order = self._create_order(decision, ticker, quantity)
|
|
|
|
# Execute order
|
|
self.portfolio.execute_order(order, current_price)
|
|
|
|
result = {
|
|
'status': 'success',
|
|
'action': action,
|
|
'ticker': ticker,
|
|
'quantity': str(quantity),
|
|
'price': str(current_price),
|
|
'order_type': decision.get('order_type', 'market'),
|
|
'commission': str(self.portfolio.commission_rate),
|
|
'reasoning': decision.get('reasoning', ''),
|
|
}
|
|
|
|
self._log_execution(decision, result)
|
|
|
|
logger.info(
|
|
f"Executed agent decision: {action} {ticker} "
|
|
f"qty={quantity} price={current_price}"
|
|
)
|
|
|
|
return result
|
|
|
|
except (InvalidOrderError, InsufficientFundsError) as e:
|
|
# Trading errors - expected in normal operation
|
|
result = {
|
|
'status': 'failed',
|
|
'action': decision.get('action'),
|
|
'ticker': decision.get('ticker'),
|
|
'error': str(e),
|
|
'error_type': type(e).__name__,
|
|
}
|
|
self._log_execution(decision, result)
|
|
logger.warning(f"Failed to execute decision: {e}")
|
|
return result
|
|
|
|
except Exception as e:
|
|
# Unexpected errors
|
|
result = {
|
|
'status': 'error',
|
|
'action': decision.get('action'),
|
|
'ticker': decision.get('ticker'),
|
|
'error': str(e),
|
|
'error_type': type(e).__name__,
|
|
}
|
|
self._log_execution(decision, result)
|
|
logger.error(f"Error executing decision: {e}", exc_info=True)
|
|
raise IntegrationError(f"Failed to execute decision: {e}")
|
|
|
|
def get_portfolio_context(
|
|
self,
|
|
current_prices: Optional[Dict[str, Decimal]] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Get portfolio context for agent decision-making.
|
|
|
|
Provides current portfolio state, positions, and performance metrics
|
|
that agents can use to make informed trading decisions.
|
|
|
|
Args:
|
|
current_prices: Optional dict of current prices
|
|
|
|
Returns:
|
|
Dictionary with portfolio context information
|
|
"""
|
|
try:
|
|
# Get current prices for all positions
|
|
if current_prices is None and self.price_fetcher is not None:
|
|
current_prices = {}
|
|
for ticker in self.portfolio.positions.keys():
|
|
try:
|
|
current_prices[ticker] = self.price_fetcher(ticker)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch price for {ticker}: {e}")
|
|
|
|
# Calculate portfolio metrics
|
|
total_value = self.portfolio.total_value(current_prices)
|
|
unrealized_pnl = self.portfolio.unrealized_pnl(current_prices) if current_prices else Decimal('0')
|
|
realized_pnl = self.portfolio.realized_pnl()
|
|
|
|
# Position details
|
|
positions_context = []
|
|
for ticker, position in self.portfolio.get_all_positions().items():
|
|
pos_context = {
|
|
'ticker': ticker,
|
|
'quantity': str(position.quantity),
|
|
'cost_basis': str(position.cost_basis),
|
|
'is_long': position.is_long,
|
|
}
|
|
|
|
if current_prices and ticker in current_prices:
|
|
price = current_prices[ticker]
|
|
pos_context.update({
|
|
'current_price': str(price),
|
|
'market_value': str(position.market_value(price)),
|
|
'unrealized_pnl': str(position.unrealized_pnl(price)),
|
|
'unrealized_pnl_pct': str(position.unrealized_pnl_percent(price)),
|
|
})
|
|
|
|
positions_context.append(pos_context)
|
|
|
|
# Performance metrics (if we have enough data)
|
|
performance = None
|
|
try:
|
|
if len(self.portfolio.trade_history) > 0:
|
|
metrics = self.portfolio.get_performance_metrics()
|
|
performance = {
|
|
'total_trades': metrics.total_trades,
|
|
'win_rate': str(metrics.win_rate),
|
|
'profit_factor': str(metrics.profit_factor),
|
|
'sharpe_ratio': str(metrics.sharpe_ratio),
|
|
'max_drawdown': str(metrics.max_drawdown),
|
|
}
|
|
except Exception as e:
|
|
logger.debug(f"Could not calculate performance metrics: {e}")
|
|
|
|
context = {
|
|
'total_value': str(total_value),
|
|
'cash': str(self.portfolio.cash),
|
|
'cash_pct': str(self.portfolio.cash / total_value if total_value > 0 else Decimal('1')),
|
|
'invested_value': str(total_value - self.portfolio.cash),
|
|
'unrealized_pnl': str(unrealized_pnl),
|
|
'realized_pnl': str(realized_pnl),
|
|
'total_pnl': str(unrealized_pnl + realized_pnl),
|
|
'total_return': str((total_value - self.portfolio.initial_capital) / self.portfolio.initial_capital),
|
|
'num_positions': len(self.portfolio.positions),
|
|
'positions': positions_context,
|
|
'performance': performance,
|
|
'timestamp': datetime.now().isoformat(),
|
|
}
|
|
|
|
return context
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting portfolio context: {e}", exc_info=True)
|
|
raise IntegrationError(f"Failed to get portfolio context: {e}")
|
|
|
|
def batch_execute_decisions(
|
|
self,
|
|
decisions: List[Dict[str, Any]],
|
|
current_prices: Optional[Dict[str, Decimal]] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Execute multiple trading decisions in batch.
|
|
|
|
Args:
|
|
decisions: List of trading decisions
|
|
current_prices: Optional dict of current prices
|
|
|
|
Returns:
|
|
List of execution results
|
|
"""
|
|
results = []
|
|
|
|
for decision in decisions:
|
|
try:
|
|
result = self.execute_agent_decision(decision, current_prices)
|
|
results.append(result)
|
|
except Exception as e:
|
|
logger.error(f"Error in batch execution: {e}")
|
|
results.append({
|
|
'status': 'error',
|
|
'decision': decision,
|
|
'error': str(e),
|
|
})
|
|
|
|
return results
|
|
|
|
def rebalance_portfolio(
|
|
self,
|
|
target_weights: Dict[str, Decimal],
|
|
current_prices: Dict[str, Decimal]
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Rebalance portfolio to target weights.
|
|
|
|
Args:
|
|
target_weights: Dictionary mapping ticker to target weight (as fraction)
|
|
current_prices: Dictionary of current prices
|
|
|
|
Returns:
|
|
List of execution results
|
|
|
|
Raises:
|
|
ValidationError: If target weights are invalid
|
|
IntegrationError: If rebalancing fails
|
|
"""
|
|
try:
|
|
# Validate target weights
|
|
total_weight = sum(target_weights.values())
|
|
if abs(total_weight - Decimal('1')) > Decimal('0.01'):
|
|
raise ValidationError(
|
|
f"Target weights must sum to 1.0, got {total_weight}"
|
|
)
|
|
|
|
# Calculate current portfolio value
|
|
current_value = self.portfolio.total_value(current_prices)
|
|
|
|
# Calculate target values
|
|
target_values = {
|
|
ticker: current_value * weight
|
|
for ticker, weight in target_weights.items()
|
|
}
|
|
|
|
# Calculate required trades
|
|
decisions = []
|
|
|
|
for ticker, target_value in target_values.items():
|
|
current_position = self.portfolio.get_position(ticker)
|
|
current_value_ticker = Decimal('0')
|
|
|
|
if current_position and ticker in current_prices:
|
|
current_value_ticker = current_position.market_value(current_prices[ticker])
|
|
|
|
# Calculate difference
|
|
difference = target_value - current_value_ticker
|
|
|
|
# Only trade if difference is significant (> 1% of target)
|
|
if abs(difference) < target_value * Decimal('0.01'):
|
|
continue
|
|
|
|
# Create decision
|
|
if ticker in current_prices:
|
|
price = current_prices[ticker]
|
|
quantity = difference / price
|
|
|
|
decision = {
|
|
'action': 'buy' if quantity > 0 else 'sell',
|
|
'ticker': ticker,
|
|
'quantity': abs(quantity),
|
|
'order_type': 'market',
|
|
'reasoning': f'Rebalancing to target weight {target_weights[ticker]:.2%}',
|
|
}
|
|
decisions.append(decision)
|
|
|
|
# Execute all rebalancing trades
|
|
results = self.batch_execute_decisions(decisions, current_prices)
|
|
|
|
logger.info(f"Completed portfolio rebalancing with {len(results)} trades")
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error rebalancing portfolio: {e}", exc_info=True)
|
|
raise IntegrationError(f"Failed to rebalance portfolio: {e}")
|
|
|
|
def _get_price(
|
|
self,
|
|
ticker: str,
|
|
current_prices: Optional[Dict[str, Decimal]] = None
|
|
) -> Decimal:
|
|
"""Get current price for a ticker."""
|
|
# Try provided prices first
|
|
if current_prices and ticker in current_prices:
|
|
price = current_prices[ticker]
|
|
if not isinstance(price, Decimal):
|
|
price = Decimal(str(price))
|
|
return price
|
|
|
|
# Try price fetcher
|
|
if self.price_fetcher:
|
|
try:
|
|
price = self.price_fetcher(ticker)
|
|
if not isinstance(price, Decimal):
|
|
price = Decimal(str(price))
|
|
return price
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch price for {ticker}: {e}")
|
|
|
|
raise IntegrationError(
|
|
f"No price available for {ticker}. "
|
|
"Provide current_prices or configure price_fetcher."
|
|
)
|
|
|
|
def _determine_quantity(
|
|
self,
|
|
decision: Dict[str, Any],
|
|
ticker: str,
|
|
current_price: Decimal
|
|
) -> Decimal:
|
|
"""Determine trade quantity from decision."""
|
|
# Check if quantity is explicitly provided
|
|
if 'quantity' in decision:
|
|
quantity = decision['quantity']
|
|
if not isinstance(quantity, Decimal):
|
|
quantity = Decimal(str(quantity))
|
|
return quantity
|
|
|
|
# Use position sizing if available
|
|
if 'position_size_pct' in decision:
|
|
pct = Decimal(str(decision['position_size_pct']))
|
|
total_value = self.portfolio.total_value()
|
|
position_value = total_value * pct
|
|
quantity = position_value / current_price
|
|
return quantity
|
|
|
|
# Default: use 10% of portfolio
|
|
total_value = self.portfolio.total_value()
|
|
default_pct = Decimal('0.10')
|
|
position_value = total_value * default_pct
|
|
quantity = position_value / current_price
|
|
|
|
logger.warning(
|
|
f"No quantity specified for {ticker}, "
|
|
f"using default 10% position size: {quantity}"
|
|
)
|
|
|
|
return quantity
|
|
|
|
def _create_order(
|
|
self,
|
|
decision: Dict[str, Any],
|
|
ticker: str,
|
|
quantity: Decimal
|
|
):
|
|
"""Create an order from a decision."""
|
|
action = decision.get('action', '').lower()
|
|
order_type = decision.get('order_type', 'market').lower()
|
|
|
|
# Adjust quantity sign based on action
|
|
if action == 'sell':
|
|
quantity = -abs(quantity)
|
|
else:
|
|
quantity = abs(quantity)
|
|
|
|
# Create appropriate order type
|
|
if order_type == 'market':
|
|
return MarketOrder(ticker=ticker, quantity=quantity)
|
|
elif order_type == 'limit':
|
|
limit_price = decision.get('limit_price')
|
|
if not limit_price:
|
|
raise IntegrationError("limit_price required for limit orders")
|
|
if not isinstance(limit_price, Decimal):
|
|
limit_price = Decimal(str(limit_price))
|
|
return LimitOrder(ticker=ticker, quantity=quantity, limit_price=limit_price)
|
|
else:
|
|
raise IntegrationError(f"Unsupported order type: {order_type}")
|
|
|
|
def _log_execution(
|
|
self,
|
|
decision: Dict[str, Any],
|
|
result: Dict[str, Any]
|
|
) -> None:
|
|
"""Log execution for audit trail."""
|
|
log_entry = {
|
|
'timestamp': datetime.now().isoformat(),
|
|
'decision': decision,
|
|
'result': result,
|
|
}
|
|
self.execution_history.append(log_entry)
|
|
|
|
def get_execution_history(
|
|
self,
|
|
limit: Optional[int] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get execution history.
|
|
|
|
Args:
|
|
limit: Maximum number of entries to return (most recent first)
|
|
|
|
Returns:
|
|
List of execution log entries
|
|
"""
|
|
if limit:
|
|
return self.execution_history[-limit:]
|
|
return self.execution_history.copy()
|
|
|
|
def clear_execution_history(self) -> None:
|
|
"""Clear the execution history."""
|
|
self.execution_history.clear()
|
|
logger.info("Cleared execution history")
|