TradingAgents/tradingagents/security/validators.py

264 lines
7.2 KiB
Python

"""
Input validation and sanitization functions.
"""
import re
from datetime import datetime
from typing import Optional
import os
def validate_ticker(ticker: str, max_length: int = 10) -> str:
"""
Validate and sanitize stock ticker symbol.
Args:
ticker: Ticker symbol to validate
max_length: Maximum allowed length for ticker
Returns:
Sanitized ticker symbol in uppercase
Raises:
ValueError: If ticker is invalid
Examples:
>>> validate_ticker("AAPL")
'AAPL'
>>> validate_ticker("nvda")
'NVDA'
>>> validate_ticker("../etc/passwd")
Traceback (most recent call last):
ValueError: Invalid ticker symbol...
"""
if not ticker:
raise ValueError("Ticker symbol cannot be empty")
if not isinstance(ticker, str):
raise ValueError("Ticker symbol must be a string")
# Remove whitespace
ticker = ticker.strip().upper()
# Check length
if len(ticker) > max_length:
raise ValueError(f"Ticker symbol too long (max {max_length} characters)")
# Only allow alphanumeric characters, dots, and hyphens (common in international tickers)
# Examples: AAPL, BRK.A, RDS-B
if not re.match(r'^[A-Z0-9.-]+$', ticker):
raise ValueError(
"Invalid ticker symbol. Only alphanumeric characters, dots, and hyphens are allowed"
)
# Prevent path traversal
if '..' in ticker or '/' in ticker or '\\' in ticker:
raise ValueError("Invalid ticker symbol: path traversal detected")
return ticker
def validate_date(date_str: str, allow_future: bool = False) -> str:
"""
Validate date string.
Args:
date_str: Date string in YYYY-MM-DD format
allow_future: Whether to allow future dates
Returns:
Validated date string
Raises:
ValueError: If date is invalid
Examples:
>>> validate_date("2024-01-15")
'2024-01-15'
>>> validate_date("2024-13-01")
Traceback (most recent call last):
ValueError: Invalid date format...
"""
if not date_str:
raise ValueError("Date cannot be empty")
if not isinstance(date_str, str):
raise ValueError("Date must be a string")
# Remove whitespace
date_str = date_str.strip()
# Validate format and parse
try:
date_obj = datetime.strptime(date_str, "%Y-%m-%d")
except ValueError as e:
raise ValueError(f"Invalid date format. Use YYYY-MM-DD: {e}")
# Check if date is in the future
if not allow_future and date_obj.date() > datetime.now().date():
raise ValueError("Date cannot be in the future")
# Check if date is too far in the past (before stock markets existed)
if date_obj.year < 1900:
raise ValueError("Date cannot be before 1900")
# Prevent path traversal via date
if '..' in date_str or '/' in date_str or '\\' in date_str:
raise ValueError("Invalid date: path traversal detected")
return date_str
def sanitize_path_component(value: str, max_length: int = 255) -> str:
"""
Sanitize a value for safe use in file paths.
Args:
value: Value to sanitize
max_length: Maximum allowed length
Returns:
Sanitized value safe for use in file paths
Examples:
>>> sanitize_path_component("AAPL")
'AAPL'
>>> sanitize_path_component("../../../etc/passwd")
'etcpasswd'
>>> sanitize_path_component("2024-01-15")
'2024-01-15'
"""
if not value:
raise ValueError("Path component cannot be empty")
if not isinstance(value, str):
value = str(value)
# Remove path traversal attempts
value = value.replace('..', '')
# Remove path separators
value = value.replace('/', '').replace('\\', '')
# Remove null bytes
value = value.replace('\0', '')
# Allow only safe characters: alphanumeric, dash, underscore, dot
# This allows dates (2024-01-15) and tickers (AAPL, BRK.A)
value = re.sub(r'[^a-zA-Z0-9_.-]', '_', value)
# Remove leading/trailing dots or dashes
value = value.strip('.-')
# Check length
if len(value) > max_length:
raise ValueError(f"Path component too long (max {max_length} characters)")
if not value:
raise ValueError("Path component cannot be empty after sanitization")
return value
def validate_api_key(api_key: Optional[str], key_name: str = "API_KEY") -> str:
"""
Validate that an API key is set and not empty.
Args:
api_key: API key to validate
key_name: Name of the API key (for error messages)
Returns:
The validated API key
Raises:
ValueError: If API key is not set or empty
Examples:
>>> validate_api_key("sk-1234567890", "OPENAI_API_KEY")
'sk-1234567890'
>>> validate_api_key(None, "OPENAI_API_KEY")
Traceback (most recent call last):
ValueError: OPENAI_API_KEY is not set...
"""
if not api_key:
raise ValueError(
f"{key_name} is not set. "
f"Please set it in your .env file or environment variables."
)
if not isinstance(api_key, str):
raise ValueError(f"{key_name} must be a string")
# Remove whitespace
api_key = api_key.strip()
if not api_key:
raise ValueError(f"{key_name} cannot be empty")
# Warn if API key looks suspicious (too short, contains spaces, etc.)
if len(api_key) < 10:
import warnings
warnings.warn(
f"{key_name} seems unusually short. Please verify it's correct.",
UserWarning
)
if ' ' in api_key:
raise ValueError(f"{key_name} should not contain spaces")
return api_key
def validate_url(url: str, allowed_schemes: list = None) -> str:
"""
Validate URL to prevent SSRF and other URL-based attacks.
Args:
url: URL to validate
allowed_schemes: List of allowed URL schemes (default: ['http', 'https'])
Returns:
Validated URL
Raises:
ValueError: If URL is invalid or uses disallowed scheme
"""
from urllib.parse import urlparse
if allowed_schemes is None:
allowed_schemes = ['http', 'https']
if not url:
raise ValueError("URL cannot be empty")
try:
parsed = urlparse(url)
except Exception as e:
raise ValueError(f"Invalid URL: {e}")
# Check scheme
if parsed.scheme not in allowed_schemes:
raise ValueError(
f"Invalid URL scheme: {parsed.scheme}. "
f"Allowed schemes: {', '.join(allowed_schemes)}"
)
# Prevent localhost/private IP access (SSRF protection)
if parsed.hostname:
import ipaddress
try:
ip = ipaddress.ip_address(parsed.hostname)
if ip.is_private or ip.is_loopback:
raise ValueError("Access to private/loopback addresses is not allowed")
except ValueError:
# Not an IP address, that's fine
pass
# Block common private network hostnames
private_hostnames = ['localhost', '127.0.0.1', '0.0.0.0', '::1']
if parsed.hostname.lower() in private_hostnames:
raise ValueError("Access to localhost is not allowed")
return url