diff --git a/autonomous/connectors/ibkr_resilient.py b/autonomous/connectors/ibkr_resilient.py new file mode 100644 index 00000000..74275874 --- /dev/null +++ b/autonomous/connectors/ibkr_resilient.py @@ -0,0 +1,851 @@ +""" +Resilient IBKR Connector +======================== + +Enhanced IBKR connection with auto-reconnection, circuit breakers, +and comprehensive error handling. +""" + +import asyncio +import logging +from typing import Dict, List, Optional, Any, Callable +from datetime import datetime, timedelta +from dataclasses import dataclass, field +from enum import Enum +import time +from collections import deque + +try: + from ib_insync import ( + IB, Stock, Option, Future, Contract, MarketOrder, + LimitOrder, StopOrder, BracketOrder, util + ) + IBKR_AVAILABLE = True +except ImportError: + IBKR_AVAILABLE = False + +from tenacity import ( + retry, stop_after_attempt, wait_exponential, + retry_if_exception_type, before_retry, after_retry +) + +from ..core.database import DatabaseManager, Position, Order, OrderStatus + +logger = logging.getLogger(__name__) + + +class ConnectionState(Enum): + """Connection state enumeration""" + DISCONNECTED = "disconnected" + CONNECTING = "connecting" + CONNECTED = "connected" + RECONNECTING = "reconnecting" + ERROR = "error" + CLOSED = "closed" + + +@dataclass +class ConnectionHealth: + """Connection health metrics""" + state: ConnectionState = ConnectionState.DISCONNECTED + last_heartbeat: Optional[datetime] = None + reconnect_attempts: int = 0 + total_reconnects: int = 0 + errors: deque = field(default_factory=lambda: deque(maxlen=100)) + latency_ms: float = 0.0 + messages_received: int = 0 + orders_placed: int = 0 + orders_failed: int = 0 + + +class CircuitBreaker: + """Circuit breaker pattern for connection management""" + + def __init__(self, failure_threshold: int = 5, timeout: int = 60): + """ + Initialize circuit breaker + + Args: + failure_threshold: Number of failures before opening circuit + timeout: Seconds to wait before attempting to close circuit + """ + self.failure_threshold = failure_threshold + self.timeout = timeout + self.failures = 0 + self.last_failure_time = None + self.state = "closed" # closed, open, half-open + + def call(self, func: Callable, *args, **kwargs): + """Execute function with circuit breaker protection""" + if self.state == "open": + if (datetime.now() - self.last_failure_time).seconds > self.timeout: + self.state = "half-open" + else: + raise Exception("Circuit breaker is open") + + try: + result = func(*args, **kwargs) + if self.state == "half-open": + self.state = "closed" + self.failures = 0 + return result + except Exception as e: + self.failures += 1 + self.last_failure_time = datetime.now() + + if self.failures >= self.failure_threshold: + self.state = "open" + logger.error(f"Circuit breaker opened after {self.failures} failures") + + raise e + + def reset(self): + """Reset circuit breaker""" + self.failures = 0 + self.state = "closed" + self.last_failure_time = None + + +class ResilientIBKRConnector: + """ + Resilient IBKR connector with auto-reconnection and health monitoring + """ + + def __init__(self, + host: str = "127.0.0.1", + port: int = 7497, + client_id: int = 1, + db_manager: Optional[DatabaseManager] = None, + max_reconnect_attempts: int = 10, + reconnect_delay: int = 5): + """ + Initialize resilient IBKR connector + + Args: + host: TWS/Gateway host + port: Connection port + client_id: Unique client ID + db_manager: Database manager for persistence + max_reconnect_attempts: Maximum reconnection attempts + reconnect_delay: Initial delay between reconnection attempts + """ + self.host = host + self.port = port + self.client_id = client_id + self.db = db_manager + + self.ib: Optional[IB] = None + self.health = ConnectionHealth() + self.circuit_breaker = CircuitBreaker() + + self.max_reconnect_attempts = max_reconnect_attempts + self.reconnect_delay = reconnect_delay + + # Callbacks + self.on_connected_callback: Optional[Callable] = None + self.on_disconnected_callback: Optional[Callable] = None + self.on_error_callback: Optional[Callable] = None + + # Connection monitoring + self._heartbeat_task: Optional[asyncio.Task] = None + self._reconnect_task: Optional[asyncio.Task] = None + self._monitor_task: Optional[asyncio.Task] = None + + # WebSocket for real-time data + self._websocket_connected = False + self._market_data_streams: Dict[str, Any] = {} + + # === Connection Management === + + async def connect(self) -> bool: + """ + Connect to IBKR with retry logic + + Returns: + True if connected successfully + """ + if not IBKR_AVAILABLE: + logger.error("ib_insync not installed") + return False + + self.health.state = ConnectionState.CONNECTING + + try: + return await self._connect_with_retry() + except Exception as e: + logger.error(f"Failed to connect after all attempts: {e}") + self.health.state = ConnectionState.ERROR + self.health.errors.append({ + 'timestamp': datetime.now(), + 'error': str(e), + 'type': 'connection' + }) + return False + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=30), + retry=retry_if_exception_type(ConnectionError), + before=before_retry(lambda retry_state: logger.info( + f"Connection attempt {retry_state.attempt_number}" + )) + ) + async def _connect_with_retry(self) -> bool: + """Internal connection with retry logic""" + try: + self.ib = IB() + + # Connect asynchronously + await self.ib.connectAsync( + host=self.host, + port=self.port, + clientId=self.client_id, + timeout=10 + ) + + # Verify connection + if not self.ib.isConnected(): + raise ConnectionError("Connection established but not active") + + # Setup event handlers + self._setup_event_handlers() + + # Start monitoring tasks + await self._start_monitoring() + + self.health.state = ConnectionState.CONNECTED + self.health.last_heartbeat = datetime.now() + self.circuit_breaker.reset() + + logger.info(f"Connected to IBKR at {self.host}:{self.port}") + + # Request initial data + self.ib.reqAccountUpdates() + + # Trigger callback + if self.on_connected_callback: + await self.on_connected_callback() + + return True + + except Exception as e: + logger.error(f"Connection failed: {e}") + raise ConnectionError(f"Failed to connect: {e}") + + async def disconnect(self): + """Gracefully disconnect from IBKR""" + self.health.state = ConnectionState.DISCONNECTED + + # Cancel monitoring tasks + if self._heartbeat_task: + self._heartbeat_task.cancel() + if self._reconnect_task: + self._reconnect_task.cancel() + if self._monitor_task: + self._monitor_task.cancel() + + # Disconnect + if self.ib and self.ib.isConnected(): + self.ib.disconnect() + + logger.info("Disconnected from IBKR") + + if self.on_disconnected_callback: + await self.on_disconnected_callback() + + async def reconnect(self) -> bool: + """ + Reconnect to IBKR with exponential backoff + + Returns: + True if reconnected successfully + """ + if self.health.state == ConnectionState.RECONNECTING: + logger.warning("Already attempting to reconnect") + return False + + self.health.state = ConnectionState.RECONNECTING + self.health.reconnect_attempts = 0 + + delay = self.reconnect_delay + + while self.health.reconnect_attempts < self.max_reconnect_attempts: + self.health.reconnect_attempts += 1 + + logger.info(f"Reconnection attempt {self.health.reconnect_attempts}" + f"/{self.max_reconnect_attempts}") + + # Disconnect existing connection + if self.ib: + try: + self.ib.disconnect() + except: + pass + + # Attempt reconnection + if await self.connect(): + self.health.total_reconnects += 1 + logger.info("Reconnection successful") + return True + + # Exponential backoff + await asyncio.sleep(delay) + delay = min(delay * 2, 300) # Max 5 minutes + + logger.error("Failed to reconnect after all attempts") + self.health.state = ConnectionState.ERROR + return False + + # === Event Handlers === + + def _setup_event_handlers(self): + """Setup IBKR event handlers""" + if not self.ib: + return + + # Connection events + self.ib.connectedEvent += self._on_connected + self.ib.disconnectedEvent += self._on_disconnected + self.ib.errorEvent += self._on_error + + # Order events + self.ib.orderStatusEvent += self._on_order_status + self.ib.execDetailsEvent += self._on_exec_details + + # Market data events + self.ib.tickerUpdateEvent += self._on_ticker_update + + def _on_connected(self): + """Handle connection event""" + logger.info("IBKR connection established") + self.health.state = ConnectionState.CONNECTED + self.health.last_heartbeat = datetime.now() + + def _on_disconnected(self): + """Handle disconnection event""" + logger.warning("IBKR connection lost") + self.health.state = ConnectionState.DISCONNECTED + + # Trigger auto-reconnection + if not self._reconnect_task or self._reconnect_task.done(): + self._reconnect_task = asyncio.create_task(self.reconnect()) + + def _on_error(self, reqId: int, errorCode: int, errorString: str, + contract: Optional[Contract]): + """Handle error events""" + logger.error(f"IBKR Error: {errorCode} - {errorString}") + + self.health.errors.append({ + 'timestamp': datetime.now(), + 'code': errorCode, + 'message': errorString, + 'contract': contract.symbol if contract else None + }) + + # Critical errors that require reconnection + critical_errors = [504, 502, 1100, 1101, 1102] # Connection lost codes + if errorCode in critical_errors: + logger.critical(f"Critical error detected: {errorCode}") + if not self._reconnect_task or self._reconnect_task.done(): + self._reconnect_task = asyncio.create_task(self.reconnect()) + + if self.on_error_callback: + asyncio.create_task(self.on_error_callback(errorCode, errorString)) + + def _on_order_status(self, trade): + """Handle order status updates""" + if self.db: + # Update order in database + asyncio.create_task(self._update_order_status(trade)) + + def _on_exec_details(self, trade, fill): + """Handle execution details""" + logger.info(f"Order executed: {fill.contract.symbol} " + f"{fill.execution.side} {fill.execution.shares} " + f"@ {fill.execution.price}") + + if self.db: + # Save trade to database + asyncio.create_task(self._save_trade(trade, fill)) + + def _on_ticker_update(self, ticker): + """Handle ticker updates""" + # Update market data streams + if ticker.contract.symbol in self._market_data_streams: + self._market_data_streams[ticker.contract.symbol] = { + 'last': ticker.last, + 'bid': ticker.bid, + 'ask': ticker.ask, + 'volume': ticker.volume, + 'timestamp': datetime.now() + } + + # === Monitoring === + + async def _start_monitoring(self): + """Start monitoring tasks""" + self._heartbeat_task = asyncio.create_task(self._heartbeat_monitor()) + self._monitor_task = asyncio.create_task(self._connection_monitor()) + + async def _heartbeat_monitor(self): + """Monitor connection heartbeat""" + while self.health.state in [ConnectionState.CONNECTED, ConnectionState.RECONNECTING]: + try: + if self.ib and self.ib.isConnected(): + # Send heartbeat + start = time.time() + self.ib.reqCurrentTime() + self.health.latency_ms = (time.time() - start) * 1000 + self.health.last_heartbeat = datetime.now() + self.health.messages_received += 1 + else: + # Connection lost + if self.health.state == ConnectionState.CONNECTED: + logger.warning("Heartbeat failed - connection lost") + self._on_disconnected() + + await asyncio.sleep(30) # Check every 30 seconds + + except Exception as e: + logger.error(f"Heartbeat monitor error: {e}") + + async def _connection_monitor(self): + """Monitor overall connection health""" + while self.health.state != ConnectionState.CLOSED: + try: + # Check if heartbeat is stale + if self.health.last_heartbeat: + time_since_heartbeat = (datetime.now() - self.health.last_heartbeat).seconds + + if time_since_heartbeat > 120: # 2 minutes + logger.warning(f"No heartbeat for {time_since_heartbeat} seconds") + + if self.health.state == ConnectionState.CONNECTED: + # Attempt reconnection + if not self._reconnect_task or self._reconnect_task.done(): + self._reconnect_task = asyncio.create_task(self.reconnect()) + + # Log health metrics periodically + if self.health.messages_received % 100 == 0: + logger.info(f"Connection health: latency={self.health.latency_ms:.1f}ms, " + f"messages={self.health.messages_received}, " + f"orders={self.health.orders_placed}") + + await asyncio.sleep(60) # Check every minute + + except Exception as e: + logger.error(f"Connection monitor error: {e}") + + # === Enhanced Order Management === + + async def place_bracket_order(self, + ticker: str, + action: str, + quantity: int, + entry_price: float, + stop_loss: float, + take_profit: float, + idempotency_key: Optional[str] = None) -> Optional[Dict[str, Any]]: + """ + Place a bracket order (entry + stop loss + take profit) + + Args: + ticker: Stock symbol + action: 'BUY' or 'SELL' + quantity: Number of shares + entry_price: Entry limit price + stop_loss: Stop loss price + take_profit: Take profit price + idempotency_key: Unique key to prevent duplicate orders + + Returns: + Dictionary with order IDs if successful + """ + if not self._ensure_connected(): + return None + + try: + # Check idempotency + if idempotency_key and self.db: + existing = self.db.get_session().query(Order).filter_by( + idempotency_key=idempotency_key + ).first() + if existing: + logger.info(f"Order already exists: {idempotency_key}") + return {'parent_id': existing.order_id} + + # Create contract + contract = Stock(ticker, 'SMART', 'USD') + self.ib.qualifyContracts(contract) + + # Create bracket order + bracket = BracketOrder( + action=action, + quantity=quantity, + limitPrice=entry_price, + stopLossPrice=stop_loss, + takeProfitPrice=take_profit + ) + + # Place the bracket order + trades = [] + for order in bracket: + trade = self.ib.placeOrder(contract, order) + trades.append(trade) + await asyncio.sleep(0.1) + + # Wait for orders to be acknowledged + await asyncio.sleep(1) + + # Verify orders + parent_trade = trades[0] + if parent_trade.orderStatus.status in ['Submitted', 'PreSubmitted']: + logger.info(f"Bracket order placed for {ticker}") + + # Save to database + if self.db: + await self._save_bracket_order( + trades, ticker, action, quantity, + entry_price, stop_loss, take_profit, idempotency_key + ) + + self.health.orders_placed += 1 + + return { + 'parent_id': parent_trade.order.orderId, + 'stop_loss_id': trades[1].order.orderId if len(trades) > 1 else None, + 'take_profit_id': trades[2].order.orderId if len(trades) > 2 else None + } + else: + logger.error(f"Bracket order failed: {parent_trade.orderStatus.status}") + self.health.orders_failed += 1 + return None + + except Exception as e: + logger.error(f"Error placing bracket order: {e}") + self.health.orders_failed += 1 + self.health.errors.append({ + 'timestamp': datetime.now(), + 'error': str(e), + 'type': 'order_placement' + }) + return None + + async def cancel_order(self, order_id: int) -> bool: + """ + Cancel an order + + Args: + order_id: IBKR order ID + + Returns: + True if cancelled successfully + """ + if not self._ensure_connected(): + return False + + try: + # Find the order + for trade in self.ib.openTrades(): + if trade.order.orderId == order_id: + self.ib.cancelOrder(trade.order) + await asyncio.sleep(1) + + # Update database + if self.db: + self.db.update_order_status( + str(order_id), + OrderStatus.CANCELLED, + cancelled_at=datetime.now() + ) + + logger.info(f"Order {order_id} cancelled") + return True + + logger.warning(f"Order {order_id} not found") + return False + + except Exception as e: + logger.error(f"Error cancelling order: {e}") + return False + + async def modify_order(self, order_id: int, new_price: float) -> bool: + """ + Modify an existing order + + Args: + order_id: IBKR order ID + new_price: New limit price + + Returns: + True if modified successfully + """ + if not self._ensure_connected(): + return False + + try: + # Find and modify the order + for trade in self.ib.openTrades(): + if trade.order.orderId == order_id: + trade.order.lmtPrice = new_price + self.ib.placeOrder(trade.contract, trade.order) + await asyncio.sleep(1) + + logger.info(f"Order {order_id} modified to {new_price}") + return True + + logger.warning(f"Order {order_id} not found") + return False + + except Exception as e: + logger.error(f"Error modifying order: {e}") + return False + + # === Real-time Data === + + async def subscribe_market_data(self, ticker: str) -> bool: + """ + Subscribe to real-time market data + + Args: + ticker: Stock symbol + + Returns: + True if subscribed successfully + """ + if not self._ensure_connected(): + return False + + try: + contract = Stock(ticker, 'SMART', 'USD') + self.ib.qualifyContracts(contract) + + # Request market data + ticker_obj = self.ib.reqMktData( + contract, + genericTickList='', + snapshot=False, + regulatorySnapshot=False + ) + + self._market_data_streams[ticker] = ticker_obj + + logger.info(f"Subscribed to market data for {ticker}") + return True + + except Exception as e: + logger.error(f"Error subscribing to market data: {e}") + return False + + def get_market_data(self, ticker: str) -> Optional[Dict[str, Any]]: + """ + Get current market data for a ticker + + Args: + ticker: Stock symbol + + Returns: + Market data dictionary + """ + if ticker in self._market_data_streams: + ticker_obj = self._market_data_streams[ticker] + return { + 'ticker': ticker, + 'last': ticker_obj.last, + 'bid': ticker_obj.bid, + 'ask': ticker_obj.ask, + 'volume': ticker_obj.volume, + 'high': ticker_obj.high, + 'low': ticker_obj.low, + 'close': ticker_obj.close, + 'timestamp': datetime.now() + } + return None + + # === Helper Methods === + + def _ensure_connected(self) -> bool: + """Ensure connection is active""" + if not self.ib or not self.ib.isConnected(): + logger.error("Not connected to IBKR") + return False + return True + + async def _save_bracket_order(self, trades, ticker, action, quantity, + entry_price, stop_loss, take_profit, + idempotency_key): + """Save bracket order to database""" + if not self.db: + return + + try: + # Save parent order + parent_trade = trades[0] + parent_order = self.db.save_order({ + 'order_id': str(parent_trade.order.orderId), + 'idempotency_key': idempotency_key, + 'ticker': ticker, + 'action': action, + 'order_type': 'LIMIT', + 'quantity': quantity, + 'limit_price': entry_price, + 'stop_loss_price': stop_loss, + 'take_profit_price': take_profit, + 'status': OrderStatus.SUBMITTED, + 'submitted_at': datetime.now() + }) + + # Save child orders + if len(trades) > 1: + # Stop loss order + stop_trade = trades[1] + self.db.save_order({ + 'order_id': str(stop_trade.order.orderId), + 'ticker': ticker, + 'action': 'SELL' if action == 'BUY' else 'BUY', + 'order_type': 'STOP', + 'quantity': quantity, + 'stop_price': stop_loss, + 'parent_order_id': parent_order.id, + 'status': OrderStatus.PENDING + }) + + if len(trades) > 2: + # Take profit order + profit_trade = trades[2] + self.db.save_order({ + 'order_id': str(profit_trade.order.orderId), + 'ticker': ticker, + 'action': 'SELL' if action == 'BUY' else 'BUY', + 'order_type': 'LIMIT', + 'quantity': quantity, + 'limit_price': take_profit, + 'parent_order_id': parent_order.id, + 'status': OrderStatus.PENDING + }) + + except Exception as e: + logger.error(f"Error saving bracket order: {e}") + + async def _update_order_status(self, trade): + """Update order status in database""" + if not self.db: + return + + try: + status_map = { + 'Submitted': OrderStatus.SUBMITTED, + 'Filled': OrderStatus.FILLED, + 'PartiallyFilled': OrderStatus.PARTIALLY_FILLED, + 'Cancelled': OrderStatus.CANCELLED, + 'Inactive': OrderStatus.REJECTED + } + + db_status = status_map.get(trade.orderStatus.status, OrderStatus.PENDING) + + self.db.update_order_status( + order_id=str(trade.order.orderId), + status=db_status, + filled_quantity=trade.orderStatus.filled, + avg_fill_price=trade.orderStatus.avgFillPrice + ) + + except Exception as e: + logger.error(f"Error updating order status: {e}") + + async def _save_trade(self, trade, fill): + """Save executed trade to database""" + if not self.db: + return + + try: + with self.db.get_session() as session: + from ..core.database import Trade + + trade_record = Trade( + order_id=trade.order.orderId, + ticker=fill.contract.symbol, + action=fill.execution.side, + quantity=fill.execution.shares, + price=fill.execution.price, + commission=fill.commissionReport.commission if fill.commissionReport else 0, + executed_at=datetime.now() + ) + + session.add(trade_record) + session.commit() + + except Exception as e: + logger.error(f"Error saving trade: {e}") + + def get_health_metrics(self) -> Dict[str, Any]: + """Get connection health metrics""" + return { + 'state': self.health.state.value, + 'connected': self.health.state == ConnectionState.CONNECTED, + 'last_heartbeat': self.health.last_heartbeat.isoformat() if self.health.last_heartbeat else None, + 'latency_ms': self.health.latency_ms, + 'reconnect_attempts': self.health.reconnect_attempts, + 'total_reconnects': self.health.total_reconnects, + 'orders_placed': self.health.orders_placed, + 'orders_failed': self.health.orders_failed, + 'recent_errors': list(self.health.errors)[-5:] if self.health.errors else [] + } + + +# Example usage +async def main(): + """Example of using the resilient IBKR connector""" + from ..core.database import DatabaseManager + + # Initialize database + db = DatabaseManager("postgresql://trader:password@localhost/trading_db") + + # Create connector + connector = ResilientIBKRConnector( + host="127.0.0.1", + port=7497, # Paper trading + db_manager=db + ) + + # Set callbacks + async def on_connected(): + logger.info("Connected callback triggered") + + async def on_error(code, message): + logger.error(f"Error callback: {code} - {message}") + + connector.on_connected_callback = on_connected + connector.on_error_callback = on_error + + # Connect + if await connector.connect(): + # Subscribe to market data + await connector.subscribe_market_data("AAPL") + + # Place a bracket order + result = await connector.place_bracket_order( + ticker="AAPL", + action="BUY", + quantity=100, + entry_price=150.00, + stop_loss=145.00, + take_profit=160.00 + ) + + if result: + logger.info(f"Bracket order placed: {result}") + + # Monitor for 5 minutes + for _ in range(10): + await asyncio.sleep(30) + health = connector.get_health_metrics() + logger.info(f"Health: {health}") + + # Disconnect + await connector.disconnect() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + asyncio.run(main()) \ No newline at end of file diff --git a/autonomous/core/database.py b/autonomous/core/database.py new file mode 100644 index 00000000..b1b0cd17 --- /dev/null +++ b/autonomous/core/database.py @@ -0,0 +1,473 @@ +""" +Database Models and Persistence Layer +===================================== + +SQLAlchemy models for persistent storage of trading data. +Uses PostgreSQL with TimescaleDB extension for time-series optimization. +""" + +from datetime import datetime +from decimal import Decimal +from enum import Enum +from typing import Optional, Dict, Any +import uuid + +from sqlalchemy import ( + Column, String, Integer, Float, Decimal as SQLDecimal, + DateTime, Boolean, JSON, Text, ForeignKey, Index, + UniqueConstraint, CheckConstraint, create_engine +) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship, sessionmaker, Session +from sqlalchemy.dialects.postgresql import UUID, JSONB +from sqlalchemy.sql import func + +Base = declarative_base() + + +class OrderStatus(str, Enum): + """Order status enumeration""" + PENDING = "pending" + SUBMITTED = "submitted" + PARTIALLY_FILLED = "partially_filled" + FILLED = "filled" + CANCELLED = "cancelled" + REJECTED = "rejected" + FAILED = "failed" + + +class SignalType(str, Enum): + """Signal type enumeration""" + CONGRESSIONAL = "congressional" + INSIDER = "insider" + TECHNICAL = "technical" + FUNDAMENTAL = "fundamental" + SENTIMENT = "sentiment" + AI_GENERATED = "ai_generated" + + +class AlertPriority(str, Enum): + """Alert priority levels""" + CRITICAL = "critical" + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + INFO = "info" + + +# === Portfolio Models === + +class Position(Base): + """Current portfolio positions""" + __tablename__ = "positions" + __table_args__ = ( + Index('idx_position_ticker', 'ticker'), + Index('idx_position_updated', 'last_updated'), + ) + + id = Column(Integer, primary_key=True) + ticker = Column(String(10), nullable=False, unique=True) + shares = Column(Integer, nullable=False) + avg_cost = Column(SQLDecimal(12, 4), nullable=False) + current_price = Column(SQLDecimal(12, 4)) + market_value = Column(SQLDecimal(12, 2)) + unrealized_pnl = Column(SQLDecimal(12, 2)) + realized_pnl = Column(SQLDecimal(12, 2)) + percent_change = Column(Float) + last_updated = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) + + # Relationships + orders = relationship("Order", back_populates="position") + snapshots = relationship("PortfolioSnapshot", back_populates="position") + + +class PortfolioSnapshot(Base): + """Historical portfolio snapshots""" + __tablename__ = "portfolio_snapshots" + __table_args__ = ( + Index('idx_snapshot_timestamp', 'timestamp'), + Index('idx_snapshot_ticker_time', 'ticker', 'timestamp'), + # TimescaleDB hypertable on timestamp + {'timescaledb_hypertable': {'time_column': 'timestamp'}} + ) + + id = Column(Integer, primary_key=True) + timestamp = Column(DateTime(timezone=True), nullable=False, default=func.now()) + position_id = Column(Integer, ForeignKey('positions.id')) + ticker = Column(String(10), nullable=False) + shares = Column(Integer, nullable=False) + price = Column(SQLDecimal(12, 4), nullable=False) + value = Column(SQLDecimal(12, 2), nullable=False) + daily_pnl = Column(SQLDecimal(12, 2)) + total_pnl = Column(SQLDecimal(12, 2)) + + # Account level metrics + total_value = Column(SQLDecimal(15, 2)) + cash_balance = Column(SQLDecimal(15, 2)) + buying_power = Column(SQLDecimal(15, 2)) + + position = relationship("Position", back_populates="snapshots") + + +# === Trading Models === + +class Order(Base): + """Order tracking with full lifecycle""" + __tablename__ = "orders" + __table_args__ = ( + Index('idx_order_ticker', 'ticker'), + Index('idx_order_status', 'status'), + Index('idx_order_created', 'created_at'), + UniqueConstraint('idempotency_key', name='uq_order_idempotency'), + ) + + id = Column(Integer, primary_key=True) + order_id = Column(String(50), unique=True) # IBKR order ID + idempotency_key = Column(UUID(as_uuid=True), default=uuid.uuid4) + + # Order details + ticker = Column(String(10), nullable=False) + position_id = Column(Integer, ForeignKey('positions.id')) + action = Column(String(10), nullable=False) # BUY/SELL + order_type = Column(String(20), nullable=False) # MARKET/LIMIT/STOP + quantity = Column(Integer, nullable=False) + limit_price = Column(SQLDecimal(12, 4)) + stop_price = Column(SQLDecimal(12, 4)) + + # Status tracking + status = Column(String(20), nullable=False, default=OrderStatus.PENDING) + filled_quantity = Column(Integer, default=0) + avg_fill_price = Column(SQLDecimal(12, 4)) + commission = Column(SQLDecimal(8, 2)) + + # Risk management + stop_loss_price = Column(SQLDecimal(12, 4)) + take_profit_price = Column(SQLDecimal(12, 4)) + parent_order_id = Column(Integer, ForeignKey('orders.id')) + + # Metadata + signal_id = Column(Integer, ForeignKey('signals.id')) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + submitted_at = Column(DateTime(timezone=True)) + filled_at = Column(DateTime(timezone=True)) + cancelled_at = Column(DateTime(timezone=True)) + notes = Column(Text) + + # Relationships + position = relationship("Position", back_populates="orders") + signal = relationship("Signal", back_populates="orders") + child_orders = relationship("Order", backref='parent_order') + + +class Trade(Base): + """Executed trades (fills)""" + __tablename__ = "trades" + __table_args__ = ( + Index('idx_trade_ticker', 'ticker'), + Index('idx_trade_executed', 'executed_at'), + ) + + id = Column(Integer, primary_key=True) + order_id = Column(Integer, ForeignKey('orders.id'), nullable=False) + ticker = Column(String(10), nullable=False) + action = Column(String(10), nullable=False) + quantity = Column(Integer, nullable=False) + price = Column(SQLDecimal(12, 4), nullable=False) + commission = Column(SQLDecimal(8, 2)) + executed_at = Column(DateTime(timezone=True), nullable=False, default=func.now()) + pnl = Column(SQLDecimal(12, 2)) + + +# === Signal Models === + +class Signal(Base): + """Trading signals generated by the system""" + __tablename__ = "signals" + __table_args__ = ( + Index('idx_signal_ticker', 'ticker'), + Index('idx_signal_type', 'signal_type'), + Index('idx_signal_created', 'created_at'), + Index('idx_signal_confidence', 'confidence'), + ) + + id = Column(Integer, primary_key=True) + ticker = Column(String(10), nullable=False) + signal_type = Column(String(20), nullable=False) + action = Column(String(10), nullable=False) # BUY/SELL/HOLD + confidence = Column(Float, nullable=False) + + # Price targets + current_price = Column(SQLDecimal(12, 4)) + entry_price_min = Column(SQLDecimal(12, 4)) + entry_price_max = Column(SQLDecimal(12, 4)) + target_price_1 = Column(SQLDecimal(12, 4)) + target_price_2 = Column(SQLDecimal(12, 4)) + stop_loss = Column(SQLDecimal(12, 4)) + + # Sizing + position_size = Column(Float) # Percentage of portfolio + risk_level = Column(String(10)) # LOW/MEDIUM/HIGH + + # Metadata + reasoning = Column(Text) + data_sources = Column(JSONB) # JSON array of sources + created_at = Column(DateTime(timezone=True), server_default=func.now()) + expires_at = Column(DateTime(timezone=True)) + + # Tracking + acted_upon = Column(Boolean, default=False) + acted_at = Column(DateTime(timezone=True)) + performance = Column(JSONB) # Track signal accuracy + + # Relationships + orders = relationship("Order", back_populates="signal") + + +# === Alternative Data Models === + +class CongressionalTrade(Base): + """Congressional trading activity""" + __tablename__ = "congressional_trades" + __table_args__ = ( + Index('idx_congress_ticker', 'ticker'), + Index('idx_congress_politician', 'politician'), + Index('idx_congress_filed', 'filing_date'), + UniqueConstraint('politician', 'ticker', 'transaction_date', + name='uq_congress_trade'), + ) + + id = Column(Integer, primary_key=True) + politician = Column(String(100), nullable=False) + ticker = Column(String(10), nullable=False) + action = Column(String(20), nullable=False) # purchase/sale + amount_range = Column(String(50)) + transaction_date = Column(DateTime(timezone=True), nullable=False) + filing_date = Column(DateTime(timezone=True), nullable=False) + party = Column(String(20)) + state = Column(String(5)) + chamber = Column(String(10)) # house/senate + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + +class InsiderTrade(Base): + """Insider trading activity""" + __tablename__ = "insider_trades" + __table_args__ = ( + Index('idx_insider_ticker', 'ticker'), + Index('idx_insider_name', 'insider_name'), + Index('idx_insider_date', 'transaction_date'), + ) + + id = Column(Integer, primary_key=True) + insider_name = Column(String(100), nullable=False) + ticker = Column(String(10), nullable=False) + action = Column(String(10), nullable=False) # Buy/Sell + shares = Column(Integer) + value = Column(SQLDecimal(15, 2)) + transaction_date = Column(DateTime(timezone=True), nullable=False) + position = Column(String(50)) # CEO/CFO/Director + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + +# === Alert Models === + +class Alert(Base): + """Alert history and tracking""" + __tablename__ = "alerts" + __table_args__ = ( + Index('idx_alert_type', 'alert_type'), + Index('idx_alert_priority', 'priority'), + Index('idx_alert_created', 'created_at'), + ) + + id = Column(Integer, primary_key=True) + title = Column(String(200), nullable=False) + message = Column(Text, nullable=False) + alert_type = Column(String(30), nullable=False) + priority = Column(String(10), nullable=False) + + # Delivery tracking + channels = Column(JSONB) # ['discord', 'telegram', 'email'] + sent_successfully = Column(Boolean, default=False) + send_attempts = Column(Integer, default=0) + error_message = Column(Text) + + # Metadata + data = Column(JSONB) + ticker = Column(String(10)) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + sent_at = Column(DateTime(timezone=True)) + + # Deduplication + hash = Column(String(64), index=True) + + +# === Performance Models === + +class PerformanceMetric(Base): + """Strategy performance tracking""" + __tablename__ = "performance_metrics" + __table_args__ = ( + Index('idx_perf_date', 'date'), + {'timescaledb_hypertable': {'time_column': 'date'}} + ) + + id = Column(Integer, primary_key=True) + date = Column(DateTime(timezone=True), nullable=False, default=func.now()) + + # Daily metrics + total_pnl = Column(SQLDecimal(12, 2)) + realized_pnl = Column(SQLDecimal(12, 2)) + unrealized_pnl = Column(SQLDecimal(12, 2)) + win_rate = Column(Float) + sharpe_ratio = Column(Float) + max_drawdown = Column(Float) + + # Trade statistics + total_trades = Column(Integer) + winning_trades = Column(Integer) + losing_trades = Column(Integer) + avg_win = Column(SQLDecimal(10, 2)) + avg_loss = Column(SQLDecimal(10, 2)) + + # Signal statistics + signals_generated = Column(Integer) + signals_acted = Column(Integer) + signal_accuracy = Column(Float) + + # Risk metrics + portfolio_beta = Column(Float) + portfolio_volatility = Column(Float) + value_at_risk = Column(SQLDecimal(12, 2)) + + +# === Database Manager === + +class DatabaseManager: + """Database connection and session management""" + + def __init__(self, connection_string: str): + """ + Initialize database manager + + Args: + connection_string: PostgreSQL connection string + """ + self.engine = create_engine( + connection_string, + pool_size=20, + max_overflow=40, + pool_pre_ping=True, # Check connections before using + echo=False + ) + self.SessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=self.engine + ) + + def init_database(self): + """Create all tables and indexes""" + Base.metadata.create_all(bind=self.engine) + + # Create TimescaleDB hypertables + with self.engine.connect() as conn: + # Convert tables to hypertables for time-series optimization + conn.execute(""" + SELECT create_hypertable('portfolio_snapshots', 'timestamp', + if_not_exists => TRUE); + """) + conn.execute(""" + SELECT create_hypertable('performance_metrics', 'date', + if_not_exists => TRUE); + """) + + def get_session(self) -> Session: + """Get a new database session""" + return self.SessionLocal() + + def save_position(self, position_data: Dict[str, Any]) -> Position: + """Save or update a position""" + with self.get_session() as session: + position = session.query(Position).filter_by( + ticker=position_data['ticker'] + ).first() + + if position: + # Update existing + for key, value in position_data.items(): + setattr(position, key, value) + else: + # Create new + position = Position(**position_data) + session.add(position) + + session.commit() + session.refresh(position) + return position + + def save_signal(self, signal_data: Dict[str, Any]) -> Signal: + """Save a new signal""" + with self.get_session() as session: + signal = Signal(**signal_data) + session.add(signal) + session.commit() + session.refresh(signal) + return signal + + def save_order(self, order_data: Dict[str, Any]) -> Order: + """Save a new order with idempotency""" + with self.get_session() as session: + # Check idempotency + if 'idempotency_key' in order_data: + existing = session.query(Order).filter_by( + idempotency_key=order_data['idempotency_key'] + ).first() + if existing: + return existing + + order = Order(**order_data) + session.add(order) + session.commit() + session.refresh(order) + return order + + def update_order_status(self, order_id: str, status: OrderStatus, + **kwargs) -> Optional[Order]: + """Update order status""" + with self.get_session() as session: + order = session.query(Order).filter_by(order_id=order_id).first() + if order: + order.status = status + for key, value in kwargs.items(): + setattr(order, key, value) + session.commit() + session.refresh(order) + return order + + def get_active_positions(self) -> list[Position]: + """Get all active positions""" + with self.get_session() as session: + return session.query(Position).filter( + Position.shares > 0 + ).all() + + def get_recent_signals(self, ticker: Optional[str] = None, + hours: int = 24) -> list[Signal]: + """Get recent signals""" + with self.get_session() as session: + query = session.query(Signal).filter( + Signal.created_at >= func.now() - timedelta(hours=hours) + ) + if ticker: + query = query.filter(Signal.ticker == ticker) + return query.order_by(Signal.confidence.desc()).all() + + +# Example usage +if __name__ == "__main__": + # Initialize database + db = DatabaseManager("postgresql://trader:password@localhost/trading_db") + db.init_database() + + print("✅ Database initialized successfully") \ No newline at end of file diff --git a/autonomous/core/order_manager.py b/autonomous/core/order_manager.py new file mode 100644 index 00000000..259ed098 --- /dev/null +++ b/autonomous/core/order_manager.py @@ -0,0 +1,703 @@ +""" +Order Management System +======================= + +Comprehensive order lifecycle management with state machine, +validation, and execution tracking. +""" + +import asyncio +import logging +import uuid +from typing import Dict, List, Optional, Any, Tuple +from datetime import datetime, timedelta +from decimal import Decimal +from enum import Enum +from dataclasses import dataclass, field +import json + +from transitions import Machine +from pydantic import BaseModel, Field, validator + +from .database import ( + DatabaseManager, Order, OrderStatus, Trade, + Signal, Position +) + +logger = logging.getLogger(__name__) + + +class OrderType(str, Enum): + """Order type enumeration""" + MARKET = "MARKET" + LIMIT = "LIMIT" + STOP = "STOP" + STOP_LIMIT = "STOP_LIMIT" + TRAILING_STOP = "TRAILING_STOP" + BRACKET = "BRACKET" + + +class OrderSide(str, Enum): + """Order side enumeration""" + BUY = "BUY" + SELL = "SELL" + + +class OrderValidationError(Exception): + """Order validation error""" + pass + + +class OrderExecutionError(Exception): + """Order execution error""" + pass + + +# === Pydantic Models for Validation === + +class OrderRequest(BaseModel): + """Order request with validation""" + ticker: str = Field(..., min_length=1, max_length=10) + side: OrderSide + quantity: int = Field(..., gt=0, le=100000) + order_type: OrderType + limit_price: Optional[Decimal] = Field(None, gt=0, le=1000000) + stop_price: Optional[Decimal] = Field(None, gt=0, le=1000000) + time_in_force: str = Field(default="DAY") # DAY, GTC, IOC, FOK + idempotency_key: Optional[str] = None + signal_id: Optional[int] = None + notes: Optional[str] = None + + # Risk management + stop_loss: Optional[Decimal] = Field(None, gt=0) + take_profit: Optional[Decimal] = Field(None, gt=0) + max_slippage: Optional[Decimal] = Field(default=Decimal("0.01")) # 1% + + @validator('ticker') + def validate_ticker(cls, v): + """Validate ticker symbol""" + # Basic validation - alphanumeric only + if not v.isalnum(): + raise ValueError(f"Invalid ticker symbol: {v}") + return v.upper() + + @validator('limit_price') + def validate_limit_price(cls, v, values): + """Validate limit price for limit orders""" + if values.get('order_type') in [OrderType.LIMIT, OrderType.STOP_LIMIT]: + if v is None: + raise ValueError("Limit price required for limit orders") + return v + + @validator('stop_price') + def validate_stop_price(cls, v, values): + """Validate stop price for stop orders""" + if values.get('order_type') in [OrderType.STOP, OrderType.STOP_LIMIT, + OrderType.TRAILING_STOP]: + if v is None: + raise ValueError("Stop price required for stop orders") + return v + + class Config: + use_enum_values = True + + +@dataclass +class OrderContext: + """Context for order execution""" + request: OrderRequest + order_id: Optional[str] = None + ibkr_order_id: Optional[int] = None + db_order: Optional[Order] = None + position: Optional[Position] = None + signal: Optional[Signal] = None + validation_errors: List[str] = field(default_factory=list) + execution_errors: List[str] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + created_at: datetime = field(default_factory=datetime.now) + + +# === Order State Machine === + +class OrderStateMachine: + """ + State machine for order lifecycle management + + States: + - pending: Initial state + - validated: Order validated + - risk_checked: Risk checks passed + - submitted: Sent to broker + - acknowledged: Broker acknowledged + - partially_filled: Partially executed + - filled: Fully executed + - cancelled: Cancelled + - rejected: Rejected by broker or risk + - failed: System failure + """ + + # State transitions + states = [ + 'pending', 'validated', 'risk_checked', 'submitted', + 'acknowledged', 'partially_filled', 'filled', + 'cancelled', 'rejected', 'failed' + ] + + # Valid transitions + transitions = [ + # Forward flow + {'trigger': 'validate', 'source': 'pending', 'dest': 'validated'}, + {'trigger': 'check_risk', 'source': 'validated', 'dest': 'risk_checked'}, + {'trigger': 'submit', 'source': 'risk_checked', 'dest': 'submitted'}, + {'trigger': 'acknowledge', 'source': 'submitted', 'dest': 'acknowledged'}, + {'trigger': 'partial_fill', 'source': ['acknowledged', 'partially_filled'], + 'dest': 'partially_filled'}, + {'trigger': 'fill', 'source': ['acknowledged', 'partially_filled'], + 'dest': 'filled'}, + + # Cancellation + {'trigger': 'cancel', 'source': ['pending', 'validated', 'risk_checked', + 'submitted', 'acknowledged', 'partially_filled'], + 'dest': 'cancelled'}, + + # Rejection + {'trigger': 'reject', 'source': ['validated', 'risk_checked', 'submitted'], + 'dest': 'rejected'}, + + # Failure + {'trigger': 'fail', 'source': '*', 'dest': 'failed'}, + ] + + def __init__(self, context: OrderContext): + """Initialize state machine""" + self.context = context + self.machine = Machine( + model=self, + states=OrderStateMachine.states, + transitions=OrderStateMachine.transitions, + initial='pending', + auto_transitions=False, + send_event=True, + after_state_change=self._on_state_change + ) + + def _on_state_change(self, event): + """Log state changes""" + logger.info(f"Order {self.context.order_id}: {event.transition.source} " + f"-> {event.transition.dest}") + + +# === Order Manager === + +class OrderManager: + """ + Manages order lifecycle from creation to execution + """ + + def __init__(self, + db_manager: DatabaseManager, + ibkr_connector, + risk_manager=None): + """ + Initialize order manager + + Args: + db_manager: Database manager + ibkr_connector: IBKR connector instance + risk_manager: Risk manager instance + """ + self.db = db_manager + self.ibkr = ibkr_connector + self.risk_manager = risk_manager + + # Track active orders + self.active_orders: Dict[str, OrderStateMachine] = {} + + # Execution metrics + self.metrics = { + 'orders_created': 0, + 'orders_submitted': 0, + 'orders_filled': 0, + 'orders_cancelled': 0, + 'orders_rejected': 0, + 'orders_failed': 0, + 'total_volume': 0, + 'total_commission': Decimal('0.00') + } + + async def create_order(self, request: OrderRequest) -> Tuple[bool, OrderContext]: + """ + Create and process a new order + + Args: + request: Order request + + Returns: + Tuple of (success, order context) + """ + try: + # Create order context + context = OrderContext( + request=request, + order_id=str(uuid.uuid4()) + ) + + # Create state machine + state_machine = OrderStateMachine(context) + self.active_orders[context.order_id] = state_machine + + self.metrics['orders_created'] += 1 + + # Process through state machine + success = await self._process_order(state_machine) + + return success, context + + except Exception as e: + logger.error(f"Error creating order: {e}") + return False, None + + async def _process_order(self, state_machine: OrderStateMachine) -> bool: + """ + Process order through state machine + + Args: + state_machine: Order state machine + + Returns: + True if order successfully submitted + """ + context = state_machine.context + + try: + # Step 1: Validate + if not await self._validate_order(state_machine): + state_machine.reject() + return False + + state_machine.validate() + + # Step 2: Risk check + if not await self._check_risk(state_machine): + state_machine.reject() + return False + + state_machine.check_risk() + + # Step 3: Submit to broker + if not await self._submit_order(state_machine): + state_machine.fail() + return False + + state_machine.submit() + + # Step 4: Wait for acknowledgment + if not await self._wait_for_acknowledgment(state_machine): + state_machine.fail() + return False + + state_machine.acknowledge() + + return True + + except Exception as e: + logger.error(f"Order processing error: {e}") + state_machine.fail() + await self._save_order_state(state_machine) + return False + + async def _validate_order(self, state_machine: OrderStateMachine) -> bool: + """ + Validate order request + + Args: + state_machine: Order state machine + + Returns: + True if valid + """ + context = state_machine.context + request = context.request + + # Check idempotency + if request.idempotency_key: + existing = await self._check_idempotency(request.idempotency_key) + if existing: + context.validation_errors.append("Duplicate order") + return False + + # Market hours check + if not self._is_market_open(): + if request.order_type == OrderType.MARKET: + context.validation_errors.append("Market closed for market orders") + return False + + # Position validation + if request.side == OrderSide.SELL: + position = await self._get_position(request.ticker) + if not position or position.shares < request.quantity: + context.validation_errors.append("Insufficient shares to sell") + return False + context.position = position + + # Price validation + market_price = await self._get_market_price(request.ticker) + if market_price: + # Check for unreasonable prices + if request.limit_price: + price_diff = abs(float(request.limit_price) - market_price) / market_price + if price_diff > 0.10: # More than 10% away + context.validation_errors.append( + f"Limit price {request.limit_price} is >10% from market {market_price}" + ) + + # Check trading halts + if await self._is_halted(request.ticker): + context.validation_errors.append(f"{request.ticker} is halted") + return False + + return len(context.validation_errors) == 0 + + async def _check_risk(self, state_machine: OrderStateMachine) -> bool: + """ + Perform risk checks + + Args: + state_machine: Order state machine + + Returns: + True if risk checks pass + """ + if not self.risk_manager: + return True + + context = state_machine.context + request = context.request + + try: + # Check with risk manager + risk_result = await self.risk_manager.check_order( + ticker=request.ticker, + side=request.side.value, + quantity=request.quantity, + price=float(request.limit_price or 0) + ) + + if not risk_result['approved']: + context.validation_errors.extend(risk_result.get('reasons', [])) + return False + + # Add risk metadata + context.metadata['risk_score'] = risk_result.get('risk_score') + context.metadata['position_impact'] = risk_result.get('position_impact') + + return True + + except Exception as e: + logger.error(f"Risk check error: {e}") + context.validation_errors.append(f"Risk check failed: {e}") + return False + + async def _submit_order(self, state_machine: OrderStateMachine) -> bool: + """ + Submit order to broker + + Args: + state_machine: Order state machine + + Returns: + True if submitted successfully + """ + context = state_machine.context + request = context.request + + try: + # Prepare for bracket order if stop loss/take profit specified + if request.stop_loss and request.take_profit: + result = await self.ibkr.place_bracket_order( + ticker=request.ticker, + action=request.side.value, + quantity=request.quantity, + entry_price=float(request.limit_price), + stop_loss=float(request.stop_loss), + take_profit=float(request.take_profit), + idempotency_key=request.idempotency_key + ) + + if result: + context.ibkr_order_id = result['parent_id'] + context.metadata['bracket_order'] = result + + else: + # Regular order + order_result = await self.ibkr.place_order( + ticker=request.ticker, + action=request.side.value, + quantity=request.quantity, + order_type=request.order_type.value, + limit_price=float(request.limit_price) if request.limit_price else None, + stop_price=float(request.stop_price) if request.stop_price else None + ) + + if order_result: + context.ibkr_order_id = order_result + + if context.ibkr_order_id: + # Save to database + await self._save_order_to_db(state_machine) + self.metrics['orders_submitted'] += 1 + return True + + context.execution_errors.append("Failed to submit order to broker") + return False + + except Exception as e: + logger.error(f"Order submission error: {e}") + context.execution_errors.append(str(e)) + return False + + async def _wait_for_acknowledgment(self, state_machine: OrderStateMachine, + timeout: int = 5) -> bool: + """ + Wait for broker acknowledgment + + Args: + state_machine: Order state machine + timeout: Timeout in seconds + + Returns: + True if acknowledged + """ + context = state_machine.context + start_time = datetime.now() + + while (datetime.now() - start_time).seconds < timeout: + # Check order status with broker + if context.ibkr_order_id: + # In real implementation, would check actual order status + # For now, assume acknowledged + return True + + await asyncio.sleep(0.5) + + context.execution_errors.append("Acknowledgment timeout") + return False + + async def update_order_status(self, order_id: str, + new_status: str, + **kwargs): + """ + Update order status from broker events + + Args: + order_id: Order ID + new_status: New status + **kwargs: Additional status info + """ + if order_id not in self.active_orders: + return + + state_machine = self.active_orders[order_id] + context = state_machine.context + + try: + # Update state machine + if new_status == 'FILLED': + state_machine.fill() + self.metrics['orders_filled'] += 1 + self.metrics['total_volume'] += context.request.quantity + + elif new_status == 'PARTIALLY_FILLED': + state_machine.partial_fill() + context.metadata['filled_quantity'] = kwargs.get('filled_quantity', 0) + + elif new_status == 'CANCELLED': + state_machine.cancel() + self.metrics['orders_cancelled'] += 1 + + elif new_status == 'REJECTED': + state_machine.reject() + self.metrics['orders_rejected'] += 1 + + # Update database + if context.db_order: + self.db.update_order_status( + order_id=context.ibkr_order_id, + status=OrderStatus[new_status], + **kwargs + ) + + # Clean up if terminal state + if new_status in ['FILLED', 'CANCELLED', 'REJECTED', 'FAILED']: + del self.active_orders[order_id] + + except Exception as e: + logger.error(f"Error updating order status: {e}") + + async def cancel_order(self, order_id: str) -> bool: + """ + Cancel an order + + Args: + order_id: Order ID + + Returns: + True if cancelled successfully + """ + if order_id not in self.active_orders: + logger.warning(f"Order {order_id} not found") + return False + + state_machine = self.active_orders[order_id] + context = state_machine.context + + try: + # Cancel with broker + if context.ibkr_order_id: + success = await self.ibkr.cancel_order(context.ibkr_order_id) + if success: + state_machine.cancel() + await self._save_order_state(state_machine) + del self.active_orders[order_id] + return True + + return False + + except Exception as e: + logger.error(f"Error cancelling order: {e}") + return False + + # === Helper Methods === + + async def _check_idempotency(self, idempotency_key: str) -> Optional[Order]: + """Check for duplicate orders""" + with self.db.get_session() as session: + return session.query(Order).filter_by( + idempotency_key=idempotency_key + ).first() + + async def _get_position(self, ticker: str) -> Optional[Position]: + """Get current position for ticker""" + with self.db.get_session() as session: + return session.query(Position).filter_by(ticker=ticker).first() + + async def _get_market_price(self, ticker: str) -> Optional[float]: + """Get current market price""" + market_data = self.ibkr.get_market_data(ticker) + if market_data: + return market_data['last'] + return None + + async def _is_halted(self, ticker: str) -> bool: + """Check if ticker is halted""" + # Would check with market data provider + return False + + def _is_market_open(self) -> bool: + """Check if market is open""" + now = datetime.now() + # Simplified check - would use market calendar + return (9 <= now.hour < 16 and + now.weekday() < 5) # Mon-Fri + + async def _save_order_to_db(self, state_machine: OrderStateMachine): + """Save order to database""" + context = state_machine.context + request = context.request + + order_data = { + 'order_id': str(context.ibkr_order_id), + 'idempotency_key': request.idempotency_key, + 'ticker': request.ticker, + 'action': request.side.value, + 'order_type': request.order_type.value, + 'quantity': request.quantity, + 'limit_price': request.limit_price, + 'stop_price': request.stop_price, + 'stop_loss_price': request.stop_loss, + 'take_profit_price': request.take_profit, + 'status': OrderStatus.SUBMITTED, + 'signal_id': request.signal_id, + 'notes': request.notes, + 'submitted_at': datetime.now() + } + + context.db_order = self.db.save_order(order_data) + + async def _save_order_state(self, state_machine: OrderStateMachine): + """Save order state to database""" + context = state_machine.context + if context.db_order: + # Update order with final state + self.db.update_order_status( + order_id=str(context.ibkr_order_id), + status=OrderStatus[state_machine.state.upper()], + notes=json.dumps({ + 'validation_errors': context.validation_errors, + 'execution_errors': context.execution_errors, + 'metadata': context.metadata + }) + ) + + def get_active_orders(self) -> List[Dict[str, Any]]: + """Get all active orders""" + active = [] + for order_id, state_machine in self.active_orders.items(): + context = state_machine.context + active.append({ + 'order_id': order_id, + 'ticker': context.request.ticker, + 'side': context.request.side.value, + 'quantity': context.request.quantity, + 'state': state_machine.state, + 'created_at': context.created_at.isoformat() + }) + return active + + def get_metrics(self) -> Dict[str, Any]: + """Get order manager metrics""" + return { + **self.metrics, + 'active_orders': len(self.active_orders), + 'fill_rate': (self.metrics['orders_filled'] / + max(self.metrics['orders_submitted'], 1)) * 100 + } + + +# Example usage +async def main(): + """Example of using the order manager""" + from .database import DatabaseManager + from ..connectors.ibkr_resilient import ResilientIBKRConnector + + # Initialize components + db = DatabaseManager("postgresql://trader:password@localhost/trading_db") + ibkr = ResilientIBKRConnector(db_manager=db) + order_manager = OrderManager(db, ibkr) + + # Create an order + order_request = OrderRequest( + ticker="AAPL", + side=OrderSide.BUY, + quantity=100, + order_type=OrderType.LIMIT, + limit_price=Decimal("150.00"), + stop_loss=Decimal("145.00"), + take_profit=Decimal("160.00"), + idempotency_key=str(uuid.uuid4()) + ) + + success, context = await order_manager.create_order(order_request) + + if success: + logger.info(f"Order created: {context.order_id}") + else: + logger.error(f"Order failed: {context.validation_errors}") + + # Check metrics + print(order_manager.get_metrics()) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + asyncio.run(main()) \ No newline at end of file diff --git a/autonomous/core/risk_manager.py b/autonomous/core/risk_manager.py new file mode 100644 index 00000000..426b10f2 --- /dev/null +++ b/autonomous/core/risk_manager.py @@ -0,0 +1,769 @@ +""" +Risk Manager +============ + +Comprehensive risk management with enforcement of position limits, +loss limits, and portfolio risk metrics. +""" + +import logging +from typing import Dict, List, Optional, Any, Tuple +from datetime import datetime, timedelta, date +from decimal import Decimal +from enum import Enum +from dataclasses import dataclass, field +import numpy as np +from collections import defaultdict + +from pydantic import BaseModel, Field, validator + +from .database import ( + DatabaseManager, Position, Order, Trade, + OrderStatus, PerformanceMetric +) + +logger = logging.getLogger(__name__) + + +class RiskLevel(str, Enum): + """Risk level classification""" + LOW = "LOW" + MEDIUM = "MEDIUM" + HIGH = "HIGH" + CRITICAL = "CRITICAL" + + +class RiskViolationType(str, Enum): + """Types of risk violations""" + POSITION_SIZE = "position_size" + DAILY_LOSS = "daily_loss" + CONCENTRATION = "concentration" + CORRELATION = "correlation" + VOLATILITY = "volatility" + MARGIN = "margin" + PATTERN_DAY_TRADER = "pattern_day_trader" + + +@dataclass +class RiskMetrics: + """Portfolio risk metrics""" + total_exposure: Decimal = Decimal('0') + max_position_size: Decimal = Decimal('0') + concentration_risk: Decimal = Decimal('0') + portfolio_beta: float = 0.0 + portfolio_volatility: float = 0.0 + value_at_risk_95: Decimal = Decimal('0') + value_at_risk_99: Decimal = Decimal('0') + sharpe_ratio: float = 0.0 + sortino_ratio: float = 0.0 + max_drawdown: Decimal = Decimal('0') + current_drawdown: Decimal = Decimal('0') + daily_pnl: Decimal = Decimal('0') + realized_pnl: Decimal = Decimal('0') + unrealized_pnl: Decimal = Decimal('0') + margin_used: Decimal = Decimal('0') + margin_available: Decimal = Decimal('0') + correlation_risk: float = 0.0 + sector_concentration: Dict[str, float] = field(default_factory=dict) + + +@dataclass +class RiskLimits: + """Risk limits configuration""" + max_position_size: Decimal = Decimal('0.20') # 20% per position + max_daily_loss: Decimal = Decimal('0.05') # 5% daily loss + max_total_exposure: Decimal = Decimal('1.0') # 100% exposure + max_concentration: Decimal = Decimal('0.30') # 30% in single stock + max_sector_exposure: Decimal = Decimal('0.40') # 40% per sector + max_correlation: float = 0.70 # Max correlation between positions + max_volatility: float = 0.30 # 30% annualized volatility + min_sharpe_ratio: float = 0.5 # Minimum Sharpe ratio + max_drawdown: Decimal = Decimal('0.15') # 15% max drawdown + max_orders_per_day: int = 50 + max_trades_per_symbol_per_day: int = 4 # PDT rule + min_position_hold_time: int = 60 # Seconds + required_stop_loss: bool = True + max_leverage: Decimal = Decimal('2.0') # 2x leverage + + +class RiskCheckResult(BaseModel): + """Result of risk check""" + approved: bool + risk_score: float = Field(ge=0, le=100) + risk_level: RiskLevel + violations: List[RiskViolationType] = Field(default_factory=list) + reasons: List[str] = Field(default_factory=list) + position_impact: Dict[str, Any] = Field(default_factory=dict) + recommendations: List[str] = Field(default_factory=list) + + +class RiskManager: + """ + Comprehensive risk management system + """ + + def __init__(self, + db_manager: DatabaseManager, + limits: Optional[RiskLimits] = None, + enable_enforcement: bool = True): + """ + Initialize risk manager + + Args: + db_manager: Database manager + limits: Risk limits configuration + enable_enforcement: Whether to enforce limits + """ + self.db = db_manager + self.limits = limits or RiskLimits() + self.enable_enforcement = enable_enforcement + + # Cache for performance + self._position_cache: Dict[str, Position] = {} + self._metrics_cache: Optional[RiskMetrics] = None + self._cache_timestamp: Optional[datetime] = None + self._cache_ttl = 60 # Seconds + + # Track daily metrics + self._daily_trades: List[Trade] = [] + self._daily_orders: List[Order] = [] + self._starting_portfolio_value: Optional[Decimal] = None + + # Sector mapping (simplified) + self.sector_map = { + 'AAPL': 'Technology', 'MSFT': 'Technology', 'GOOGL': 'Technology', + 'NVDA': 'Technology', 'AVGO': 'Technology', 'TSM': 'Technology', + 'MU': 'Technology', 'META': 'Technology', + 'JPM': 'Financial', 'BAC': 'Financial', 'GS': 'Financial', + 'XOM': 'Energy', 'CVX': 'Energy', + 'JNJ': 'Healthcare', 'PFE': 'Healthcare', + # Add more mappings as needed + } + + async def check_order(self, + ticker: str, + side: str, + quantity: int, + price: float, + stop_loss: Optional[float] = None) -> RiskCheckResult: + """ + Check if an order passes risk management rules + + Args: + ticker: Stock symbol + side: 'BUY' or 'SELL' + quantity: Number of shares + price: Order price + stop_loss: Stop loss price + + Returns: + Risk check result + """ + violations = [] + reasons = [] + recommendations = [] + + # Get current metrics + metrics = await self.calculate_risk_metrics() + + # Calculate order value + order_value = Decimal(str(price * quantity)) + + # Get portfolio value + portfolio_value = await self._get_portfolio_value() + + if portfolio_value <= 0: + return RiskCheckResult( + approved=False, + risk_score=100, + risk_level=RiskLevel.CRITICAL, + reasons=["Invalid portfolio value"] + ) + + # === Check 1: Position Size === + position_pct = order_value / portfolio_value + + if position_pct > self.limits.max_position_size: + violations.append(RiskViolationType.POSITION_SIZE) + reasons.append( + f"Position size {position_pct:.1%} exceeds limit " + f"{self.limits.max_position_size:.1%}" + ) + recommendations.append( + f"Reduce quantity to {int(quantity * float(self.limits.max_position_size / position_pct))}" + ) + + # === Check 2: Daily Loss Limit === + if metrics.daily_pnl < 0: + daily_loss_pct = abs(metrics.daily_pnl / portfolio_value) + if daily_loss_pct >= self.limits.max_daily_loss: + violations.append(RiskViolationType.DAILY_LOSS) + reasons.append( + f"Daily loss {daily_loss_pct:.1%} at limit " + f"{self.limits.max_daily_loss:.1%}" + ) + recommendations.append("Stop trading for the day") + + # === Check 3: Concentration Risk === + existing_position = await self._get_position(ticker) + if existing_position and side == 'BUY': + new_position_value = (existing_position.market_value + + order_value) + concentration = new_position_value / portfolio_value + + if concentration > self.limits.max_concentration: + violations.append(RiskViolationType.CONCENTRATION) + reasons.append( + f"Concentration {concentration:.1%} exceeds limit " + f"{self.limits.max_concentration:.1%}" + ) + recommendations.append(f"Diversify into other stocks") + + # === Check 4: Sector Concentration === + sector = self.sector_map.get(ticker, 'Other') + sector_exposure = metrics.sector_concentration.get(sector, 0) + + if side == 'BUY': + new_sector_exposure = sector_exposure + float(position_pct) + if new_sector_exposure > float(self.limits.max_sector_exposure): + violations.append(RiskViolationType.CONCENTRATION) + reasons.append( + f"Sector exposure {new_sector_exposure:.1%} exceeds limit " + f"{self.limits.max_sector_exposure:.1%}" + ) + recommendations.append("Diversify into other sectors") + + # === Check 5: Stop Loss Required === + if self.limits.required_stop_loss and side == 'BUY': + if not stop_loss: + reasons.append("Stop loss required for buy orders") + recommendations.append( + f"Add stop loss at {price * 0.97:.2f} (-3%)" + ) + + # === Check 6: Pattern Day Trader Rule === + day_trades_count = await self._count_day_trades(ticker) + if day_trades_count >= self.limits.max_trades_per_symbol_per_day: + violations.append(RiskViolationType.PATTERN_DAY_TRADER) + reasons.append( + f"PDT rule: {day_trades_count} trades today in {ticker}" + ) + + # === Check 7: Volatility === + if metrics.portfolio_volatility > self.limits.max_volatility: + violations.append(RiskViolationType.VOLATILITY) + reasons.append( + f"Portfolio volatility {metrics.portfolio_volatility:.1%} " + f"exceeds limit {self.limits.max_volatility:.1%}" + ) + recommendations.append("Reduce position sizes or add hedges") + + # === Check 8: Margin === + if metrics.margin_available < order_value: + violations.append(RiskViolationType.MARGIN) + reasons.append( + f"Insufficient margin: need ${order_value:,.2f}, " + f"have ${metrics.margin_available:,.2f}" + ) + + # === Check 9: Correlation === + if side == 'BUY': + correlation_risk = await self._check_correlation(ticker) + if correlation_risk > self.limits.max_correlation: + violations.append(RiskViolationType.CORRELATION) + reasons.append( + f"High correlation {correlation_risk:.2f} with existing positions" + ) + recommendations.append("Diversify into uncorrelated assets") + + # === Check 10: Max Drawdown === + if metrics.current_drawdown > self.limits.max_drawdown: + violations.append(RiskViolationType.VOLATILITY) + reasons.append( + f"In drawdown {metrics.current_drawdown:.1%}, " + f"limit {self.limits.max_drawdown:.1%}" + ) + recommendations.append("Reduce risk until recovery") + + # Calculate risk score + risk_score = self._calculate_risk_score( + violations, metrics, position_pct + ) + + # Determine risk level + if risk_score >= 80: + risk_level = RiskLevel.CRITICAL + elif risk_score >= 60: + risk_level = RiskLevel.HIGH + elif risk_score >= 40: + risk_level = RiskLevel.MEDIUM + else: + risk_level = RiskLevel.LOW + + # Determine approval + approved = True + if self.enable_enforcement: + # Critical violations always reject + critical_violations = [ + RiskViolationType.DAILY_LOSS, + RiskViolationType.MARGIN, + RiskViolationType.PATTERN_DAY_TRADER + ] + if any(v in critical_violations for v in violations): + approved = False + + # High risk requires override + if risk_level == RiskLevel.CRITICAL: + approved = False + + return RiskCheckResult( + approved=approved, + risk_score=risk_score, + risk_level=risk_level, + violations=violations, + reasons=reasons, + position_impact={ + 'new_position_size': float(position_pct), + 'new_total_exposure': float(metrics.total_exposure + position_pct), + 'expected_volatility_change': self._estimate_volatility_impact( + ticker, position_pct + ) + }, + recommendations=recommendations + ) + + async def calculate_risk_metrics(self, force_refresh: bool = False) -> RiskMetrics: + """ + Calculate comprehensive risk metrics + + Args: + force_refresh: Force cache refresh + + Returns: + Risk metrics + """ + # Check cache + if not force_refresh and self._metrics_cache: + if (datetime.now() - self._cache_timestamp).seconds < self._cache_ttl: + return self._metrics_cache + + metrics = RiskMetrics() + + # Get positions + positions = await self._get_all_positions() + portfolio_value = await self._get_portfolio_value() + + if not positions or portfolio_value <= 0: + return metrics + + # Calculate exposure metrics + total_position_value = Decimal('0') + max_position_value = Decimal('0') + sector_values = defaultdict(Decimal) + + for position in positions: + position_value = position.market_value + total_position_value += position_value + + # Track max position + if position_value > max_position_value: + max_position_value = position_value + + # Track sector exposure + sector = self.sector_map.get(position.ticker, 'Other') + sector_values[sector] += position_value + + # Basic metrics + metrics.total_exposure = total_position_value / portfolio_value + metrics.max_position_size = max_position_value / portfolio_value + metrics.concentration_risk = max_position_value / total_position_value + + # Sector concentration + for sector, value in sector_values.items(): + metrics.sector_concentration[sector] = float(value / portfolio_value) + + # P&L metrics + metrics.daily_pnl = await self._calculate_daily_pnl() + metrics.realized_pnl = await self._calculate_realized_pnl() + metrics.unrealized_pnl = sum(p.unrealized_pnl for p in positions) + + # Risk metrics (simplified) + returns = await self._get_historical_returns() + if returns: + metrics.portfolio_volatility = np.std(returns) * np.sqrt(252) # Annualized + metrics.sharpe_ratio = self._calculate_sharpe_ratio(returns) + metrics.value_at_risk_95 = self._calculate_var(returns, 0.95) + metrics.value_at_risk_99 = self._calculate_var(returns, 0.99) + metrics.max_drawdown = self._calculate_max_drawdown(returns) + + # Correlation risk + metrics.correlation_risk = await self._calculate_correlation_risk() + + # Margin (simplified) + metrics.margin_used = total_position_value * Decimal('0.5') # 50% margin + metrics.margin_available = portfolio_value * Decimal('2') - metrics.margin_used + + # Cache results + self._metrics_cache = metrics + self._cache_timestamp = datetime.now() + + return metrics + + async def apply_risk_adjustment(self, + ticker: str, + base_size: float, + confidence: float) -> float: + """ + Apply risk-based position sizing + + Args: + ticker: Stock symbol + base_size: Base position size (% of portfolio) + confidence: Signal confidence (0-100) + + Returns: + Risk-adjusted position size + """ + metrics = await self.calculate_risk_metrics() + + # Start with base size + adjusted_size = base_size + + # Adjust for confidence + confidence_multiplier = 0.5 + (confidence / 100) * 0.5 # 0.5x to 1.0x + adjusted_size *= confidence_multiplier + + # Adjust for portfolio volatility + if metrics.portfolio_volatility > 0.20: # High volatility + adjusted_size *= 0.7 + elif metrics.portfolio_volatility < 0.10: # Low volatility + adjusted_size *= 1.2 + + # Adjust for drawdown + if metrics.current_drawdown > Decimal('0.10'): # In drawdown + adjusted_size *= 0.5 + + # Adjust for concentration + if ticker in [p.ticker for p in await self._get_all_positions()]: + adjusted_size *= 0.8 # Reduce if already have position + + # Kelly Criterion (simplified) + win_rate = 0.55 # Assumed win rate + avg_win_loss = 1.5 # Assumed win/loss ratio + kelly_fraction = (win_rate * avg_win_loss - (1 - win_rate)) / avg_win_loss + kelly_size = min(kelly_fraction, 0.25) # Cap at 25% + + # Blend base and Kelly + adjusted_size = (adjusted_size * 0.7) + (kelly_size * 0.3) + + # Ensure within limits + adjusted_size = max( + float(self.limits.max_position_size * Decimal('0.25')), # Min 25% of limit + min(adjusted_size, float(self.limits.max_position_size)) + ) + + return adjusted_size + + async def check_portfolio_health(self) -> Dict[str, Any]: + """ + Comprehensive portfolio health check + + Returns: + Health report dictionary + """ + metrics = await self.calculate_risk_metrics() + health_issues = [] + health_score = 100 + + # Check each metric + if metrics.total_exposure > Decimal('0.95'): + health_issues.append("Over-exposed (>95% invested)") + health_score -= 10 + + if metrics.concentration_risk > Decimal('0.40'): + health_issues.append("High concentration risk (>40% in one position)") + health_score -= 15 + + if metrics.portfolio_volatility > 0.30: + health_issues.append(f"High volatility ({metrics.portfolio_volatility:.1%})") + health_score -= 10 + + if metrics.sharpe_ratio < 0.5: + health_issues.append(f"Low Sharpe ratio ({metrics.sharpe_ratio:.2f})") + health_score -= 10 + + if metrics.current_drawdown > Decimal('0.10'): + health_issues.append(f"In drawdown ({metrics.current_drawdown:.1%})") + health_score -= 20 + + if metrics.daily_pnl < Decimal('-1000'): + health_issues.append(f"Large daily loss (${metrics.daily_pnl:,.2f})") + health_score -= 15 + + # Determine health status + if health_score >= 80: + status = "HEALTHY" + elif health_score >= 60: + status = "CAUTION" + elif health_score >= 40: + status = "WARNING" + else: + status = "CRITICAL" + + return { + 'status': status, + 'score': health_score, + 'issues': health_issues, + 'metrics': { + 'exposure': float(metrics.total_exposure), + 'volatility': metrics.portfolio_volatility, + 'sharpe': metrics.sharpe_ratio, + 'var_95': float(metrics.value_at_risk_95), + 'daily_pnl': float(metrics.daily_pnl), + 'drawdown': float(metrics.current_drawdown) + }, + 'recommendations': self._generate_recommendations(metrics, health_issues) + } + + # === Helper Methods === + + async def _get_all_positions(self) -> List[Position]: + """Get all active positions""" + return self.db.get_active_positions() + + async def _get_position(self, ticker: str) -> Optional[Position]: + """Get specific position""" + with self.db.get_session() as session: + return session.query(Position).filter_by(ticker=ticker).first() + + async def _get_portfolio_value(self) -> Decimal: + """Get total portfolio value""" + positions = await self._get_all_positions() + return sum(p.market_value for p in positions) + + async def _calculate_daily_pnl(self) -> Decimal: + """Calculate today's P&L""" + with self.db.get_session() as session: + today_trades = session.query(Trade).filter( + Trade.executed_at >= date.today() + ).all() + + daily_pnl = sum(t.pnl or 0 for t in today_trades) + + # Add unrealized P&L changes + positions = await self._get_all_positions() + for position in positions: + # Simplified - would compare to morning snapshot + daily_pnl += position.unrealized_pnl * Decimal('0.1') # Estimate + + return daily_pnl + + async def _calculate_realized_pnl(self) -> Decimal: + """Calculate realized P&L""" + with self.db.get_session() as session: + all_trades = session.query(Trade).all() + return sum(t.pnl or 0 for t in all_trades) + + async def _count_day_trades(self, ticker: str) -> int: + """Count day trades for PDT rule""" + with self.db.get_session() as session: + today_orders = session.query(Order).filter( + Order.ticker == ticker, + Order.created_at >= date.today() + ).all() + return len(today_orders) + + async def _check_correlation(self, ticker: str) -> float: + """Check correlation with existing positions""" + # Simplified correlation check + # In production, would calculate actual correlation matrix + positions = await self._get_all_positions() + + if not positions: + return 0.0 + + # High correlation for same sector + sector = self.sector_map.get(ticker, 'Other') + same_sector_positions = [ + p for p in positions + if self.sector_map.get(p.ticker, 'Other') == sector + ] + + if same_sector_positions: + return 0.8 # High correlation assumed for same sector + + return 0.3 # Low correlation for different sectors + + async def _get_historical_returns(self, days: int = 30) -> List[float]: + """Get historical portfolio returns""" + with self.db.get_session() as session: + snapshots = session.query(PerformanceMetric).filter( + PerformanceMetric.date >= datetime.now() - timedelta(days=days) + ).order_by(PerformanceMetric.date).all() + + if len(snapshots) < 2: + return [] + + returns = [] + for i in range(1, len(snapshots)): + if snapshots[i-1].total_pnl and snapshots[i].total_pnl: + daily_return = float( + (snapshots[i].total_pnl - snapshots[i-1].total_pnl) / + abs(snapshots[i-1].total_pnl) + ) + returns.append(daily_return) + + return returns + + def _calculate_risk_score(self, + violations: List[RiskViolationType], + metrics: RiskMetrics, + position_size: Decimal) -> float: + """Calculate risk score (0-100)""" + score = 0 + + # Violation weights + violation_weights = { + RiskViolationType.DAILY_LOSS: 30, + RiskViolationType.MARGIN: 25, + RiskViolationType.POSITION_SIZE: 20, + RiskViolationType.CONCENTRATION: 15, + RiskViolationType.PATTERN_DAY_TRADER: 20, + RiskViolationType.VOLATILITY: 10, + RiskViolationType.CORRELATION: 10 + } + + for violation in violations: + score += violation_weights.get(violation, 5) + + # Add metric-based scoring + if metrics.portfolio_volatility > 0.25: + score += 10 + if metrics.current_drawdown > Decimal('0.10'): + score += 15 + if float(position_size) > 0.15: + score += 10 + + return min(score, 100) + + def _calculate_sharpe_ratio(self, returns: List[float], + risk_free_rate: float = 0.03) -> float: + """Calculate Sharpe ratio""" + if not returns: + return 0.0 + + mean_return = np.mean(returns) * 252 # Annualized + std_return = np.std(returns) * np.sqrt(252) + + if std_return == 0: + return 0.0 + + return (mean_return - risk_free_rate) / std_return + + def _calculate_var(self, returns: List[float], + confidence: float) -> Decimal: + """Calculate Value at Risk""" + if not returns: + return Decimal('0') + + percentile = (1 - confidence) * 100 + var = np.percentile(returns, percentile) + return Decimal(str(abs(var))) + + def _calculate_max_drawdown(self, returns: List[float]) -> Decimal: + """Calculate maximum drawdown""" + if not returns: + return Decimal('0') + + cumulative = np.cumprod(1 + np.array(returns)) + running_max = np.maximum.accumulate(cumulative) + drawdown = (cumulative - running_max) / running_max + return Decimal(str(abs(min(drawdown)))) + + async def _calculate_correlation_risk(self) -> float: + """Calculate overall portfolio correlation risk""" + # Simplified - would use actual correlation matrix + positions = await self._get_all_positions() + + if len(positions) < 2: + return 0.0 + + # Check sector concentration + sectors = [self.sector_map.get(p.ticker, 'Other') for p in positions] + unique_sectors = len(set(sectors)) + + if unique_sectors == 1: + return 0.9 # All in same sector + elif unique_sectors == 2: + return 0.6 + else: + return 0.3 + + def _estimate_volatility_impact(self, ticker: str, + position_size: Decimal) -> float: + """Estimate impact on portfolio volatility""" + # Simplified estimate + # Individual stock volatility assumed at 30% + stock_vol = 0.30 + impact = float(position_size) * stock_vol * 0.5 # Rough estimate + return impact + + def _generate_recommendations(self, metrics: RiskMetrics, + issues: List[str]) -> List[str]: + """Generate risk management recommendations""" + recommendations = [] + + if metrics.concentration_risk > Decimal('0.30'): + recommendations.append("Diversify portfolio - reduce largest position") + + if metrics.portfolio_volatility > 0.25: + recommendations.append("Consider hedging with options or inverse ETFs") + + if metrics.current_drawdown > Decimal('0.10'): + recommendations.append("Reduce position sizes until recovery") + + if metrics.sharpe_ratio < 0.5: + recommendations.append("Review strategy - risk-adjusted returns are low") + + if "Over-exposed" in str(issues): + recommendations.append("Keep some cash reserve for opportunities") + + return recommendations + + +# Example usage +async def main(): + """Example of using the risk manager""" + from .database import DatabaseManager + + # Initialize + db = DatabaseManager("postgresql://trader:password@localhost/trading_db") + risk_manager = RiskManager(db) + + # Check an order + result = await risk_manager.check_order( + ticker="AAPL", + side="BUY", + quantity=1000, + price=150.00, + stop_loss=145.00 + ) + + print(f"Approved: {result.approved}") + print(f"Risk Score: {result.risk_score}") + print(f"Risk Level: {result.risk_level}") + print(f"Violations: {result.violations}") + print(f"Reasons: {result.reasons}") + print(f"Recommendations: {result.recommendations}") + + # Check portfolio health + health = await risk_manager.check_portfolio_health() + print(f"\nPortfolio Health: {health['status']}") + print(f"Health Score: {health['score']}") + print(f"Issues: {health['issues']}") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + import asyncio + asyncio.run(main()) \ No newline at end of file