diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..54467f6b --- /dev/null +++ b/.dockerignore @@ -0,0 +1,70 @@ +# Git +.git +.gitignore + +# Python +__pycache__ +*.py[cod] +*$py.class +*.so +.Python +venv/ +env/ +ENV/ +.venv +.ipynb_checkpoints +*.egg-info/ +dist/ +build/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ +.hypothesis/ + +# Logs +*.log +logs/ +*.pid + +# Environment +.env.local +.env.development +.env.test + +# Data (exclude large files) +*.csv +*.db +*.sqlite +data/ +backups/ + +# Documentation +*.md +docs/ +!README.md + +# Docker +Dockerfile +docker-compose*.yml +.dockerignore + +# Monitoring +monitoring/dashboards/ +prometheus_data/ +grafana_data/ + +# Temporary files +tmp/ +temp/ +cache/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..f3552261 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,92 @@ +# Multi-stage Dockerfile for Autonomous Trading System +# Optimized for security and minimal image size + +# === Stage 1: Builder === +FROM python:3.11-slim as builder + +# Set environment variables +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 + +# Install build dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + g++ \ + make \ + postgresql-client \ + libpq-dev \ + && rm -rf /var/lib/apt/lists/* + +# Create virtual environment +RUN python -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +# Copy requirements files +COPY requirements.txt requirements_autonomous.txt /tmp/ + +# Install Python dependencies +RUN pip install --upgrade pip setuptools wheel && \ + pip install -r /tmp/requirements.txt && \ + pip install -r /tmp/requirements_autonomous.txt + +# === Stage 2: Runtime === +FROM python:3.11-slim + +# Security: Create non-root user +RUN groupadd -r trader && \ + useradd -r -g trader -d /home/trader -s /bin/bash -m trader + +# Set environment variables +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PYTHONPATH=/app \ + PATH="/opt/venv/bin:$PATH" \ + # Default environment (can be overridden) + ENVIRONMENT=production \ + LOG_LEVEL=INFO + +# Install runtime dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + postgresql-client \ + curl \ + ca-certificates \ + tzdata \ + && rm -rf /var/lib/apt/lists/* \ + && ln -sf /usr/share/zoneinfo/America/New_York /etc/localtime \ + && echo "America/New_York" > /etc/timezone + +# Copy virtual environment from builder +COPY --from=builder /opt/venv /opt/venv + +# Create app directory +WORKDIR /app + +# Copy application code +COPY --chown=trader:trader . /app/ + +# Create necessary directories with proper permissions +RUN mkdir -p /app/logs /app/data /app/cache && \ + chown -R trader:trader /app/logs /app/data /app/cache && \ + chmod 755 /app/logs /app/data /app/cache + +# Security: Set proper file permissions +RUN find /app -type f -name "*.py" -exec chmod 644 {} \; && \ + find /app -type d -exec chmod 755 {} \; && \ + chmod +x /app/autonomous_trader.py + +# Switch to non-root user +USER trader + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \ + CMD python -c "import sys; sys.path.insert(0, '/app'); from autonomous.core.health_check import health_check; exit(0 if health_check() else 1)" || exit 1 + +# Expose ports +# 8000 - API/Dashboard +# 9090 - Prometheus metrics +EXPOSE 8000 9090 + +# Default command (can be overridden) +CMD ["python", "autonomous_trader.py"] \ No newline at end of file diff --git a/autonomous/core/cache.py b/autonomous/core/cache.py new file mode 100644 index 00000000..033b3afb --- /dev/null +++ b/autonomous/core/cache.py @@ -0,0 +1,811 @@ +""" +Redis Cache Layer +================= + +High-performance caching with Redis for market data, AI decisions, +and expensive computations. +""" + +import json +import logging +import pickle +from typing import Any, Optional, Dict, List, Union +from datetime import datetime, timedelta +from decimal import Decimal +import hashlib +import asyncio + +import redis.asyncio as redis +from redis.asyncio.lock import Lock +from redis.exceptions import RedisError, ConnectionError + +logger = logging.getLogger(__name__) + + +class CacheKey: + """Cache key builder with namespacing""" + + # Namespaces + MARKET_DATA = "market" + AI_DECISION = "ai" + SIGNAL = "signal" + CONGRESSIONAL = "congress" + NEWS = "news" + TECHNICAL = "tech" + POSITION = "position" + METRICS = "metrics" + + @staticmethod + def build(*parts: Union[str, int, float]) -> str: + """Build a cache key from parts""" + return ":".join(str(p) for p in parts) + + @staticmethod + def market_data(ticker: str) -> str: + """Key for market data""" + return CacheKey.build(CacheKey.MARKET_DATA, ticker) + + @staticmethod + def ai_decision(ticker: str, date: str) -> str: + """Key for AI trading decision""" + return CacheKey.build(CacheKey.AI_DECISION, ticker, date) + + @staticmethod + def signal(ticker: str, signal_type: str) -> str: + """Key for trading signal""" + return CacheKey.build(CacheKey.SIGNAL, ticker, signal_type) + + @staticmethod + def congressional_trades(days_back: int) -> str: + """Key for congressional trades""" + return CacheKey.build(CacheKey.CONGRESSIONAL, f"last_{days_back}_days") + + @staticmethod + def news_sentiment(ticker: str) -> str: + """Key for news sentiment""" + return CacheKey.build(CacheKey.NEWS, ticker, "sentiment") + + @staticmethod + def technical_indicators(ticker: str) -> str: + """Key for technical indicators""" + return CacheKey.build(CacheKey.TECHNICAL, ticker) + + @staticmethod + def risk_metrics() -> str: + """Key for risk metrics""" + return CacheKey.build(CacheKey.METRICS, "risk") + + +class RedisCache: + """ + Redis cache manager with connection pooling and error handling + """ + + def __init__(self, + host: str = "localhost", + port: int = 6379, + db: int = 0, + password: Optional[str] = None, + max_connections: int = 50, + socket_timeout: int = 5, + retry_on_timeout: bool = True): + """ + Initialize Redis cache + + Args: + host: Redis host + port: Redis port + db: Redis database number + password: Redis password + max_connections: Maximum connection pool size + socket_timeout: Socket timeout in seconds + retry_on_timeout: Retry on timeout errors + """ + self.host = host + self.port = port + self.db = db + + # Connection pool for better performance + self.pool = redis.ConnectionPool( + host=host, + port=port, + db=db, + password=password, + max_connections=max_connections, + socket_timeout=socket_timeout, + socket_connect_timeout=socket_timeout, + retry_on_timeout=retry_on_timeout, + decode_responses=False # Handle encoding ourselves + ) + + self.redis: Optional[redis.Redis] = None + self.connected = False + + # Default TTLs (seconds) + self.default_ttls = { + CacheKey.MARKET_DATA: 60, # 1 minute for market data + CacheKey.AI_DECISION: 3600, # 1 hour for AI decisions + CacheKey.SIGNAL: 300, # 5 minutes for signals + CacheKey.CONGRESSIONAL: 3600, # 1 hour for congressional + CacheKey.NEWS: 1800, # 30 minutes for news + CacheKey.TECHNICAL: 300, # 5 minutes for technical + CacheKey.METRICS: 60, # 1 minute for metrics + } + + # Cache statistics + self.stats = { + 'hits': 0, + 'misses': 0, + 'errors': 0, + 'evictions': 0 + } + + async def connect(self) -> bool: + """ + Connect to Redis + + Returns: + True if connected successfully + """ + try: + self.redis = redis.Redis(connection_pool=self.pool) + + # Test connection + await self.redis.ping() + + self.connected = True + logger.info(f"Connected to Redis at {self.host}:{self.port}") + + # Set memory policy + await self._configure_memory_policy() + + return True + + except (RedisError, ConnectionError) as e: + logger.error(f"Failed to connect to Redis: {e}") + self.connected = False + return False + + async def disconnect(self): + """Disconnect from Redis""" + if self.redis: + await self.redis.close() + await self.pool.disconnect() + self.connected = False + logger.info("Disconnected from Redis") + + async def _configure_memory_policy(self): + """Configure Redis memory policy""" + try: + # Set max memory policy to LRU (Least Recently Used) + await self.redis.config_set('maxmemory-policy', 'allkeys-lru') + + # Set max memory (optional, depends on your setup) + # await self.redis.config_set('maxmemory', '1gb') + + except Exception as e: + logger.warning(f"Could not configure memory policy: {e}") + + # === Core Cache Operations === + + async def get(self, key: str) -> Optional[Any]: + """ + Get value from cache + + Args: + key: Cache key + + Returns: + Cached value or None if not found + """ + if not self.connected: + return None + + try: + value = await self.redis.get(key) + + if value is None: + self.stats['misses'] += 1 + return None + + self.stats['hits'] += 1 + + # Deserialize based on data type + return self._deserialize(value) + + except Exception as e: + logger.error(f"Cache get error for {key}: {e}") + self.stats['errors'] += 1 + return None + + async def set(self, + key: str, + value: Any, + ttl: Optional[int] = None, + nx: bool = False) -> bool: + """ + Set value in cache + + Args: + key: Cache key + value: Value to cache + ttl: Time to live in seconds + nx: Only set if key doesn't exist + + Returns: + True if set successfully + """ + if not self.connected: + return False + + try: + # Determine TTL + if ttl is None: + namespace = key.split(':')[0] + ttl = self.default_ttls.get(namespace, 300) + + # Serialize value + serialized = self._serialize(value) + + # Set with TTL + if nx: + result = await self.redis.set(key, serialized, ex=ttl, nx=True) + else: + result = await self.redis.setex(key, ttl, serialized) + + return bool(result) + + except Exception as e: + logger.error(f"Cache set error for {key}: {e}") + self.stats['errors'] += 1 + return False + + async def delete(self, *keys: str) -> int: + """ + Delete keys from cache + + Args: + *keys: Keys to delete + + Returns: + Number of keys deleted + """ + if not self.connected or not keys: + return 0 + + try: + return await self.redis.delete(*keys) + except Exception as e: + logger.error(f"Cache delete error: {e}") + return 0 + + async def exists(self, key: str) -> bool: + """ + Check if key exists + + Args: + key: Cache key + + Returns: + True if exists + """ + if not self.connected: + return False + + try: + return bool(await self.redis.exists(key)) + except Exception: + return False + + async def expire(self, key: str, ttl: int) -> bool: + """ + Set expiration on key + + Args: + key: Cache key + ttl: Time to live in seconds + + Returns: + True if expiration set + """ + if not self.connected: + return False + + try: + return bool(await self.redis.expire(key, ttl)) + except Exception: + return False + + # === Batch Operations === + + async def mget(self, keys: List[str]) -> Dict[str, Any]: + """ + Get multiple values + + Args: + keys: List of keys + + Returns: + Dictionary of key-value pairs + """ + if not self.connected or not keys: + return {} + + try: + values = await self.redis.mget(keys) + result = {} + + for key, value in zip(keys, values): + if value is not None: + result[key] = self._deserialize(value) + self.stats['hits'] += 1 + else: + self.stats['misses'] += 1 + + return result + + except Exception as e: + logger.error(f"Cache mget error: {e}") + return {} + + async def mset(self, + data: Dict[str, Any], + ttl: Optional[int] = None) -> bool: + """ + Set multiple values + + Args: + data: Dictionary of key-value pairs + ttl: Time to live in seconds + + Returns: + True if all set successfully + """ + if not self.connected or not data: + return False + + try: + # Use pipeline for atomic operations + pipe = self.redis.pipeline() + + for key, value in data.items(): + serialized = self._serialize(value) + + if ttl is None: + namespace = key.split(':')[0] + key_ttl = self.default_ttls.get(namespace, 300) + else: + key_ttl = ttl + + pipe.setex(key, key_ttl, serialized) + + results = await pipe.execute() + return all(results) + + except Exception as e: + logger.error(f"Cache mset error: {e}") + return False + + # === Pattern Operations === + + async def keys_pattern(self, pattern: str) -> List[str]: + """ + Get keys matching pattern + + Args: + pattern: Redis pattern (e.g., "market:*") + + Returns: + List of matching keys + """ + if not self.connected: + return [] + + try: + keys = await self.redis.keys(pattern) + return [k.decode('utf-8') if isinstance(k, bytes) else k for k in keys] + except Exception as e: + logger.error(f"Cache keys error: {e}") + return [] + + async def delete_pattern(self, pattern: str) -> int: + """ + Delete keys matching pattern + + Args: + pattern: Redis pattern + + Returns: + Number of keys deleted + """ + keys = await self.keys_pattern(pattern) + if keys: + return await self.delete(*keys) + return 0 + + # === Distributed Locking === + + async def acquire_lock(self, + name: str, + timeout: int = 10, + blocking: bool = True) -> Optional[Lock]: + """ + Acquire distributed lock + + Args: + name: Lock name + timeout: Lock timeout in seconds + blocking: Whether to block waiting for lock + + Returns: + Lock object or None + """ + if not self.connected: + return None + + try: + lock = Lock( + self.redis, + f"lock:{name}", + timeout=timeout, + blocking=blocking, + blocking_timeout=5 if blocking else 0 + ) + + if await lock.acquire(): + return lock + return None + + except Exception as e: + logger.error(f"Lock acquire error: {e}") + return None + + # === Specialized Cache Methods === + + async def cache_market_data(self, + ticker: str, + data: Dict[str, Any], + ttl: int = 60) -> bool: + """ + Cache market data with automatic expiration + + Args: + ticker: Stock ticker + data: Market data dictionary + ttl: Time to live in seconds + + Returns: + True if cached successfully + """ + key = CacheKey.market_data(ticker) + + # Add timestamp + data['cached_at'] = datetime.now().isoformat() + + return await self.set(key, data, ttl) + + async def get_market_data(self, ticker: str) -> Optional[Dict[str, Any]]: + """ + Get cached market data + + Args: + ticker: Stock ticker + + Returns: + Market data or None + """ + key = CacheKey.market_data(ticker) + data = await self.get(key) + + if data and isinstance(data, dict): + # Check if data is stale (older than 2 minutes) + cached_at = data.get('cached_at') + if cached_at: + cache_time = datetime.fromisoformat(cached_at) + if (datetime.now() - cache_time).seconds > 120: + # Data is stale, delete it + await self.delete(key) + return None + + return data + + async def cache_ai_decision(self, + ticker: str, + date: str, + decision: str, + ttl: int = 3600) -> bool: + """ + Cache AI trading decision + + Args: + ticker: Stock ticker + date: Analysis date + decision: AI decision + ttl: Time to live (default 1 hour) + + Returns: + True if cached successfully + """ + key = CacheKey.ai_decision(ticker, date) + + data = { + 'decision': decision, + 'ticker': ticker, + 'date': date, + 'cached_at': datetime.now().isoformat() + } + + return await self.set(key, data, ttl) + + async def get_ai_decision(self, + ticker: str, + date: str) -> Optional[str]: + """ + Get cached AI decision + + Args: + ticker: Stock ticker + date: Analysis date + + Returns: + AI decision or None + """ + key = CacheKey.ai_decision(ticker, date) + data = await self.get(key) + + if data and isinstance(data, dict): + return data.get('decision') + return None + + async def invalidate_ticker(self, ticker: str): + """ + Invalidate all cache entries for a ticker + + Args: + ticker: Stock ticker + """ + patterns = [ + f"{CacheKey.MARKET_DATA}:{ticker}*", + f"{CacheKey.AI_DECISION}:{ticker}*", + f"{CacheKey.SIGNAL}:{ticker}*", + f"{CacheKey.TECHNICAL}:{ticker}*", + f"{CacheKey.NEWS}:{ticker}*" + ] + + for pattern in patterns: + await self.delete_pattern(pattern) + + # === Helper Methods === + + def _serialize(self, value: Any) -> bytes: + """ + Serialize value for storage + + Args: + value: Value to serialize + + Returns: + Serialized bytes + """ + # Handle different types + if isinstance(value, (str, int, float)): + return json.dumps(value).encode('utf-8') + elif isinstance(value, Decimal): + return json.dumps(str(value)).encode('utf-8') + elif isinstance(value, (dict, list)): + return json.dumps(value, default=str).encode('utf-8') + else: + # Use pickle for complex objects + return pickle.dumps(value) + + def _deserialize(self, value: bytes) -> Any: + """ + Deserialize value from storage + + Args: + value: Serialized bytes + + Returns: + Deserialized value + """ + if not value: + return None + + # Try JSON first + try: + decoded = value.decode('utf-8') + return json.loads(decoded) + except (json.JSONDecodeError, UnicodeDecodeError): + pass + + # Try pickle + try: + return pickle.loads(value) + except Exception as e: + logger.error(f"Deserialization error: {e}") + return None + + async def get_stats(self) -> Dict[str, Any]: + """ + Get cache statistics + + Returns: + Statistics dictionary + """ + info = {} + + if self.connected: + try: + # Get Redis server info + redis_info = await self.redis.info('stats') + info['redis'] = { + 'total_connections': redis_info.get('total_connections_received', 0), + 'commands_processed': redis_info.get('total_commands_processed', 0), + 'keyspace_hits': redis_info.get('keyspace_hits', 0), + 'keyspace_misses': redis_info.get('keyspace_misses', 0), + } + + # Get memory info + memory_info = await self.redis.info('memory') + info['memory'] = { + 'used_memory': memory_info.get('used_memory_human', '0'), + 'peak_memory': memory_info.get('used_memory_peak_human', '0'), + } + + # Get database size + info['db_size'] = await self.redis.dbsize() + + except Exception as e: + logger.error(f"Error getting Redis stats: {e}") + + # Add local stats + total_requests = self.stats['hits'] + self.stats['misses'] + hit_rate = (self.stats['hits'] / total_requests * 100) if total_requests > 0 else 0 + + info['cache_stats'] = { + **self.stats, + 'hit_rate': f"{hit_rate:.1f}%", + 'total_requests': total_requests + } + + return info + + async def clear_all(self) -> bool: + """ + Clear all cache (USE WITH CAUTION) + + Returns: + True if cleared successfully + """ + if not self.connected: + return False + + try: + await self.redis.flushdb() + logger.warning("Cache cleared - all data deleted") + return True + except Exception as e: + logger.error(f"Error clearing cache: {e}") + return False + + +# === Cache Decorator === + +def cached(ttl: int = 300, key_prefix: str = ""): + """ + Decorator for caching function results + + Args: + ttl: Time to live in seconds + key_prefix: Prefix for cache key + """ + def decorator(func): + async def wrapper(*args, **kwargs): + # Build cache key from function name and arguments + cache_key_parts = [key_prefix or func.__name__] + + # Add args to key + for arg in args: + if hasattr(arg, '__dict__'): + # Skip object instances + continue + cache_key_parts.append(str(arg)) + + # Add kwargs to key + for k, v in sorted(kwargs.items()): + cache_key_parts.append(f"{k}={v}") + + cache_key = ":".join(cache_key_parts) + + # Try to get from cache + if hasattr(wrapper, '_cache'): + cached_value = await wrapper._cache.get(cache_key) + if cached_value is not None: + return cached_value + + # Call function + result = await func(*args, **kwargs) + + # Cache result + if hasattr(wrapper, '_cache') and result is not None: + await wrapper._cache.set(cache_key, result, ttl) + + return result + + return wrapper + return decorator + + +# === Cache Manager Singleton === + +class CacheManager: + """Singleton cache manager""" + + _instance: Optional[RedisCache] = None + + @classmethod + def get_instance(cls) -> RedisCache: + """Get or create cache instance""" + if cls._instance is None: + cls._instance = RedisCache() + return cls._instance + + @classmethod + async def initialize(cls, + host: str = "localhost", + port: int = 6379, + **kwargs) -> RedisCache: + """ + Initialize cache manager + + Args: + host: Redis host + port: Redis port + **kwargs: Additional Redis options + + Returns: + Redis cache instance + """ + cls._instance = RedisCache(host, port, **kwargs) + await cls._instance.connect() + return cls._instance + + +# Example usage +async def main(): + """Example of using the cache layer""" + + # Initialize cache + cache = await CacheManager.initialize() + + # Cache market data + market_data = { + 'last': 150.25, + 'bid': 150.20, + 'ask': 150.30, + 'volume': 1000000 + } + await cache.cache_market_data("AAPL", market_data) + + # Get cached data + cached = await cache.get_market_data("AAPL") + print(f"Cached market data: {cached}") + + # Cache AI decision + await cache.cache_ai_decision("NVDA", "2024-01-01", "BUY with high confidence") + + # Get cached decision + decision = await cache.get_ai_decision("NVDA", "2024-01-01") + print(f"Cached AI decision: {decision}") + + # Get cache statistics + stats = await cache.get_stats() + print(f"Cache stats: {stats}") + + # Clean up + await cache.disconnect() + + +if __name__ == "__main__": + import asyncio + asyncio.run(main()) \ No newline at end of file diff --git a/autonomous/security/validators.py b/autonomous/security/validators.py new file mode 100644 index 00000000..60abba1a --- /dev/null +++ b/autonomous/security/validators.py @@ -0,0 +1,687 @@ +""" +Security Validation Layer +========================= + +Comprehensive input validation, sanitization, and security checks +to prevent injection attacks and ensure data integrity. +""" + +import re +import logging +import hashlib +import hmac +import secrets +from typing import Any, Dict, List, Optional, Union, Type +from datetime import datetime, timedelta +from decimal import Decimal, InvalidOperation +from enum import Enum +import json + +from pydantic import ( + BaseModel, Field, validator, root_validator, + ValidationError, constr, condecimal, conint +) +from typing_extensions import Annotated + +logger = logging.getLogger(__name__) + + +# === Custom Types with Validation === + +# Ticker symbol: 1-10 uppercase letters/numbers, no special chars +TickerSymbol = Annotated[ + str, + constr( + regex=r'^[A-Z0-9]{1,10}$', + strip_whitespace=True, + to_upper=True + ) +] + +# Price: positive decimal with max 2 decimal places +Price = Annotated[ + Decimal, + condecimal( + gt=0, + max_digits=10, + decimal_places=2 + ) +] + +# Quantity: positive integer within reasonable bounds +Quantity = Annotated[ + int, + conint( + gt=0, + le=1000000 # Max 1 million shares + ) +] + +# Percentage: 0-100 +Percentage = Annotated[ + float, + Field(ge=0.0, le=100.0) +] + + +class SecurityLevel(str, Enum): + """Security validation levels""" + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + CRITICAL = "critical" + + +# === Input Validators === + +class TickerValidator(BaseModel): + """Validator for ticker symbols""" + ticker: TickerSymbol + + @validator('ticker') + def validate_ticker(cls, v): + """Additional ticker validation""" + # Check against blacklist of invalid tickers + blacklist = ['TEST', 'DUMMY', 'NULL', 'UNDEFINED'] + if v in blacklist: + raise ValueError(f"Invalid ticker: {v}") + + # Check for SQL injection patterns + if cls._contains_sql_injection(v): + raise ValueError("Potential SQL injection detected") + + return v + + @staticmethod + def _contains_sql_injection(value: str) -> bool: + """Check for SQL injection patterns""" + sql_patterns = [ + r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|UNION|CREATE|ALTER)\b)", + r"(-{2}|\/\*|\*\/)", # SQL comments + r"(;|\||&&)", # Command separators + r"(\bOR\b.*=.*)", # OR conditions + r"('|\")", # Quotes + ] + + for pattern in sql_patterns: + if re.search(pattern, value, re.IGNORECASE): + return True + return False + + +class OrderValidator(BaseModel): + """Comprehensive order validation""" + ticker: TickerSymbol + side: str = Field(regex=r'^(BUY|SELL)$') + quantity: Quantity + order_type: str = Field(regex=r'^(MARKET|LIMIT|STOP|STOP_LIMIT)$') + limit_price: Optional[Price] = None + stop_price: Optional[Price] = None + time_in_force: str = Field( + default="DAY", + regex=r'^(DAY|GTC|IOC|FOK)$' + ) + account_id: Optional[constr(max_length=50)] = None + notes: Optional[constr(max_length=500)] = None + + @root_validator + def validate_prices(cls, values): + """Validate price requirements based on order type""" + order_type = values.get('order_type') + limit_price = values.get('limit_price') + stop_price = values.get('stop_price') + + if order_type == 'LIMIT' and not limit_price: + raise ValueError("Limit price required for LIMIT orders") + + if order_type in ['STOP', 'STOP_LIMIT'] and not stop_price: + raise ValueError("Stop price required for STOP orders") + + if order_type == 'STOP_LIMIT' and not limit_price: + raise ValueError("Limit price required for STOP_LIMIT orders") + + # Check for unreasonable prices + if limit_price and limit_price > 100000: + raise ValueError(f"Limit price ${limit_price} exceeds maximum") + + if stop_price and limit_price: + side = values.get('side') + if side == 'BUY' and stop_price < limit_price: + raise ValueError("Stop price must be above limit for buy stop orders") + elif side == 'SELL' and stop_price > limit_price: + raise ValueError("Stop price must be below limit for sell stop orders") + + return values + + @validator('notes') + def sanitize_notes(cls, v): + """Sanitize notes field""" + if v: + # Remove potential XSS/injection content + v = cls._sanitize_string(v) + return v + + @staticmethod + def _sanitize_string(value: str) -> str: + """Remove dangerous characters from string""" + # Remove HTML/Script tags + value = re.sub(r'<[^>]*>', '', value) + + # Remove JavaScript + value = re.sub(r'javascript:', '', value, flags=re.IGNORECASE) + + # Remove SQL keywords + sql_keywords = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'EXEC', 'UNION'] + for keyword in sql_keywords: + value = re.sub(rf'\b{keyword}\b', '', value, flags=re.IGNORECASE) + + return value.strip() + + +class ConfigValidator(BaseModel): + """Validator for configuration settings""" + max_position_size: Percentage + max_daily_loss: Percentage + max_orders_per_day: conint(gt=0, le=1000) + confidence_threshold: Percentage + stop_loss_percent: Percentage + api_keys: Dict[str, str] = Field(default_factory=dict) + + @validator('api_keys') + def validate_api_keys(cls, v): + """Validate API key format""" + for key_name, key_value in v.items(): + # Check for exposed secrets + if cls._is_placeholder(key_value): + raise ValueError(f"Invalid API key for {key_name}") + + # Check key format (basic validation) + if len(key_value) < 10: + raise ValueError(f"API key {key_name} is too short") + + # Check for common test keys + if key_value in ['test', 'demo', '12345', 'password']: + raise ValueError(f"Invalid API key for {key_name}") + + return v + + @staticmethod + def _is_placeholder(value: str) -> bool: + """Check if value is a placeholder""" + placeholders = [ + 'your_key_here', + 'placeholder', + 'xxxx', + 'todo', + 'changeme' + ] + return any(p in value.lower() for p in placeholders) + + +class WebhookValidator(BaseModel): + """Validator for webhook URLs""" + url: constr( + regex=r'^https:\/\/(discord\.com|hooks\.slack\.com|api\.telegram\.org)\/.*', + max_length=500 + ) + enabled: bool = True + + @validator('url') + def validate_webhook_url(cls, v): + """Validate webhook URL security""" + # Check for localhost/internal IPs (SSRF prevention) + internal_patterns = [ + r'localhost', + r'127\.0\.0\.1', + r'0\.0\.0\.0', + r'192\.168\.', + r'10\.', + r'172\.(1[6-9]|2[0-9]|3[0-1])\.' + ] + + for pattern in internal_patterns: + if re.search(pattern, v, re.IGNORECASE): + raise ValueError("Webhook URL cannot point to internal network") + + return v + + +# === Request Signing & Verification === + +class RequestSigner: + """Sign and verify requests for authentication""" + + def __init__(self, secret_key: str): + """ + Initialize request signer + + Args: + secret_key: Secret key for signing + """ + self.secret_key = secret_key.encode('utf-8') + + def sign_request(self, data: Dict[str, Any]) -> str: + """ + Sign a request payload + + Args: + data: Request data + + Returns: + Signature string + """ + # Sort keys for consistent signing + sorted_data = json.dumps(data, sort_keys=True) + + # Create HMAC signature + signature = hmac.new( + self.secret_key, + sorted_data.encode('utf-8'), + hashlib.sha256 + ).hexdigest() + + return signature + + def verify_request(self, + data: Dict[str, Any], + signature: str) -> bool: + """ + Verify a request signature + + Args: + data: Request data + signature: Provided signature + + Returns: + True if signature is valid + """ + expected_signature = self.sign_request(data) + + # Use constant-time comparison to prevent timing attacks + return hmac.compare_digest(expected_signature, signature) + + +# === Rate Limiting === + +class RateLimiter: + """Rate limiting for API endpoints""" + + def __init__(self): + self.requests: Dict[str, List[datetime]] = {} + + def check_rate_limit(self, + identifier: str, + max_requests: int = 100, + window_seconds: int = 60) -> bool: + """ + Check if request is within rate limit + + Args: + identifier: Client identifier (IP, API key, etc.) + max_requests: Maximum requests allowed + window_seconds: Time window in seconds + + Returns: + True if within limit + """ + now = datetime.now() + window_start = now - timedelta(seconds=window_seconds) + + # Get request history + if identifier not in self.requests: + self.requests[identifier] = [] + + # Remove old requests + self.requests[identifier] = [ + req_time for req_time in self.requests[identifier] + if req_time > window_start + ] + + # Check limit + if len(self.requests[identifier]) >= max_requests: + return False + + # Add current request + self.requests[identifier].append(now) + return True + + +# === Secure Configuration === + +class SecureConfig: + """Secure configuration management""" + + def __init__(self, config_data: Dict[str, Any]): + """ + Initialize secure config + + Args: + config_data: Configuration dictionary + """ + self.config = self._sanitize_config(config_data) + self._validate_security_settings() + + def _sanitize_config(self, config: Dict[str, Any]) -> Dict[str, Any]: + """Sanitize configuration data""" + sanitized = {} + + for key, value in config.items(): + # Skip sensitive keys from logs + if any(sensitive in key.lower() for sensitive in + ['password', 'secret', 'key', 'token']): + # Don't include actual value in sanitized version + sanitized[key] = "***REDACTED***" if value else None + else: + if isinstance(value, str): + # Sanitize strings + sanitized[key] = self._sanitize_value(value) + elif isinstance(value, dict): + # Recursively sanitize nested dicts + sanitized[key] = self._sanitize_config(value) + else: + sanitized[key] = value + + return sanitized + + def _sanitize_value(self, value: str) -> str: + """Sanitize a configuration value""" + # Remove potential command injection + dangerous_chars = [';', '|', '&', '$', '`', '\\', '\n', '\r'] + for char in dangerous_chars: + value = value.replace(char, '') + + # Remove path traversal + value = value.replace('../', '').replace('..\\', '') + + return value + + def _validate_security_settings(self): + """Validate security-critical settings""" + # Check for secure defaults + if self.config.get('ssl_enabled', True) is False: + logger.warning("SSL is disabled - this is insecure!") + + if self.config.get('debug_mode', False) is True: + logger.warning("Debug mode is enabled - disable in production!") + + if self.config.get('allow_all_origins', False) is True: + logger.warning("CORS allow_all_origins is enabled - security risk!") + + +# === API Security === + +class APISecurityValidator: + """Validator for API security""" + + @staticmethod + def validate_api_key(api_key: str) -> bool: + """ + Validate API key format and strength + + Args: + api_key: API key to validate + + Returns: + True if valid + """ + # Check length + if len(api_key) < 32: + return False + + # Check for common patterns + if api_key.startswith('sk_test_') or api_key.startswith('pk_test_'): + logger.warning("Test API key detected") + + # Check entropy (simplified) + unique_chars = len(set(api_key)) + if unique_chars < 10: + return False # Low entropy + + return True + + @staticmethod + def generate_api_key() -> str: + """ + Generate a secure API key + + Returns: + Secure API key + """ + # Generate 32 bytes of random data + random_bytes = secrets.token_bytes(32) + + # Convert to hex string + api_key = f"sk_live_{random_bytes.hex()}" + + return api_key + + @staticmethod + def hash_api_key(api_key: str) -> str: + """ + Hash an API key for storage + + Args: + api_key: API key to hash + + Returns: + Hashed API key + """ + # Use SHA-256 for hashing + return hashlib.sha256(api_key.encode('utf-8')).hexdigest() + + +# === XSS Prevention === + +class XSSPrevention: + """Cross-site scripting prevention""" + + @staticmethod + def sanitize_html(text: str) -> str: + """ + Sanitize HTML content + + Args: + text: Text to sanitize + + Returns: + Sanitized text + """ + # HTML entity encoding + html_escapes = { + '<': '<', + '>': '>', + '"': '"', + "'": ''', + '&': '&', + '/': '/', + '`': '`', + '=': '=' + } + + for char, escape in html_escapes.items(): + text = text.replace(char, escape) + + return text + + @staticmethod + def sanitize_json(data: Dict[str, Any]) -> Dict[str, Any]: + """ + Sanitize JSON data + + Args: + data: JSON data + + Returns: + Sanitized data + """ + sanitized = {} + + for key, value in data.items(): + if isinstance(value, str): + sanitized[key] = XSSPrevention.sanitize_html(value) + elif isinstance(value, dict): + sanitized[key] = XSSPrevention.sanitize_json(value) + elif isinstance(value, list): + sanitized[key] = [ + XSSPrevention.sanitize_html(item) if isinstance(item, str) else item + for item in value + ] + else: + sanitized[key] = value + + return sanitized + + +# === Composite Security Validator === + +class SecurityValidator: + """Main security validator combining all checks""" + + def __init__(self, security_level: SecurityLevel = SecurityLevel.HIGH): + """ + Initialize security validator + + Args: + security_level: Security validation level + """ + self.security_level = security_level + self.rate_limiter = RateLimiter() + self.request_signer = None + + def validate_order(self, order_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Validate and sanitize order data + + Args: + order_data: Raw order data + + Returns: + Validated order data + + Raises: + ValidationError: If validation fails + """ + try: + # Validate with Pydantic + validated = OrderValidator(**order_data) + + # Additional security checks for high security + if self.security_level in [SecurityLevel.HIGH, SecurityLevel.CRITICAL]: + # Check for suspicious patterns + if self._is_suspicious_order(validated.dict()): + raise ValueError("Order flagged as suspicious") + + return validated.dict() + + except ValidationError as e: + logger.error(f"Order validation failed: {e}") + raise + + def validate_config(self, config_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Validate configuration data + + Args: + config_data: Raw configuration + + Returns: + Validated configuration + + Raises: + ValidationError: If validation fails + """ + try: + validated = ConfigValidator(**config_data) + return validated.dict() + except ValidationError as e: + logger.error(f"Config validation failed: {e}") + raise + + def _is_suspicious_order(self, order: Dict[str, Any]) -> bool: + """ + Check for suspicious order patterns + + Args: + order: Order data + + Returns: + True if suspicious + """ + # Check for unusual quantity + if order['quantity'] > 10000: + logger.warning(f"Large order quantity: {order['quantity']}") + return True + + # Check for price manipulation attempts + if order.get('limit_price'): + # Check for penny stock manipulation + if order['limit_price'] < 1 and order['quantity'] > 1000: + logger.warning("Potential penny stock manipulation") + return True + + return False + + def sanitize_user_input(self, input_data: Any) -> Any: + """ + Sanitize any user input + + Args: + input_data: User input + + Returns: + Sanitized input + """ + if isinstance(input_data, str): + # Remove dangerous characters + input_data = re.sub(r'[<>&\'"`]', '', input_data) + + # Truncate to reasonable length + input_data = input_data[:1000] + + elif isinstance(input_data, dict): + input_data = XSSPrevention.sanitize_json(input_data) + + return input_data + + +# === Example Usage === + +def main(): + """Example of using security validators""" + + # Initialize validator + validator = SecurityValidator(SecurityLevel.HIGH) + + # Validate order + order_data = { + "ticker": "AAPL", + "side": "BUY", + "quantity": 100, + "order_type": "LIMIT", + "limit_price": "150.50", + "notes": "Test order " + } + + try: + validated_order = validator.validate_order(order_data) + print(f"Validated order: {validated_order}") + except ValidationError as e: + print(f"Validation failed: {e}") + + # Generate secure API key + api_key = APISecurityValidator.generate_api_key() + print(f"Generated API key: {api_key}") + + # Hash for storage + hashed = APISecurityValidator.hash_api_key(api_key) + print(f"Hashed key: {hashed}") + + # Rate limiting + rate_limiter = RateLimiter() + for i in range(5): + allowed = rate_limiter.check_rate_limit("user123", max_requests=3, window_seconds=10) + print(f"Request {i+1}: {'Allowed' if allowed else 'Blocked'}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml new file mode 100644 index 00000000..5737e1fe --- /dev/null +++ b/docker-compose.dev.yml @@ -0,0 +1,124 @@ +# Development Docker Compose Configuration +# Use with: docker-compose -f docker-compose.yml -f docker-compose.dev.yml up + +version: '3.8' + +services: + # === Override main trader service for development === + trader: + build: + context: . + dockerfile: Dockerfile + # Development build args + args: + - BUILD_ENV=development + environment: + # Development settings + ENVIRONMENT: development + LOG_LEVEL: DEBUG + TRADING_ENABLED: false # Always false in dev + PAPER_TRADING: true # Always paper in dev + + # Hot reload + PYTHONUNBUFFERED: 1 + FLASK_ENV: development + FLASK_DEBUG: 1 + + volumes: + # Mount source code for hot reload + - ./:/app:rw + - ./logs:/app/logs:rw + - ./data:/app/data:rw + # Exclude virtual environment + - /app/venv + - /app/.venv + + # Override command for development + command: > + sh -c " + echo 'Starting in development mode...' + pip install -e . + python -m watchdog.auto_restart --directory=. --pattern='*.py' --recursive -- python autonomous_trader.py + " + + # No resource limits in development + deploy: + resources: + limits: + cpus: '4' + memory: 8G + + # === Development Database with sample data === + postgres: + environment: + POSTGRES_DB: trading_dev + POSTGRES_USER: dev_trader + POSTGRES_PASSWORD: devpass + ports: + - "5433:5432" # Different port for dev + volumes: + - ./dev/sample_data.sql:/docker-entrypoint-initdb.d/02-sample-data.sql:ro + + # === Redis with monitoring === + redis: + ports: + - "6380:6379" # Different port for dev + command: > + redis-server + --maxmemory 512mb + --maxmemory-policy allkeys-lru + --loglevel debug + + # === Redis Commander for Redis GUI === + redis-commander: + image: rediscommander/redis-commander:latest + container_name: trading_redis_commander + environment: + REDIS_HOST: redis + REDIS_PORT: 6379 + REDIS_PASSWORD: ${REDIS_PASSWORD:-changeme} + ports: + - "8081:8081" + networks: + - trading_network + depends_on: + - redis + + # === Adminer for Database GUI === + adminer: + image: adminer:latest + container_name: trading_adminer + ports: + - "8082:8080" + environment: + ADMINER_DEFAULT_SERVER: postgres + networks: + - trading_network + depends_on: + - postgres + + # === Jupyter Notebook for Research === + jupyter: + build: + context: . + dockerfile: Dockerfile + container_name: trading_jupyter + command: > + jupyter lab + --ip=0.0.0.0 + --port=8888 + --no-browser + --allow-root + --NotebookApp.token='' + --NotebookApp.password='' + volumes: + - ./:/app:rw + - ./notebooks:/app/notebooks:rw + ports: + - "8888:8888" + environment: + PYTHONPATH: /app + networks: + - trading_network + profiles: + - research \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..576f395a --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,220 @@ +version: '3.8' + +services: + # === PostgreSQL Database === + postgres: + image: timescale/timescaledb:latest-pg15 + container_name: trading_postgres + environment: + POSTGRES_DB: ${DB_NAME:-trading_db} + POSTGRES_USER: ${DB_USER:-trader} + POSTGRES_PASSWORD: ${DB_PASSWORD:-changeme} + POSTGRES_INITDB_ARGS: "--encoding=UTF-8 --lc-collate=C --lc-ctype=C" + volumes: + - postgres_data:/var/lib/postgresql/data + - ./init.sql:/docker-entrypoint-initdb.d/init.sql:ro + ports: + - "${DB_PORT:-5432}:5432" + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${DB_USER:-trader} -d ${DB_NAME:-trading_db}"] + interval: 10s + timeout: 5s + retries: 5 + networks: + - trading_network + restart: unless-stopped + + # === Redis Cache === + redis: + image: redis:7-alpine + container_name: trading_redis + command: > + redis-server + --maxmemory 2gb + --maxmemory-policy allkeys-lru + --appendonly yes + --requirepass ${REDIS_PASSWORD:-changeme} + volumes: + - redis_data:/data + ports: + - "${REDIS_PORT:-6379}:6379" + healthcheck: + test: ["CMD", "redis-cli", "--raw", "incr", "ping"] + interval: 10s + timeout: 5s + retries: 5 + networks: + - trading_network + restart: unless-stopped + + # === Main Trading Application === + trader: + build: + context: . + dockerfile: Dockerfile + image: autonomous-trader:latest + container_name: trading_app + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + environment: + # Database + DB_HOST: postgres + DB_PORT: 5432 + DB_NAME: ${DB_NAME:-trading_db} + DB_USER: ${DB_USER:-trader} + DB_PASSWORD: ${DB_PASSWORD:-changeme} + + # Redis + REDIS_HOST: redis + REDIS_PORT: 6379 + REDIS_PASSWORD: ${REDIS_PASSWORD:-changeme} + + # IBKR Settings + IBKR_HOST: ${IBKR_HOST:-host.docker.internal} # Use host machine's TWS + IBKR_PORT: ${IBKR_PORT:-7497} + IBKR_CLIENT_ID: ${IBKR_CLIENT_ID:-1} + + # API Keys (from .env file) + OPENAI_API_KEY: ${OPENAI_API_KEY} + ALPHA_VANTAGE_API_KEY: ${ALPHA_VANTAGE_API_KEY} + QUIVER_API_KEY: ${QUIVER_API_KEY} + + # Notification Settings + DISCORD_WEBHOOK_URL: ${DISCORD_WEBHOOK_URL} + TELEGRAM_BOT_TOKEN: ${TELEGRAM_BOT_TOKEN} + TELEGRAM_CHAT_ID: ${TELEGRAM_CHAT_ID} + + # Trading Settings + TRADING_ENABLED: ${TRADING_ENABLED:-false} + PAPER_TRADING: ${PAPER_TRADING:-true} + + # Application Settings + ENVIRONMENT: ${ENVIRONMENT:-production} + LOG_LEVEL: ${LOG_LEVEL:-INFO} + + volumes: + - ./logs:/app/logs + - ./data:/app/data + - ./.env:/app/.env:ro + ports: + - "${APP_PORT:-8000}:8000" # API/Dashboard + - "${METRICS_PORT:-9090}:9090" # Prometheus metrics + networks: + - trading_network + restart: unless-stopped + # Resource limits + deploy: + resources: + limits: + cpus: '2' + memory: 4G + reservations: + cpus: '1' + memory: 2G + + # === Monitoring - Prometheus === + prometheus: + image: prom/prometheus:latest + container_name: trading_prometheus + command: + - '--config.file=/etc/prometheus/prometheus.yml' + - '--storage.tsdb.path=/prometheus' + - '--web.console.libraries=/usr/share/prometheus/console_libraries' + - '--web.console.templates=/usr/share/prometheus/consoles' + volumes: + - ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml:ro + - prometheus_data:/prometheus + ports: + - "${PROMETHEUS_PORT:-9091}:9090" + networks: + - trading_network + restart: unless-stopped + + # === Monitoring - Grafana === + grafana: + image: grafana/grafana:latest + container_name: trading_grafana + environment: + GF_SECURITY_ADMIN_USER: ${GRAFANA_USER:-admin} + GF_SECURITY_ADMIN_PASSWORD: ${GRAFANA_PASSWORD:-changeme} + GF_INSTALL_PLUGINS: redis-datasource + volumes: + - grafana_data:/var/lib/grafana + - ./monitoring/grafana/provisioning:/etc/grafana/provisioning:ro + ports: + - "${GRAFANA_PORT:-3000}:3000" + networks: + - trading_network + depends_on: + - prometheus + restart: unless-stopped + + # === Backup Service (Optional) === + backup: + image: postgres:15-alpine + container_name: trading_backup + environment: + PGHOST: postgres + PGPORT: 5432 + PGDATABASE: ${DB_NAME:-trading_db} + PGUSER: ${DB_USER:-trader} + PGPASSWORD: ${DB_PASSWORD:-changeme} + volumes: + - ./backups:/backups + command: > + sh -c " + while true; do + pg_dump -Fc -f /backups/backup_$$(date +%Y%m%d_%H%M%S).dump + find /backups -type f -mtime +7 -delete + sleep 86400 + done + " + networks: + - trading_network + depends_on: + postgres: + condition: service_healthy + restart: unless-stopped + profiles: + - backup + + # === Nginx Reverse Proxy (Optional) === + nginx: + image: nginx:alpine + container_name: trading_nginx + volumes: + - ./nginx/nginx.conf:/etc/nginx/nginx.conf:ro + - ./nginx/ssl:/etc/nginx/ssl:ro + ports: + - "80:80" + - "443:443" + networks: + - trading_network + depends_on: + - trader + - grafana + restart: unless-stopped + profiles: + - production + +# === Networks === +networks: + trading_network: + driver: bridge + ipam: + config: + - subnet: 172.28.0.0/16 + +# === Volumes === +volumes: + postgres_data: + driver: local + redis_data: + driver: local + prometheus_data: + driver: local + grafana_data: + driver: local \ No newline at end of file diff --git a/init.sql b/init.sql new file mode 100644 index 00000000..9f33c1f0 --- /dev/null +++ b/init.sql @@ -0,0 +1,154 @@ +-- Database initialization script for Autonomous Trading System +-- Creates necessary extensions and initial setup + +-- Enable TimescaleDB extension +CREATE EXTENSION IF NOT EXISTS timescaledb CASCADE; + +-- Enable UUID generation +CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; + +-- Enable crypto functions +CREATE EXTENSION IF NOT EXISTS pgcrypto; + +-- Create schema +CREATE SCHEMA IF NOT EXISTS trading; + +-- Set default search path +SET search_path TO trading, public; + +-- Create custom types +CREATE TYPE order_status AS ENUM ( + 'pending', + 'submitted', + 'partially_filled', + 'filled', + 'cancelled', + 'rejected', + 'failed' +); + +CREATE TYPE order_side AS ENUM ('BUY', 'SELL'); +CREATE TYPE order_type AS ENUM ('MARKET', 'LIMIT', 'STOP', 'STOP_LIMIT', 'BRACKET'); +CREATE TYPE risk_level AS ENUM ('LOW', 'MEDIUM', 'HIGH', 'CRITICAL'); + +-- Grant permissions to user +GRANT ALL PRIVILEGES ON SCHEMA trading TO trader; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA trading TO trader; +GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA trading TO trader; + +-- Create indexes for better performance +-- These will be created after tables are created by SQLAlchemy +-- Just documenting the important ones here + +COMMENT ON SCHEMA trading IS 'Autonomous Trading System Schema'; + +-- Function to update last_updated timestamps +CREATE OR REPLACE FUNCTION update_last_updated() +RETURNS TRIGGER AS $$ +BEGIN + NEW.last_updated = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Function to calculate portfolio metrics +CREATE OR REPLACE FUNCTION calculate_portfolio_metrics() +RETURNS TABLE( + total_value DECIMAL, + total_pnl DECIMAL, + position_count INTEGER, + avg_position_size DECIMAL +) AS $$ +BEGIN + RETURN QUERY + SELECT + SUM(market_value) as total_value, + SUM(unrealized_pnl + realized_pnl) as total_pnl, + COUNT(*) as position_count, + AVG(market_value) as avg_position_size + FROM positions + WHERE shares > 0; +END; +$$ LANGUAGE plpgsql; + +-- Function to get recent trades +CREATE OR REPLACE FUNCTION get_recent_trades(hours INTEGER DEFAULT 24) +RETURNS TABLE( + ticker VARCHAR, + action VARCHAR, + quantity INTEGER, + price DECIMAL, + executed_at TIMESTAMP WITH TIME ZONE +) AS $$ +BEGIN + RETURN QUERY + SELECT + t.ticker, + t.action, + t.quantity, + t.price, + t.executed_at + FROM trades t + WHERE t.executed_at >= NOW() - INTERVAL '1 hour' * hours + ORDER BY t.executed_at DESC; +END; +$$ LANGUAGE plpgsql; + +-- Create materialized view for performance metrics (updated hourly) +-- This will be created after the tables exist + +-- Set up row-level security (RLS) for multi-tenant support +-- ALTER TABLE positions ENABLE ROW LEVEL SECURITY; +-- ALTER TABLE orders ENABLE ROW LEVEL SECURITY; + +-- Create audit log table for compliance +CREATE TABLE IF NOT EXISTS audit_log ( + id SERIAL PRIMARY KEY, + table_name VARCHAR(50) NOT NULL, + operation VARCHAR(10) NOT NULL, + user_name VARCHAR(50), + changed_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + old_data JSONB, + new_data JSONB +); + +-- Audit trigger function +CREATE OR REPLACE FUNCTION audit_trigger_function() +RETURNS TRIGGER AS $$ +BEGIN + IF TG_OP = 'INSERT' THEN + INSERT INTO audit_log(table_name, operation, user_name, new_data) + VALUES (TG_TABLE_NAME, TG_OP, current_user, row_to_json(NEW)); + RETURN NEW; + ELSIF TG_OP = 'UPDATE' THEN + INSERT INTO audit_log(table_name, operation, user_name, old_data, new_data) + VALUES (TG_TABLE_NAME, TG_OP, current_user, row_to_json(OLD), row_to_json(NEW)); + RETURN NEW; + ELSIF TG_OP = 'DELETE' THEN + INSERT INTO audit_log(table_name, operation, user_name, old_data) + VALUES (TG_TABLE_NAME, TG_OP, current_user, row_to_json(OLD)); + RETURN OLD; + END IF; + RETURN NULL; +END; +$$ LANGUAGE plpgsql; + +-- Performance settings +ALTER SYSTEM SET shared_buffers = '256MB'; +ALTER SYSTEM SET effective_cache_size = '1GB'; +ALTER SYSTEM SET maintenance_work_mem = '64MB'; +ALTER SYSTEM SET checkpoint_completion_target = 0.9; +ALTER SYSTEM SET wal_buffers = '16MB'; +ALTER SYSTEM SET default_statistics_target = 100; +ALTER SYSTEM SET random_page_cost = 1.1; + +-- Reload configuration +SELECT pg_reload_conf(); + +-- Create initial admin notification +DO $$ +BEGIN + RAISE NOTICE 'Database initialization complete for Autonomous Trading System'; + RAISE NOTICE 'TimescaleDB enabled for time-series optimization'; + RAISE NOTICE 'Audit logging configured for compliance'; +END $$; \ No newline at end of file