""" Security Validation Layer ========================= Comprehensive input validation, sanitization, and security checks to prevent injection attacks and ensure data integrity. """ import re import logging import hashlib import hmac import secrets from typing import Any, Dict, List, Optional, Union, Type from datetime import datetime, timedelta from decimal import Decimal, InvalidOperation from enum import Enum import json from pydantic import ( BaseModel, Field, validator, root_validator, ValidationError, constr, condecimal, conint ) from typing_extensions import Annotated logger = logging.getLogger(__name__) # === Custom Types with Validation === # Ticker symbol: 1-10 uppercase letters/numbers, no special chars TickerSymbol = Annotated[ str, constr( regex=r'^[A-Z0-9]{1,10}$', strip_whitespace=True, to_upper=True ) ] # Price: positive decimal with max 2 decimal places Price = Annotated[ Decimal, condecimal( gt=0, max_digits=10, decimal_places=2 ) ] # Quantity: positive integer within reasonable bounds Quantity = Annotated[ int, conint( gt=0, le=1000000 # Max 1 million shares ) ] # Percentage: 0-100 Percentage = Annotated[ float, Field(ge=0.0, le=100.0) ] class SecurityLevel(str, Enum): """Security validation levels""" LOW = "low" MEDIUM = "medium" HIGH = "high" CRITICAL = "critical" # === Input Validators === class TickerValidator(BaseModel): """Validator for ticker symbols""" ticker: TickerSymbol @validator('ticker') def validate_ticker(cls, v): """Additional ticker validation""" # Check against blacklist of invalid tickers blacklist = ['TEST', 'DUMMY', 'NULL', 'UNDEFINED'] if v in blacklist: raise ValueError(f"Invalid ticker: {v}") # Check for SQL injection patterns if cls._contains_sql_injection(v): raise ValueError("Potential SQL injection detected") return v @staticmethod def _contains_sql_injection(value: str) -> bool: """Check for SQL injection patterns""" sql_patterns = [ r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|UNION|CREATE|ALTER)\b)", r"(-{2}|\/\*|\*\/)", # SQL comments r"(;|\||&&)", # Command separators r"(\bOR\b.*=.*)", # OR conditions r"('|\")", # Quotes ] for pattern in sql_patterns: if re.search(pattern, value, re.IGNORECASE): return True return False class OrderValidator(BaseModel): """Comprehensive order validation""" ticker: TickerSymbol side: str = Field(regex=r'^(BUY|SELL)$') quantity: Quantity order_type: str = Field(regex=r'^(MARKET|LIMIT|STOP|STOP_LIMIT)$') limit_price: Optional[Price] = None stop_price: Optional[Price] = None time_in_force: str = Field( default="DAY", regex=r'^(DAY|GTC|IOC|FOK)$' ) account_id: Optional[constr(max_length=50)] = None notes: Optional[constr(max_length=500)] = None @root_validator def validate_prices(cls, values): """Validate price requirements based on order type""" order_type = values.get('order_type') limit_price = values.get('limit_price') stop_price = values.get('stop_price') if order_type == 'LIMIT' and not limit_price: raise ValueError("Limit price required for LIMIT orders") if order_type in ['STOP', 'STOP_LIMIT'] and not stop_price: raise ValueError("Stop price required for STOP orders") if order_type == 'STOP_LIMIT' and not limit_price: raise ValueError("Limit price required for STOP_LIMIT orders") # Check for unreasonable prices if limit_price and limit_price > 100000: raise ValueError(f"Limit price ${limit_price} exceeds maximum") if stop_price and limit_price: side = values.get('side') if side == 'BUY' and stop_price < limit_price: raise ValueError("Stop price must be above limit for buy stop orders") elif side == 'SELL' and stop_price > limit_price: raise ValueError("Stop price must be below limit for sell stop orders") return values @validator('notes') def sanitize_notes(cls, v): """Sanitize notes field""" if v: # Remove potential XSS/injection content v = cls._sanitize_string(v) return v @staticmethod def _sanitize_string(value: str) -> str: """Remove dangerous characters from string""" # Remove HTML/Script tags value = re.sub(r'<[^>]*>', '', value) # Remove JavaScript value = re.sub(r'javascript:', '', value, flags=re.IGNORECASE) # Remove SQL keywords sql_keywords = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'EXEC', 'UNION'] for keyword in sql_keywords: value = re.sub(rf'\b{keyword}\b', '', value, flags=re.IGNORECASE) return value.strip() class ConfigValidator(BaseModel): """Validator for configuration settings""" max_position_size: Percentage max_daily_loss: Percentage max_orders_per_day: conint(gt=0, le=1000) confidence_threshold: Percentage stop_loss_percent: Percentage api_keys: Dict[str, str] = Field(default_factory=dict) @validator('api_keys') def validate_api_keys(cls, v): """Validate API key format""" for key_name, key_value in v.items(): # Check for exposed secrets if cls._is_placeholder(key_value): raise ValueError(f"Invalid API key for {key_name}") # Check key format (basic validation) if len(key_value) < 10: raise ValueError(f"API key {key_name} is too short") # Check for common test keys if key_value in ['test', 'demo', '12345', 'password']: raise ValueError(f"Invalid API key for {key_name}") return v @staticmethod def _is_placeholder(value: str) -> bool: """Check if value is a placeholder""" placeholders = [ 'your_key_here', 'placeholder', 'xxxx', 'todo', 'changeme' ] return any(p in value.lower() for p in placeholders) class WebhookValidator(BaseModel): """Validator for webhook URLs""" url: constr( regex=r'^https:\/\/(discord\.com|hooks\.slack\.com|api\.telegram\.org)\/.*', max_length=500 ) enabled: bool = True @validator('url') def validate_webhook_url(cls, v): """Validate webhook URL security""" # Check for localhost/internal IPs (SSRF prevention) internal_patterns = [ r'localhost', r'127\.0\.0\.1', r'0\.0\.0\.0', r'192\.168\.', r'10\.', r'172\.(1[6-9]|2[0-9]|3[0-1])\.' ] for pattern in internal_patterns: if re.search(pattern, v, re.IGNORECASE): raise ValueError("Webhook URL cannot point to internal network") return v # === Request Signing & Verification === class RequestSigner: """Sign and verify requests for authentication""" def __init__(self, secret_key: str): """ Initialize request signer Args: secret_key: Secret key for signing """ self.secret_key = secret_key.encode('utf-8') def sign_request(self, data: Dict[str, Any]) -> str: """ Sign a request payload Args: data: Request data Returns: Signature string """ # Sort keys for consistent signing sorted_data = json.dumps(data, sort_keys=True) # Create HMAC signature signature = hmac.new( self.secret_key, sorted_data.encode('utf-8'), hashlib.sha256 ).hexdigest() return signature def verify_request(self, data: Dict[str, Any], signature: str) -> bool: """ Verify a request signature Args: data: Request data signature: Provided signature Returns: True if signature is valid """ expected_signature = self.sign_request(data) # Use constant-time comparison to prevent timing attacks return hmac.compare_digest(expected_signature, signature) # === Rate Limiting === class RateLimiter: """Rate limiting for API endpoints""" def __init__(self): self.requests: Dict[str, List[datetime]] = {} def check_rate_limit(self, identifier: str, max_requests: int = 100, window_seconds: int = 60) -> bool: """ Check if request is within rate limit Args: identifier: Client identifier (IP, API key, etc.) max_requests: Maximum requests allowed window_seconds: Time window in seconds Returns: True if within limit """ now = datetime.now() window_start = now - timedelta(seconds=window_seconds) # Get request history if identifier not in self.requests: self.requests[identifier] = [] # Remove old requests self.requests[identifier] = [ req_time for req_time in self.requests[identifier] if req_time > window_start ] # Check limit if len(self.requests[identifier]) >= max_requests: return False # Add current request self.requests[identifier].append(now) return True # === Secure Configuration === class SecureConfig: """Secure configuration management""" def __init__(self, config_data: Dict[str, Any]): """ Initialize secure config Args: config_data: Configuration dictionary """ self.config = self._sanitize_config(config_data) self._validate_security_settings() def _sanitize_config(self, config: Dict[str, Any]) -> Dict[str, Any]: """Sanitize configuration data""" sanitized = {} for key, value in config.items(): # Skip sensitive keys from logs if any(sensitive in key.lower() for sensitive in ['password', 'secret', 'key', 'token']): # Don't include actual value in sanitized version sanitized[key] = "***REDACTED***" if value else None else: if isinstance(value, str): # Sanitize strings sanitized[key] = self._sanitize_value(value) elif isinstance(value, dict): # Recursively sanitize nested dicts sanitized[key] = self._sanitize_config(value) else: sanitized[key] = value return sanitized def _sanitize_value(self, value: str) -> str: """Sanitize a configuration value""" # Remove potential command injection dangerous_chars = [';', '|', '&', '$', '`', '\\', '\n', '\r'] for char in dangerous_chars: value = value.replace(char, '') # Remove path traversal value = value.replace('../', '').replace('..\\', '') return value def _validate_security_settings(self): """Validate security-critical settings""" # Check for secure defaults if self.config.get('ssl_enabled', True) is False: logger.warning("SSL is disabled - this is insecure!") if self.config.get('debug_mode', False) is True: logger.warning("Debug mode is enabled - disable in production!") if self.config.get('allow_all_origins', False) is True: logger.warning("CORS allow_all_origins is enabled - security risk!") # === API Security === class APISecurityValidator: """Validator for API security""" @staticmethod def validate_api_key(api_key: str) -> bool: """ Validate API key format and strength Args: api_key: API key to validate Returns: True if valid """ # Check length if len(api_key) < 32: return False # Check for common patterns if api_key.startswith('sk_test_') or api_key.startswith('pk_test_'): logger.warning("Test API key detected") # Check entropy (simplified) unique_chars = len(set(api_key)) if unique_chars < 10: return False # Low entropy return True @staticmethod def generate_api_key() -> str: """ Generate a secure API key Returns: Secure API key """ # Generate 32 bytes of random data random_bytes = secrets.token_bytes(32) # Convert to hex string api_key = f"sk_live_{random_bytes.hex()}" return api_key @staticmethod def hash_api_key(api_key: str) -> str: """ Hash an API key for storage Args: api_key: API key to hash Returns: Hashed API key """ # Use SHA-256 for hashing return hashlib.sha256(api_key.encode('utf-8')).hexdigest() # === XSS Prevention === class XSSPrevention: """Cross-site scripting prevention""" @staticmethod def sanitize_html(text: str) -> str: """ Sanitize HTML content Args: text: Text to sanitize Returns: Sanitized text """ # HTML entity encoding html_escapes = { '<': '<', '>': '>', '"': '"', "'": ''', '&': '&', '/': '/', '`': '`', '=': '=' } for char, escape in html_escapes.items(): text = text.replace(char, escape) return text @staticmethod def sanitize_json(data: Dict[str, Any]) -> Dict[str, Any]: """ Sanitize JSON data Args: data: JSON data Returns: Sanitized data """ sanitized = {} for key, value in data.items(): if isinstance(value, str): sanitized[key] = XSSPrevention.sanitize_html(value) elif isinstance(value, dict): sanitized[key] = XSSPrevention.sanitize_json(value) elif isinstance(value, list): sanitized[key] = [ XSSPrevention.sanitize_html(item) if isinstance(item, str) else item for item in value ] else: sanitized[key] = value return sanitized # === Composite Security Validator === class SecurityValidator: """Main security validator combining all checks""" def __init__(self, security_level: SecurityLevel = SecurityLevel.HIGH): """ Initialize security validator Args: security_level: Security validation level """ self.security_level = security_level self.rate_limiter = RateLimiter() self.request_signer = None def validate_order(self, order_data: Dict[str, Any]) -> Dict[str, Any]: """ Validate and sanitize order data Args: order_data: Raw order data Returns: Validated order data Raises: ValidationError: If validation fails """ try: # Validate with Pydantic validated = OrderValidator(**order_data) # Additional security checks for high security if self.security_level in [SecurityLevel.HIGH, SecurityLevel.CRITICAL]: # Check for suspicious patterns if self._is_suspicious_order(validated.dict()): raise ValueError("Order flagged as suspicious") return validated.dict() except ValidationError as e: logger.error(f"Order validation failed: {e}") raise def validate_config(self, config_data: Dict[str, Any]) -> Dict[str, Any]: """ Validate configuration data Args: config_data: Raw configuration Returns: Validated configuration Raises: ValidationError: If validation fails """ try: validated = ConfigValidator(**config_data) return validated.dict() except ValidationError as e: logger.error(f"Config validation failed: {e}") raise def _is_suspicious_order(self, order: Dict[str, Any]) -> bool: """ Check for suspicious order patterns Args: order: Order data Returns: True if suspicious """ # Check for unusual quantity if order['quantity'] > 10000: logger.warning(f"Large order quantity: {order['quantity']}") return True # Check for price manipulation attempts if order.get('limit_price'): # Check for penny stock manipulation if order['limit_price'] < 1 and order['quantity'] > 1000: logger.warning("Potential penny stock manipulation") return True return False def sanitize_user_input(self, input_data: Any) -> Any: """ Sanitize any user input Args: input_data: User input Returns: Sanitized input """ if isinstance(input_data, str): # Remove dangerous characters input_data = re.sub(r'[<>&\'"`]', '', input_data) # Truncate to reasonable length input_data = input_data[:1000] elif isinstance(input_data, dict): input_data = XSSPrevention.sanitize_json(input_data) return input_data # === Example Usage === def main(): """Example of using security validators""" # Initialize validator validator = SecurityValidator(SecurityLevel.HIGH) # Validate order order_data = { "ticker": "AAPL", "side": "BUY", "quantity": 100, "order_type": "LIMIT", "limit_price": "150.50", "notes": "Test order " } try: validated_order = validator.validate_order(order_data) print(f"Validated order: {validated_order}") except ValidationError as e: print(f"Validation failed: {e}") # Generate secure API key api_key = APISecurityValidator.generate_api_key() print(f"Generated API key: {api_key}") # Hash for storage hashed = APISecurityValidator.hash_api_key(api_key) print(f"Hashed key: {hashed}") # Rate limiting rate_limiter = RateLimiter() for i in range(5): allowed = rate_limiter.check_rate_limit("user123", max_requests=3, window_seconds=10) print(f"Request {i+1}: {'Allowed' if allowed else 'Blocked'}") if __name__ == "__main__": main()