TradingAgents/tradingagents/dataflows/cache.py

629 lines
19 KiB
Python

"""Data caching layer for vendor data with rate limit awareness.
This module provides a robust caching layer to handle API rate limits across
all data vendors. Features:
- Multi-backend support (memory, file, SQLite)
- TTL-based expiration with configurable per-source TTLs
- Rate limit tracking and backoff
- Cache statistics and monitoring
- Atomic cache operations for thread safety
Issue #12: [DATA-11] Data caching layer - FRED rate limits
"""
import hashlib
import json
import logging
import sqlite3
import threading
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum, auto
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, TypeVar, Generic
logger = logging.getLogger(__name__)
T = TypeVar('T')
class CacheStatus(Enum):
"""Status of a cache lookup."""
HIT = auto()
MISS = auto()
EXPIRED = auto()
STALE = auto() # Expired but returned due to rate limit
@dataclass
class CacheEntry(Generic[T]):
"""A single cache entry with metadata."""
key: str
value: T
created_at: datetime
expires_at: datetime
access_count: int = 0
last_accessed: Optional[datetime] = None
source: str = ""
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def is_expired(self) -> bool:
"""Check if entry is expired."""
return datetime.now() > self.expires_at
@property
def age_seconds(self) -> float:
"""Get age in seconds."""
return (datetime.now() - self.created_at).total_seconds()
@property
def ttl_remaining_seconds(self) -> float:
"""Get remaining TTL in seconds."""
return max(0, (self.expires_at - datetime.now()).total_seconds())
def touch(self) -> None:
"""Update access metadata."""
self.access_count += 1
self.last_accessed = datetime.now()
@dataclass
class CacheStats:
"""Statistics for cache operations."""
hits: int = 0
misses: int = 0
expired: int = 0
stale_served: int = 0
evictions: int = 0
size: int = 0
@property
def hit_rate(self) -> float:
"""Calculate hit rate as percentage."""
total = self.hits + self.misses
if total == 0:
return 0.0
return (self.hits / total) * 100
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"hits": self.hits,
"misses": self.misses,
"expired": self.expired,
"stale_served": self.stale_served,
"evictions": self.evictions,
"size": self.size,
"hit_rate": self.hit_rate
}
@dataclass
class RateLimitState:
"""Track rate limit state for a source."""
source: str
requests_made: int = 0
requests_limit: int = 120 # Default FRED limit
window_start: datetime = field(default_factory=datetime.now)
window_seconds: int = 60
backoff_until: Optional[datetime] = None
consecutive_failures: int = 0
@property
def is_rate_limited(self) -> bool:
"""Check if currently rate limited."""
if self.backoff_until and datetime.now() < self.backoff_until:
return True
return False
@property
def requests_remaining(self) -> int:
"""Get remaining requests in current window."""
self._maybe_reset_window()
return max(0, self.requests_limit - self.requests_made)
def _maybe_reset_window(self) -> None:
"""Reset window if expired."""
if (datetime.now() - self.window_start).total_seconds() > self.window_seconds:
self.window_start = datetime.now()
self.requests_made = 0
def record_request(self) -> None:
"""Record a request."""
self._maybe_reset_window()
self.requests_made += 1
def record_rate_limit(self, backoff_seconds: int = 60) -> None:
"""Record a rate limit hit."""
self.consecutive_failures += 1
# Exponential backoff
actual_backoff = backoff_seconds * (2 ** (self.consecutive_failures - 1))
self.backoff_until = datetime.now() + timedelta(seconds=actual_backoff)
logger.warning(f"Rate limit hit for {self.source}, backing off for {actual_backoff}s")
def record_success(self) -> None:
"""Record successful request."""
self.consecutive_failures = 0
self.backoff_until = None
class CacheBackend(ABC):
"""Abstract base class for cache backends."""
@abstractmethod
def get(self, key: str) -> Optional[CacheEntry]:
"""Get entry from cache."""
pass
@abstractmethod
def set(self, entry: CacheEntry) -> None:
"""Set entry in cache."""
pass
@abstractmethod
def delete(self, key: str) -> bool:
"""Delete entry from cache."""
pass
@abstractmethod
def clear(self) -> int:
"""Clear all entries. Returns number cleared."""
pass
@abstractmethod
def keys(self) -> List[str]:
"""Get all cache keys."""
pass
@abstractmethod
def size(self) -> int:
"""Get number of entries."""
pass
class MemoryCache(CacheBackend):
"""In-memory cache with LRU eviction."""
def __init__(self, max_size: int = 1000):
"""Initialize memory cache.
Args:
max_size: Maximum number of entries
"""
self._cache: Dict[str, CacheEntry] = {}
self._max_size = max_size
self._lock = threading.RLock()
self._access_order: List[str] = []
def get(self, key: str) -> Optional[CacheEntry]:
"""Get entry from cache."""
with self._lock:
entry = self._cache.get(key)
if entry:
# Update access order for LRU
if key in self._access_order:
self._access_order.remove(key)
self._access_order.append(key)
return entry
def set(self, entry: CacheEntry) -> None:
"""Set entry in cache with LRU eviction."""
with self._lock:
# Evict if at capacity
while len(self._cache) >= self._max_size and self._access_order:
oldest_key = self._access_order.pop(0)
self._cache.pop(oldest_key, None)
self._cache[entry.key] = entry
# Update access order
if entry.key in self._access_order:
self._access_order.remove(entry.key)
self._access_order.append(entry.key)
def delete(self, key: str) -> bool:
"""Delete entry from cache."""
with self._lock:
if key in self._cache:
del self._cache[key]
if key in self._access_order:
self._access_order.remove(key)
return True
return False
def clear(self) -> int:
"""Clear all entries."""
with self._lock:
count = len(self._cache)
self._cache.clear()
self._access_order.clear()
return count
def keys(self) -> List[str]:
"""Get all cache keys."""
with self._lock:
return list(self._cache.keys())
def size(self) -> int:
"""Get number of entries."""
with self._lock:
return len(self._cache)
class FileCache(CacheBackend):
"""File-based cache using JSON serialization."""
def __init__(self, cache_dir: Optional[Path] = None):
"""Initialize file cache.
Args:
cache_dir: Directory for cache files
"""
self._cache_dir = cache_dir or Path.home() / ".cache" / "tradingagents"
self._cache_dir.mkdir(parents=True, exist_ok=True)
self._lock = threading.RLock()
def _get_path(self, key: str) -> Path:
"""Get file path for key."""
# Use hash to avoid filesystem issues
safe_key = hashlib.md5(key.encode()).hexdigest()
return self._cache_dir / f"{safe_key}.json"
def get(self, key: str) -> Optional[CacheEntry]:
"""Get entry from file cache."""
path = self._get_path(key)
if not path.exists():
return None
with self._lock:
try:
with open(path, 'r') as f:
data = json.load(f)
return CacheEntry(
key=data['key'],
value=data['value'],
created_at=datetime.fromisoformat(data['created_at']),
expires_at=datetime.fromisoformat(data['expires_at']),
access_count=data.get('access_count', 0),
last_accessed=datetime.fromisoformat(data['last_accessed']) if data.get('last_accessed') else None,
source=data.get('source', ''),
metadata=data.get('metadata', {})
)
except (json.JSONDecodeError, KeyError, ValueError):
# Corrupted file
path.unlink(missing_ok=True)
return None
def set(self, entry: CacheEntry) -> None:
"""Set entry in file cache."""
path = self._get_path(entry.key)
with self._lock:
data = {
'key': entry.key,
'value': entry.value,
'created_at': entry.created_at.isoformat(),
'expires_at': entry.expires_at.isoformat(),
'access_count': entry.access_count,
'last_accessed': entry.last_accessed.isoformat() if entry.last_accessed else None,
'source': entry.source,
'metadata': entry.metadata
}
with open(path, 'w') as f:
json.dump(data, f)
def delete(self, key: str) -> bool:
"""Delete entry from file cache."""
path = self._get_path(key)
with self._lock:
if path.exists():
path.unlink()
return True
return False
def clear(self) -> int:
"""Clear all entries."""
with self._lock:
count = 0
for path in self._cache_dir.glob("*.json"):
path.unlink()
count += 1
return count
def keys(self) -> List[str]:
"""Get all cache keys (returns hashed keys)."""
return [p.stem for p in self._cache_dir.glob("*.json")]
def size(self) -> int:
"""Get number of entries."""
return len(list(self._cache_dir.glob("*.json")))
class DataCache:
"""Main data cache with rate limit awareness.
Provides caching for vendor data with configurable TTLs,
rate limit tracking, and stale-while-revalidate support.
Example:
cache = DataCache()
# Cache with default TTL
cache.set("fred:FEDFUNDS", data, source="fred")
# Get with stale fallback if rate limited
result = cache.get("fred:FEDFUNDS", serve_stale_if_rate_limited=True)
# Use as decorator
@cache.cached(ttl_seconds=3600, source="fred")
def get_fred_data(series_id):
return fetch_from_api(series_id)
"""
# Default TTLs by source (in seconds)
DEFAULT_TTLS = {
"fred": 3600 * 24, # 24 hours for FRED (data updates daily)
"yfinance": 60, # 1 minute for real-time quotes
"finnhub": 60, # 1 minute for real-time data
"polygon": 300, # 5 minutes
"alpha_vantage": 300, # 5 minutes
"default": 300 # 5 minutes default
}
# Default rate limits by source
DEFAULT_RATE_LIMITS = {
"fred": (120, 60), # 120 requests per 60 seconds
"yfinance": (2000, 60), # High limit (throttles internally)
"finnhub": (60, 60), # 60 per minute
"polygon": (5, 60), # 5 per minute (free tier)
"alpha_vantage": (5, 60), # 5 per minute (free tier)
"default": (100, 60)
}
def __init__(
self,
backend: Optional[CacheBackend] = None,
default_ttl_seconds: int = 300
):
"""Initialize data cache.
Args:
backend: Cache backend (defaults to MemoryCache)
default_ttl_seconds: Default TTL for entries
"""
self._backend = backend or MemoryCache()
self._default_ttl = default_ttl_seconds
self._stats = CacheStats()
self._rate_limits: Dict[str, RateLimitState] = {}
self._lock = threading.RLock()
def _generate_key(self, key: str, **kwargs) -> str:
"""Generate cache key from key and optional params."""
if kwargs:
params_str = json.dumps(kwargs, sort_keys=True)
return f"{key}:{hashlib.md5(params_str.encode()).hexdigest()[:8]}"
return key
def _get_ttl(self, source: str) -> int:
"""Get TTL for a source."""
return self.DEFAULT_TTLS.get(source, self.DEFAULT_TTLS["default"])
def _get_rate_limit_state(self, source: str) -> RateLimitState:
"""Get or create rate limit state for source."""
if source not in self._rate_limits:
limit, window = self.DEFAULT_RATE_LIMITS.get(
source,
self.DEFAULT_RATE_LIMITS["default"]
)
self._rate_limits[source] = RateLimitState(
source=source,
requests_limit=limit,
window_seconds=window
)
return self._rate_limits[source]
def get(
self,
key: str,
serve_stale_if_rate_limited: bool = True,
**kwargs
) -> tuple[Optional[Any], CacheStatus]:
"""Get value from cache.
Args:
key: Cache key
serve_stale_if_rate_limited: Return expired value if rate limited
**kwargs: Additional key params
Returns:
Tuple of (value, status)
"""
full_key = self._generate_key(key, **kwargs)
with self._lock:
entry = self._backend.get(full_key)
if entry is None:
self._stats.misses += 1
return None, CacheStatus.MISS
if not entry.is_expired:
entry.touch()
self._backend.set(entry) # Update metadata
self._stats.hits += 1
return entry.value, CacheStatus.HIT
# Entry is expired
self._stats.expired += 1
# Check if we should serve stale
if serve_stale_if_rate_limited:
rate_state = self._get_rate_limit_state(entry.source)
if rate_state.is_rate_limited:
self._stats.stale_served += 1
return entry.value, CacheStatus.STALE
return None, CacheStatus.EXPIRED
def set(
self,
key: str,
value: Any,
ttl_seconds: Optional[int] = None,
source: str = "default",
metadata: Optional[Dict[str, Any]] = None,
**kwargs
) -> None:
"""Set value in cache.
Args:
key: Cache key
value: Value to cache
ttl_seconds: TTL in seconds (uses source default if not specified)
source: Data source name
metadata: Optional metadata
**kwargs: Additional key params
"""
full_key = self._generate_key(key, **kwargs)
actual_ttl = ttl_seconds if ttl_seconds is not None else self._get_ttl(source)
entry = CacheEntry(
key=full_key,
value=value,
created_at=datetime.now(),
expires_at=datetime.now() + timedelta(seconds=actual_ttl),
source=source,
metadata=metadata or {}
)
with self._lock:
self._backend.set(entry)
self._stats.size = self._backend.size()
def delete(self, key: str, **kwargs) -> bool:
"""Delete value from cache."""
full_key = self._generate_key(key, **kwargs)
with self._lock:
result = self._backend.delete(full_key)
self._stats.size = self._backend.size()
return result
def clear(self, source: Optional[str] = None) -> int:
"""Clear cache entries.
Args:
source: Clear only entries from this source (None = all)
Returns:
Number of entries cleared
"""
with self._lock:
if source is None:
count = self._backend.clear()
else:
count = 0
for key in self._backend.keys():
entry = self._backend.get(key)
if entry and entry.source == source:
self._backend.delete(key)
count += 1
self._stats.evictions += count
self._stats.size = self._backend.size()
return count
def record_rate_limit(self, source: str, backoff_seconds: int = 60) -> None:
"""Record a rate limit hit for a source."""
with self._lock:
state = self._get_rate_limit_state(source)
state.record_rate_limit(backoff_seconds)
def record_request(self, source: str) -> None:
"""Record a request for rate limit tracking."""
with self._lock:
state = self._get_rate_limit_state(source)
state.record_request()
def record_success(self, source: str) -> None:
"""Record successful request."""
with self._lock:
state = self._get_rate_limit_state(source)
state.record_success()
def is_rate_limited(self, source: str) -> bool:
"""Check if source is rate limited."""
with self._lock:
state = self._get_rate_limit_state(source)
return state.is_rate_limited
def get_stats(self) -> CacheStats:
"""Get cache statistics."""
with self._lock:
self._stats.size = self._backend.size()
return self._stats
def cached(
self,
ttl_seconds: Optional[int] = None,
source: str = "default",
key_prefix: str = ""
) -> Callable:
"""Decorator for caching function results.
Example:
@cache.cached(ttl_seconds=3600, source="fred")
def get_fred_data(series_id):
return fetch_from_api(series_id)
"""
def decorator(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
# Generate cache key from function name and args
key_parts = [key_prefix, func.__name__]
if args:
key_parts.append(str(args))
if kwargs:
key_parts.append(json.dumps(kwargs, sort_keys=True))
cache_key = ":".join(filter(None, key_parts))
# Check cache
value, status = self.get(cache_key, serve_stale_if_rate_limited=True)
if status in (CacheStatus.HIT, CacheStatus.STALE):
return value
# Execute function
result = func(*args, **kwargs)
# Cache result
self.set(cache_key, result, ttl_seconds=ttl_seconds, source=source)
return result
return wrapper
return decorator
# Global cache instance
_global_cache: Optional[DataCache] = None
_global_cache_lock = threading.Lock()
def get_cache() -> DataCache:
"""Get the global cache instance."""
global _global_cache
with _global_cache_lock:
if _global_cache is None:
_global_cache = DataCache()
return _global_cache
def reset_cache() -> None:
"""Reset the global cache instance. For testing."""
global _global_cache
with _global_cache_lock:
_global_cache = None