"""Data caching layer for vendor data with rate limit awareness. This module provides a robust caching layer to handle API rate limits across all data vendors. Features: - Multi-backend support (memory, file, SQLite) - TTL-based expiration with configurable per-source TTLs - Rate limit tracking and backoff - Cache statistics and monitoring - Atomic cache operations for thread safety Issue #12: [DATA-11] Data caching layer - FRED rate limits """ import hashlib import json import logging import sqlite3 import threading import time from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import datetime, timedelta from enum import Enum, auto from pathlib import Path from typing import Any, Callable, Dict, List, Optional, TypeVar, Generic logger = logging.getLogger(__name__) T = TypeVar('T') class CacheStatus(Enum): """Status of a cache lookup.""" HIT = auto() MISS = auto() EXPIRED = auto() STALE = auto() # Expired but returned due to rate limit @dataclass class CacheEntry(Generic[T]): """A single cache entry with metadata.""" key: str value: T created_at: datetime expires_at: datetime access_count: int = 0 last_accessed: Optional[datetime] = None source: str = "" metadata: Dict[str, Any] = field(default_factory=dict) @property def is_expired(self) -> bool: """Check if entry is expired.""" return datetime.now() > self.expires_at @property def age_seconds(self) -> float: """Get age in seconds.""" return (datetime.now() - self.created_at).total_seconds() @property def ttl_remaining_seconds(self) -> float: """Get remaining TTL in seconds.""" return max(0, (self.expires_at - datetime.now()).total_seconds()) def touch(self) -> None: """Update access metadata.""" self.access_count += 1 self.last_accessed = datetime.now() @dataclass class CacheStats: """Statistics for cache operations.""" hits: int = 0 misses: int = 0 expired: int = 0 stale_served: int = 0 evictions: int = 0 size: int = 0 @property def hit_rate(self) -> float: """Calculate hit rate as percentage.""" total = self.hits + self.misses if total == 0: return 0.0 return (self.hits / total) * 100 def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" return { "hits": self.hits, "misses": self.misses, "expired": self.expired, "stale_served": self.stale_served, "evictions": self.evictions, "size": self.size, "hit_rate": self.hit_rate } @dataclass class RateLimitState: """Track rate limit state for a source.""" source: str requests_made: int = 0 requests_limit: int = 120 # Default FRED limit window_start: datetime = field(default_factory=datetime.now) window_seconds: int = 60 backoff_until: Optional[datetime] = None consecutive_failures: int = 0 @property def is_rate_limited(self) -> bool: """Check if currently rate limited.""" if self.backoff_until and datetime.now() < self.backoff_until: return True return False @property def requests_remaining(self) -> int: """Get remaining requests in current window.""" self._maybe_reset_window() return max(0, self.requests_limit - self.requests_made) def _maybe_reset_window(self) -> None: """Reset window if expired.""" if (datetime.now() - self.window_start).total_seconds() > self.window_seconds: self.window_start = datetime.now() self.requests_made = 0 def record_request(self) -> None: """Record a request.""" self._maybe_reset_window() self.requests_made += 1 def record_rate_limit(self, backoff_seconds: int = 60) -> None: """Record a rate limit hit.""" self.consecutive_failures += 1 # Exponential backoff actual_backoff = backoff_seconds * (2 ** (self.consecutive_failures - 1)) self.backoff_until = datetime.now() + timedelta(seconds=actual_backoff) logger.warning(f"Rate limit hit for {self.source}, backing off for {actual_backoff}s") def record_success(self) -> None: """Record successful request.""" self.consecutive_failures = 0 self.backoff_until = None class CacheBackend(ABC): """Abstract base class for cache backends.""" @abstractmethod def get(self, key: str) -> Optional[CacheEntry]: """Get entry from cache.""" pass @abstractmethod def set(self, entry: CacheEntry) -> None: """Set entry in cache.""" pass @abstractmethod def delete(self, key: str) -> bool: """Delete entry from cache.""" pass @abstractmethod def clear(self) -> int: """Clear all entries. Returns number cleared.""" pass @abstractmethod def keys(self) -> List[str]: """Get all cache keys.""" pass @abstractmethod def size(self) -> int: """Get number of entries.""" pass class MemoryCache(CacheBackend): """In-memory cache with LRU eviction.""" def __init__(self, max_size: int = 1000): """Initialize memory cache. Args: max_size: Maximum number of entries """ self._cache: Dict[str, CacheEntry] = {} self._max_size = max_size self._lock = threading.RLock() self._access_order: List[str] = [] def get(self, key: str) -> Optional[CacheEntry]: """Get entry from cache.""" with self._lock: entry = self._cache.get(key) if entry: # Update access order for LRU if key in self._access_order: self._access_order.remove(key) self._access_order.append(key) return entry def set(self, entry: CacheEntry) -> None: """Set entry in cache with LRU eviction.""" with self._lock: # Evict if at capacity while len(self._cache) >= self._max_size and self._access_order: oldest_key = self._access_order.pop(0) self._cache.pop(oldest_key, None) self._cache[entry.key] = entry # Update access order if entry.key in self._access_order: self._access_order.remove(entry.key) self._access_order.append(entry.key) def delete(self, key: str) -> bool: """Delete entry from cache.""" with self._lock: if key in self._cache: del self._cache[key] if key in self._access_order: self._access_order.remove(key) return True return False def clear(self) -> int: """Clear all entries.""" with self._lock: count = len(self._cache) self._cache.clear() self._access_order.clear() return count def keys(self) -> List[str]: """Get all cache keys.""" with self._lock: return list(self._cache.keys()) def size(self) -> int: """Get number of entries.""" with self._lock: return len(self._cache) class FileCache(CacheBackend): """File-based cache using JSON serialization.""" def __init__(self, cache_dir: Optional[Path] = None): """Initialize file cache. Args: cache_dir: Directory for cache files """ self._cache_dir = cache_dir or Path.home() / ".cache" / "tradingagents" self._cache_dir.mkdir(parents=True, exist_ok=True) self._lock = threading.RLock() def _get_path(self, key: str) -> Path: """Get file path for key.""" # Use hash to avoid filesystem issues safe_key = hashlib.md5(key.encode()).hexdigest() return self._cache_dir / f"{safe_key}.json" def get(self, key: str) -> Optional[CacheEntry]: """Get entry from file cache.""" path = self._get_path(key) if not path.exists(): return None with self._lock: try: with open(path, 'r') as f: data = json.load(f) return CacheEntry( key=data['key'], value=data['value'], created_at=datetime.fromisoformat(data['created_at']), expires_at=datetime.fromisoformat(data['expires_at']), access_count=data.get('access_count', 0), last_accessed=datetime.fromisoformat(data['last_accessed']) if data.get('last_accessed') else None, source=data.get('source', ''), metadata=data.get('metadata', {}) ) except (json.JSONDecodeError, KeyError, ValueError): # Corrupted file path.unlink(missing_ok=True) return None def set(self, entry: CacheEntry) -> None: """Set entry in file cache.""" path = self._get_path(entry.key) with self._lock: data = { 'key': entry.key, 'value': entry.value, 'created_at': entry.created_at.isoformat(), 'expires_at': entry.expires_at.isoformat(), 'access_count': entry.access_count, 'last_accessed': entry.last_accessed.isoformat() if entry.last_accessed else None, 'source': entry.source, 'metadata': entry.metadata } with open(path, 'w') as f: json.dump(data, f) def delete(self, key: str) -> bool: """Delete entry from file cache.""" path = self._get_path(key) with self._lock: if path.exists(): path.unlink() return True return False def clear(self) -> int: """Clear all entries.""" with self._lock: count = 0 for path in self._cache_dir.glob("*.json"): path.unlink() count += 1 return count def keys(self) -> List[str]: """Get all cache keys (returns hashed keys).""" return [p.stem for p in self._cache_dir.glob("*.json")] def size(self) -> int: """Get number of entries.""" return len(list(self._cache_dir.glob("*.json"))) class DataCache: """Main data cache with rate limit awareness. Provides caching for vendor data with configurable TTLs, rate limit tracking, and stale-while-revalidate support. Example: cache = DataCache() # Cache with default TTL cache.set("fred:FEDFUNDS", data, source="fred") # Get with stale fallback if rate limited result = cache.get("fred:FEDFUNDS", serve_stale_if_rate_limited=True) # Use as decorator @cache.cached(ttl_seconds=3600, source="fred") def get_fred_data(series_id): return fetch_from_api(series_id) """ # Default TTLs by source (in seconds) DEFAULT_TTLS = { "fred": 3600 * 24, # 24 hours for FRED (data updates daily) "yfinance": 60, # 1 minute for real-time quotes "finnhub": 60, # 1 minute for real-time data "polygon": 300, # 5 minutes "alpha_vantage": 300, # 5 minutes "default": 300 # 5 minutes default } # Default rate limits by source DEFAULT_RATE_LIMITS = { "fred": (120, 60), # 120 requests per 60 seconds "yfinance": (2000, 60), # High limit (throttles internally) "finnhub": (60, 60), # 60 per minute "polygon": (5, 60), # 5 per minute (free tier) "alpha_vantage": (5, 60), # 5 per minute (free tier) "default": (100, 60) } def __init__( self, backend: Optional[CacheBackend] = None, default_ttl_seconds: int = 300 ): """Initialize data cache. Args: backend: Cache backend (defaults to MemoryCache) default_ttl_seconds: Default TTL for entries """ self._backend = backend or MemoryCache() self._default_ttl = default_ttl_seconds self._stats = CacheStats() self._rate_limits: Dict[str, RateLimitState] = {} self._lock = threading.RLock() def _generate_key(self, key: str, **kwargs) -> str: """Generate cache key from key and optional params.""" if kwargs: params_str = json.dumps(kwargs, sort_keys=True) return f"{key}:{hashlib.md5(params_str.encode()).hexdigest()[:8]}" return key def _get_ttl(self, source: str) -> int: """Get TTL for a source.""" return self.DEFAULT_TTLS.get(source, self.DEFAULT_TTLS["default"]) def _get_rate_limit_state(self, source: str) -> RateLimitState: """Get or create rate limit state for source.""" if source not in self._rate_limits: limit, window = self.DEFAULT_RATE_LIMITS.get( source, self.DEFAULT_RATE_LIMITS["default"] ) self._rate_limits[source] = RateLimitState( source=source, requests_limit=limit, window_seconds=window ) return self._rate_limits[source] def get( self, key: str, serve_stale_if_rate_limited: bool = True, **kwargs ) -> tuple[Optional[Any], CacheStatus]: """Get value from cache. Args: key: Cache key serve_stale_if_rate_limited: Return expired value if rate limited **kwargs: Additional key params Returns: Tuple of (value, status) """ full_key = self._generate_key(key, **kwargs) with self._lock: entry = self._backend.get(full_key) if entry is None: self._stats.misses += 1 return None, CacheStatus.MISS if not entry.is_expired: entry.touch() self._backend.set(entry) # Update metadata self._stats.hits += 1 return entry.value, CacheStatus.HIT # Entry is expired self._stats.expired += 1 # Check if we should serve stale if serve_stale_if_rate_limited: rate_state = self._get_rate_limit_state(entry.source) if rate_state.is_rate_limited: self._stats.stale_served += 1 return entry.value, CacheStatus.STALE return None, CacheStatus.EXPIRED def set( self, key: str, value: Any, ttl_seconds: Optional[int] = None, source: str = "default", metadata: Optional[Dict[str, Any]] = None, **kwargs ) -> None: """Set value in cache. Args: key: Cache key value: Value to cache ttl_seconds: TTL in seconds (uses source default if not specified) source: Data source name metadata: Optional metadata **kwargs: Additional key params """ full_key = self._generate_key(key, **kwargs) actual_ttl = ttl_seconds if ttl_seconds is not None else self._get_ttl(source) entry = CacheEntry( key=full_key, value=value, created_at=datetime.now(), expires_at=datetime.now() + timedelta(seconds=actual_ttl), source=source, metadata=metadata or {} ) with self._lock: self._backend.set(entry) self._stats.size = self._backend.size() def delete(self, key: str, **kwargs) -> bool: """Delete value from cache.""" full_key = self._generate_key(key, **kwargs) with self._lock: result = self._backend.delete(full_key) self._stats.size = self._backend.size() return result def clear(self, source: Optional[str] = None) -> int: """Clear cache entries. Args: source: Clear only entries from this source (None = all) Returns: Number of entries cleared """ with self._lock: if source is None: count = self._backend.clear() else: count = 0 for key in self._backend.keys(): entry = self._backend.get(key) if entry and entry.source == source: self._backend.delete(key) count += 1 self._stats.evictions += count self._stats.size = self._backend.size() return count def record_rate_limit(self, source: str, backoff_seconds: int = 60) -> None: """Record a rate limit hit for a source.""" with self._lock: state = self._get_rate_limit_state(source) state.record_rate_limit(backoff_seconds) def record_request(self, source: str) -> None: """Record a request for rate limit tracking.""" with self._lock: state = self._get_rate_limit_state(source) state.record_request() def record_success(self, source: str) -> None: """Record successful request.""" with self._lock: state = self._get_rate_limit_state(source) state.record_success() def is_rate_limited(self, source: str) -> bool: """Check if source is rate limited.""" with self._lock: state = self._get_rate_limit_state(source) return state.is_rate_limited def get_stats(self) -> CacheStats: """Get cache statistics.""" with self._lock: self._stats.size = self._backend.size() return self._stats def cached( self, ttl_seconds: Optional[int] = None, source: str = "default", key_prefix: str = "" ) -> Callable: """Decorator for caching function results. Example: @cache.cached(ttl_seconds=3600, source="fred") def get_fred_data(series_id): return fetch_from_api(series_id) """ def decorator(func: Callable) -> Callable: def wrapper(*args, **kwargs): # Generate cache key from function name and args key_parts = [key_prefix, func.__name__] if args: key_parts.append(str(args)) if kwargs: key_parts.append(json.dumps(kwargs, sort_keys=True)) cache_key = ":".join(filter(None, key_parts)) # Check cache value, status = self.get(cache_key, serve_stale_if_rate_limited=True) if status in (CacheStatus.HIT, CacheStatus.STALE): return value # Execute function result = func(*args, **kwargs) # Cache result self.set(cache_key, result, ttl_seconds=ttl_seconds, source=source) return result return wrapper return decorator # Global cache instance _global_cache: Optional[DataCache] = None _global_cache_lock = threading.Lock() def get_cache() -> DataCache: """Get the global cache instance.""" global _global_cache with _global_cache_lock: if _global_cache is None: _global_cache = DataCache() return _global_cache def reset_cache() -> None: """Reset the global cache instance. For testing.""" global _global_cache with _global_cache_lock: _global_cache = None