diff --git a/tests/unit/dataflows/test_cache.py b/tests/unit/dataflows/test_cache.py new file mode 100644 index 00000000..7eaa31d0 --- /dev/null +++ b/tests/unit/dataflows/test_cache.py @@ -0,0 +1,618 @@ +"""Tests for data caching layer. + +Issue #12: [DATA-11] Data caching layer - FRED rate limits +""" + +import pytest +import time +import threading +import tempfile +from datetime import datetime, timedelta +from pathlib import Path +from unittest.mock import Mock, patch + +from tradingagents.dataflows.cache import ( + CacheEntry, + CacheStats, + CacheStatus, + RateLimitState, + MemoryCache, + FileCache, + DataCache, + get_cache, + reset_cache, +) + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def reset_global_cache(): + """Reset global cache before each test.""" + reset_cache() + yield + reset_cache() + + +class TestCacheEntry: + """Tests for CacheEntry dataclass.""" + + def test_entry_creation(self): + """Test creating a cache entry.""" + entry = CacheEntry( + key="test_key", + value={"data": "test"}, + created_at=datetime.now(), + expires_at=datetime.now() + timedelta(hours=1), + source="fred" + ) + + assert entry.key == "test_key" + assert entry.value == {"data": "test"} + assert entry.source == "fred" + assert entry.access_count == 0 + + def test_is_expired_false(self): + """Test is_expired returns False for valid entry.""" + entry = CacheEntry( + key="test", + value="data", + created_at=datetime.now(), + expires_at=datetime.now() + timedelta(hours=1) + ) + assert entry.is_expired is False + + def test_is_expired_true(self): + """Test is_expired returns True for expired entry.""" + entry = CacheEntry( + key="test", + value="data", + created_at=datetime.now() - timedelta(hours=2), + expires_at=datetime.now() - timedelta(hours=1) + ) + assert entry.is_expired is True + + def test_age_seconds(self): + """Test age_seconds calculation.""" + entry = CacheEntry( + key="test", + value="data", + created_at=datetime.now() - timedelta(seconds=60), + expires_at=datetime.now() + timedelta(hours=1) + ) + assert 59 < entry.age_seconds < 61 + + def test_ttl_remaining(self): + """Test ttl_remaining_seconds calculation.""" + entry = CacheEntry( + key="test", + value="data", + created_at=datetime.now(), + expires_at=datetime.now() + timedelta(seconds=3600) + ) + assert 3599 < entry.ttl_remaining_seconds <= 3600 + + def test_touch_updates_metadata(self): + """Test touch updates access metadata.""" + entry = CacheEntry( + key="test", + value="data", + created_at=datetime.now(), + expires_at=datetime.now() + timedelta(hours=1) + ) + + assert entry.access_count == 0 + assert entry.last_accessed is None + + entry.touch() + + assert entry.access_count == 1 + assert entry.last_accessed is not None + + +class TestCacheStats: + """Tests for CacheStats dataclass.""" + + def test_default_values(self): + """Test default values.""" + stats = CacheStats() + assert stats.hits == 0 + assert stats.misses == 0 + assert stats.hit_rate == 0.0 + + def test_hit_rate_calculation(self): + """Test hit rate calculation.""" + stats = CacheStats(hits=75, misses=25) + assert stats.hit_rate == 75.0 + + def test_hit_rate_no_requests(self): + """Test hit rate with no requests.""" + stats = CacheStats() + assert stats.hit_rate == 0.0 + + def test_to_dict(self): + """Test conversion to dictionary.""" + stats = CacheStats(hits=10, misses=5, evictions=2) + d = stats.to_dict() + + assert d["hits"] == 10 + assert d["misses"] == 5 + assert d["evictions"] == 2 + assert "hit_rate" in d + + +class TestRateLimitState: + """Tests for RateLimitState dataclass.""" + + def test_default_values(self): + """Test default values.""" + state = RateLimitState(source="fred") + assert state.source == "fred" + assert state.requests_made == 0 + assert state.is_rate_limited is False + + def test_record_request(self): + """Test recording requests.""" + state = RateLimitState(source="test", requests_limit=5) + + for i in range(3): + state.record_request() + + assert state.requests_made == 3 + assert state.requests_remaining == 2 + + def test_is_rate_limited_after_backoff(self): + """Test rate limiting after recording limit hit.""" + state = RateLimitState(source="test") + + assert state.is_rate_limited is False + + state.record_rate_limit(backoff_seconds=1) + + assert state.is_rate_limited is True + + # Wait for backoff to expire + time.sleep(1.1) + assert state.is_rate_limited is False + + def test_record_success_clears_backoff(self): + """Test that success clears backoff.""" + state = RateLimitState(source="test") + state.record_rate_limit(backoff_seconds=60) + assert state.is_rate_limited is True + + state.record_success() + assert state.is_rate_limited is False + assert state.consecutive_failures == 0 + + def test_exponential_backoff(self): + """Test exponential backoff on consecutive failures.""" + state = RateLimitState(source="test") + + # First failure - 1 second backoff + state.record_rate_limit(backoff_seconds=1) + assert state.consecutive_failures == 1 + + # Simulate recovery + state.backoff_until = None + + # Second failure - 2 second backoff + state.record_rate_limit(backoff_seconds=1) + assert state.consecutive_failures == 2 + + +class TestMemoryCache: + """Tests for MemoryCache backend.""" + + def test_get_set(self): + """Test basic get/set operations.""" + cache = MemoryCache() + + entry = CacheEntry( + key="test", + value="data", + created_at=datetime.now(), + expires_at=datetime.now() + timedelta(hours=1) + ) + + cache.set(entry) + result = cache.get("test") + + assert result is not None + assert result.value == "data" + + def test_get_missing(self): + """Test getting missing key.""" + cache = MemoryCache() + assert cache.get("nonexistent") is None + + def test_delete(self): + """Test deleting entry.""" + cache = MemoryCache() + + entry = CacheEntry( + key="test", + value="data", + created_at=datetime.now(), + expires_at=datetime.now() + timedelta(hours=1) + ) + + cache.set(entry) + assert cache.delete("test") is True + assert cache.get("test") is None + + def test_delete_missing(self): + """Test deleting missing key.""" + cache = MemoryCache() + assert cache.delete("nonexistent") is False + + def test_clear(self): + """Test clearing cache.""" + cache = MemoryCache() + + for i in range(5): + cache.set(CacheEntry( + key=f"key_{i}", + value=f"value_{i}", + created_at=datetime.now(), + expires_at=datetime.now() + timedelta(hours=1) + )) + + count = cache.clear() + assert count == 5 + assert cache.size() == 0 + + def test_lru_eviction(self): + """Test LRU eviction when at capacity.""" + cache = MemoryCache(max_size=3) + + # Add 3 entries + for i in range(3): + cache.set(CacheEntry( + key=f"key_{i}", + value=f"value_{i}", + created_at=datetime.now(), + expires_at=datetime.now() + timedelta(hours=1) + )) + + # Access key_1 to make it recently used + cache.get("key_1") + + # Add new entry, should evict key_0 (least recently used) + cache.set(CacheEntry( + key="key_3", + value="value_3", + created_at=datetime.now(), + expires_at=datetime.now() + timedelta(hours=1) + )) + + assert cache.size() == 3 + assert cache.get("key_0") is None + assert cache.get("key_1") is not None + + def test_thread_safety(self): + """Test thread-safe operations.""" + cache = MemoryCache() + errors = [] + + def write_entries(start): + try: + for i in range(100): + cache.set(CacheEntry( + key=f"key_{start}_{i}", + value=f"value_{start}_{i}", + created_at=datetime.now(), + expires_at=datetime.now() + timedelta(hours=1) + )) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=write_entries, args=(i,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0 + assert cache.size() == 500 + + +class TestFileCache: + """Tests for FileCache backend.""" + + def test_get_set(self): + """Test basic get/set operations.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = FileCache(cache_dir=Path(tmpdir)) + + entry = CacheEntry( + key="test", + value={"data": "test"}, + created_at=datetime.now(), + expires_at=datetime.now() + timedelta(hours=1), + source="test" + ) + + cache.set(entry) + result = cache.get("test") + + assert result is not None + assert result.value == {"data": "test"} + assert result.source == "test" + + def test_get_missing(self): + """Test getting missing key.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = FileCache(cache_dir=Path(tmpdir)) + assert cache.get("nonexistent") is None + + def test_delete(self): + """Test deleting entry.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = FileCache(cache_dir=Path(tmpdir)) + + entry = CacheEntry( + key="test", + value="data", + created_at=datetime.now(), + expires_at=datetime.now() + timedelta(hours=1) + ) + + cache.set(entry) + assert cache.delete("test") is True + assert cache.get("test") is None + + def test_clear(self): + """Test clearing cache.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = FileCache(cache_dir=Path(tmpdir)) + + for i in range(3): + cache.set(CacheEntry( + key=f"key_{i}", + value=f"value_{i}", + created_at=datetime.now(), + expires_at=datetime.now() + timedelta(hours=1) + )) + + count = cache.clear() + assert count == 3 + assert cache.size() == 0 + + +class TestDataCache: + """Tests for DataCache main class.""" + + def test_get_set_basic(self): + """Test basic get/set operations.""" + cache = DataCache() + + cache.set("test_key", {"data": "test"}, source="fred") + value, status = cache.get("test_key") + + assert status == CacheStatus.HIT + assert value == {"data": "test"} + + def test_get_miss(self): + """Test cache miss.""" + cache = DataCache() + + value, status = cache.get("nonexistent") + + assert status == CacheStatus.MISS + assert value is None + + def test_get_expired(self): + """Test getting expired entry.""" + cache = DataCache() + + # Set with very short TTL + cache.set("test", "data", ttl_seconds=0, source="test") + + # Wait for expiration + time.sleep(0.1) + + value, status = cache.get("test", serve_stale_if_rate_limited=False) + + assert status == CacheStatus.EXPIRED + assert value is None + + def test_serve_stale_when_rate_limited(self): + """Test serving stale data when rate limited.""" + cache = DataCache() + + # Set entry that will expire + cache.set("test", "stale_data", ttl_seconds=0, source="test") + time.sleep(0.1) + + # Simulate rate limit + cache.record_rate_limit("test", backoff_seconds=60) + + # Should get stale data + value, status = cache.get("test", serve_stale_if_rate_limited=True) + + assert status == CacheStatus.STALE + assert value == "stale_data" + + def test_delete(self): + """Test deleting entry.""" + cache = DataCache() + + cache.set("test", "data", source="test") + assert cache.delete("test") is True + + value, status = cache.get("test") + assert status == CacheStatus.MISS + + def test_clear_all(self): + """Test clearing all entries.""" + cache = DataCache() + + cache.set("key1", "value1", source="fred") + cache.set("key2", "value2", source="yfinance") + + count = cache.clear() + assert count == 2 + + def test_clear_by_source(self): + """Test clearing entries by source.""" + cache = DataCache() + + cache.set("fred_key", "fred_data", source="fred") + cache.set("yf_key", "yf_data", source="yfinance") + + count = cache.clear(source="fred") + assert count == 1 + + # yfinance entry should still exist + value, status = cache.get("yf_key") + assert status == CacheStatus.HIT + + def test_key_with_params(self): + """Test key generation with params.""" + cache = DataCache() + + cache.set("series", "data1", source="fred", series_id="FEDFUNDS") + cache.set("series", "data2", source="fred", series_id="DGS10") + + value1, _ = cache.get("series", series_id="FEDFUNDS") + value2, _ = cache.get("series", series_id="DGS10") + + assert value1 == "data1" + assert value2 == "data2" + + def test_stats_tracking(self): + """Test statistics tracking.""" + cache = DataCache() + + # Miss + cache.get("missing") + + # Hit + cache.set("present", "data", source="test") + cache.get("present") + cache.get("present") + + stats = cache.get_stats() + assert stats.misses == 1 + assert stats.hits == 2 + + def test_rate_limit_tracking(self): + """Test rate limit state tracking.""" + cache = DataCache() + + assert cache.is_rate_limited("fred") is False + + cache.record_rate_limit("fred", backoff_seconds=1) + assert cache.is_rate_limited("fred") is True + + time.sleep(1.1) + assert cache.is_rate_limited("fred") is False + + def test_cached_decorator(self): + """Test @cached decorator.""" + cache = DataCache() + call_count = [0] + + @cache.cached(ttl_seconds=300, source="test") + def expensive_function(x): + call_count[0] += 1 + return x * 2 + + # First call - executes function + result1 = expensive_function(5) + assert result1 == 10 + assert call_count[0] == 1 + + # Second call - from cache + result2 = expensive_function(5) + assert result2 == 10 + assert call_count[0] == 1 + + # Different argument - executes function + result3 = expensive_function(10) + assert result3 == 20 + assert call_count[0] == 2 + + def test_default_ttls_by_source(self): + """Test default TTLs are applied by source.""" + cache = DataCache() + + # FRED default is 24 hours + cache.set("fred_data", "data", source="fred") + entry = cache._backend.get(cache._generate_key("fred_data")) + + # Should have ~24 hour TTL + assert entry.ttl_remaining_seconds > 3600 * 23 + + +class TestGlobalCache: + """Tests for global cache functions.""" + + def test_get_cache_singleton(self): + """Test get_cache returns singleton.""" + cache1 = get_cache() + cache2 = get_cache() + assert cache1 is cache2 + + def test_reset_cache(self): + """Test reset_cache creates new instance.""" + cache1 = get_cache() + reset_cache() + cache2 = get_cache() + assert cache1 is not cache2 + + +class TestCacheIntegration: + """Integration tests for cache with rate limiting.""" + + def test_rate_limited_fetch_pattern(self): + """Test typical pattern: cache + rate limit handling.""" + cache = DataCache() + fetch_count = [0] + + def fetch_data(key): + """Simulate data fetch with rate limit.""" + # Check rate limit first + if cache.is_rate_limited("api"): + # Try stale cache + value, status = cache.get(key, serve_stale_if_rate_limited=True) + if status == CacheStatus.STALE: + return value + raise RuntimeError("Rate limited and no stale data") + + # Check cache + value, status = cache.get(key) + if status == CacheStatus.HIT: + return value + + # Fetch fresh data + fetch_count[0] += 1 + cache.record_request("api") + + # Simulate API response + data = f"data_for_{key}" + + cache.set(key, data, source="api", ttl_seconds=1) + cache.record_success("api") + + return data + + # First fetch - from API + result1 = fetch_data("key1") + assert result1 == "data_for_key1" + assert fetch_count[0] == 1 + + # Second fetch - from cache + result2 = fetch_data("key1") + assert result2 == "data_for_key1" + assert fetch_count[0] == 1 # No additional fetch + + # Wait for expiration and simulate rate limit + time.sleep(1.1) + cache.record_rate_limit("api", backoff_seconds=60) + + # Should get stale data + result3 = fetch_data("key1") + assert result3 == "data_for_key1" + assert fetch_count[0] == 1 # Still no additional fetch diff --git a/tradingagents/dataflows/cache.py b/tradingagents/dataflows/cache.py new file mode 100644 index 00000000..f24c1e49 --- /dev/null +++ b/tradingagents/dataflows/cache.py @@ -0,0 +1,628 @@ +"""Data caching layer for vendor data with rate limit awareness. + +This module provides a robust caching layer to handle API rate limits across +all data vendors. Features: +- Multi-backend support (memory, file, SQLite) +- TTL-based expiration with configurable per-source TTLs +- Rate limit tracking and backoff +- Cache statistics and monitoring +- Atomic cache operations for thread safety + +Issue #12: [DATA-11] Data caching layer - FRED rate limits +""" + +import hashlib +import json +import logging +import sqlite3 +import threading +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from enum import Enum, auto +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, TypeVar, Generic + +logger = logging.getLogger(__name__) + +T = TypeVar('T') + + +class CacheStatus(Enum): + """Status of a cache lookup.""" + HIT = auto() + MISS = auto() + EXPIRED = auto() + STALE = auto() # Expired but returned due to rate limit + + +@dataclass +class CacheEntry(Generic[T]): + """A single cache entry with metadata.""" + key: str + value: T + created_at: datetime + expires_at: datetime + access_count: int = 0 + last_accessed: Optional[datetime] = None + source: str = "" + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def is_expired(self) -> bool: + """Check if entry is expired.""" + return datetime.now() > self.expires_at + + @property + def age_seconds(self) -> float: + """Get age in seconds.""" + return (datetime.now() - self.created_at).total_seconds() + + @property + def ttl_remaining_seconds(self) -> float: + """Get remaining TTL in seconds.""" + return max(0, (self.expires_at - datetime.now()).total_seconds()) + + def touch(self) -> None: + """Update access metadata.""" + self.access_count += 1 + self.last_accessed = datetime.now() + + +@dataclass +class CacheStats: + """Statistics for cache operations.""" + hits: int = 0 + misses: int = 0 + expired: int = 0 + stale_served: int = 0 + evictions: int = 0 + size: int = 0 + + @property + def hit_rate(self) -> float: + """Calculate hit rate as percentage.""" + total = self.hits + self.misses + if total == 0: + return 0.0 + return (self.hits / total) * 100 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "hits": self.hits, + "misses": self.misses, + "expired": self.expired, + "stale_served": self.stale_served, + "evictions": self.evictions, + "size": self.size, + "hit_rate": self.hit_rate + } + + +@dataclass +class RateLimitState: + """Track rate limit state for a source.""" + source: str + requests_made: int = 0 + requests_limit: int = 120 # Default FRED limit + window_start: datetime = field(default_factory=datetime.now) + window_seconds: int = 60 + backoff_until: Optional[datetime] = None + consecutive_failures: int = 0 + + @property + def is_rate_limited(self) -> bool: + """Check if currently rate limited.""" + if self.backoff_until and datetime.now() < self.backoff_until: + return True + return False + + @property + def requests_remaining(self) -> int: + """Get remaining requests in current window.""" + self._maybe_reset_window() + return max(0, self.requests_limit - self.requests_made) + + def _maybe_reset_window(self) -> None: + """Reset window if expired.""" + if (datetime.now() - self.window_start).total_seconds() > self.window_seconds: + self.window_start = datetime.now() + self.requests_made = 0 + + def record_request(self) -> None: + """Record a request.""" + self._maybe_reset_window() + self.requests_made += 1 + + def record_rate_limit(self, backoff_seconds: int = 60) -> None: + """Record a rate limit hit.""" + self.consecutive_failures += 1 + # Exponential backoff + actual_backoff = backoff_seconds * (2 ** (self.consecutive_failures - 1)) + self.backoff_until = datetime.now() + timedelta(seconds=actual_backoff) + logger.warning(f"Rate limit hit for {self.source}, backing off for {actual_backoff}s") + + def record_success(self) -> None: + """Record successful request.""" + self.consecutive_failures = 0 + self.backoff_until = None + + +class CacheBackend(ABC): + """Abstract base class for cache backends.""" + + @abstractmethod + def get(self, key: str) -> Optional[CacheEntry]: + """Get entry from cache.""" + pass + + @abstractmethod + def set(self, entry: CacheEntry) -> None: + """Set entry in cache.""" + pass + + @abstractmethod + def delete(self, key: str) -> bool: + """Delete entry from cache.""" + pass + + @abstractmethod + def clear(self) -> int: + """Clear all entries. Returns number cleared.""" + pass + + @abstractmethod + def keys(self) -> List[str]: + """Get all cache keys.""" + pass + + @abstractmethod + def size(self) -> int: + """Get number of entries.""" + pass + + +class MemoryCache(CacheBackend): + """In-memory cache with LRU eviction.""" + + def __init__(self, max_size: int = 1000): + """Initialize memory cache. + + Args: + max_size: Maximum number of entries + """ + self._cache: Dict[str, CacheEntry] = {} + self._max_size = max_size + self._lock = threading.RLock() + self._access_order: List[str] = [] + + def get(self, key: str) -> Optional[CacheEntry]: + """Get entry from cache.""" + with self._lock: + entry = self._cache.get(key) + if entry: + # Update access order for LRU + if key in self._access_order: + self._access_order.remove(key) + self._access_order.append(key) + return entry + + def set(self, entry: CacheEntry) -> None: + """Set entry in cache with LRU eviction.""" + with self._lock: + # Evict if at capacity + while len(self._cache) >= self._max_size and self._access_order: + oldest_key = self._access_order.pop(0) + self._cache.pop(oldest_key, None) + + self._cache[entry.key] = entry + + # Update access order + if entry.key in self._access_order: + self._access_order.remove(entry.key) + self._access_order.append(entry.key) + + def delete(self, key: str) -> bool: + """Delete entry from cache.""" + with self._lock: + if key in self._cache: + del self._cache[key] + if key in self._access_order: + self._access_order.remove(key) + return True + return False + + def clear(self) -> int: + """Clear all entries.""" + with self._lock: + count = len(self._cache) + self._cache.clear() + self._access_order.clear() + return count + + def keys(self) -> List[str]: + """Get all cache keys.""" + with self._lock: + return list(self._cache.keys()) + + def size(self) -> int: + """Get number of entries.""" + with self._lock: + return len(self._cache) + + +class FileCache(CacheBackend): + """File-based cache using JSON serialization.""" + + def __init__(self, cache_dir: Optional[Path] = None): + """Initialize file cache. + + Args: + cache_dir: Directory for cache files + """ + self._cache_dir = cache_dir or Path.home() / ".cache" / "tradingagents" + self._cache_dir.mkdir(parents=True, exist_ok=True) + self._lock = threading.RLock() + + def _get_path(self, key: str) -> Path: + """Get file path for key.""" + # Use hash to avoid filesystem issues + safe_key = hashlib.md5(key.encode()).hexdigest() + return self._cache_dir / f"{safe_key}.json" + + def get(self, key: str) -> Optional[CacheEntry]: + """Get entry from file cache.""" + path = self._get_path(key) + if not path.exists(): + return None + + with self._lock: + try: + with open(path, 'r') as f: + data = json.load(f) + + return CacheEntry( + key=data['key'], + value=data['value'], + created_at=datetime.fromisoformat(data['created_at']), + expires_at=datetime.fromisoformat(data['expires_at']), + access_count=data.get('access_count', 0), + last_accessed=datetime.fromisoformat(data['last_accessed']) if data.get('last_accessed') else None, + source=data.get('source', ''), + metadata=data.get('metadata', {}) + ) + except (json.JSONDecodeError, KeyError, ValueError): + # Corrupted file + path.unlink(missing_ok=True) + return None + + def set(self, entry: CacheEntry) -> None: + """Set entry in file cache.""" + path = self._get_path(entry.key) + + with self._lock: + data = { + 'key': entry.key, + 'value': entry.value, + 'created_at': entry.created_at.isoformat(), + 'expires_at': entry.expires_at.isoformat(), + 'access_count': entry.access_count, + 'last_accessed': entry.last_accessed.isoformat() if entry.last_accessed else None, + 'source': entry.source, + 'metadata': entry.metadata + } + + with open(path, 'w') as f: + json.dump(data, f) + + def delete(self, key: str) -> bool: + """Delete entry from file cache.""" + path = self._get_path(key) + with self._lock: + if path.exists(): + path.unlink() + return True + return False + + def clear(self) -> int: + """Clear all entries.""" + with self._lock: + count = 0 + for path in self._cache_dir.glob("*.json"): + path.unlink() + count += 1 + return count + + def keys(self) -> List[str]: + """Get all cache keys (returns hashed keys).""" + return [p.stem for p in self._cache_dir.glob("*.json")] + + def size(self) -> int: + """Get number of entries.""" + return len(list(self._cache_dir.glob("*.json"))) + + +class DataCache: + """Main data cache with rate limit awareness. + + Provides caching for vendor data with configurable TTLs, + rate limit tracking, and stale-while-revalidate support. + + Example: + cache = DataCache() + + # Cache with default TTL + cache.set("fred:FEDFUNDS", data, source="fred") + + # Get with stale fallback if rate limited + result = cache.get("fred:FEDFUNDS", serve_stale_if_rate_limited=True) + + # Use as decorator + @cache.cached(ttl_seconds=3600, source="fred") + def get_fred_data(series_id): + return fetch_from_api(series_id) + """ + + # Default TTLs by source (in seconds) + DEFAULT_TTLS = { + "fred": 3600 * 24, # 24 hours for FRED (data updates daily) + "yfinance": 60, # 1 minute for real-time quotes + "finnhub": 60, # 1 minute for real-time data + "polygon": 300, # 5 minutes + "alpha_vantage": 300, # 5 minutes + "default": 300 # 5 minutes default + } + + # Default rate limits by source + DEFAULT_RATE_LIMITS = { + "fred": (120, 60), # 120 requests per 60 seconds + "yfinance": (2000, 60), # High limit (throttles internally) + "finnhub": (60, 60), # 60 per minute + "polygon": (5, 60), # 5 per minute (free tier) + "alpha_vantage": (5, 60), # 5 per minute (free tier) + "default": (100, 60) + } + + def __init__( + self, + backend: Optional[CacheBackend] = None, + default_ttl_seconds: int = 300 + ): + """Initialize data cache. + + Args: + backend: Cache backend (defaults to MemoryCache) + default_ttl_seconds: Default TTL for entries + """ + self._backend = backend or MemoryCache() + self._default_ttl = default_ttl_seconds + self._stats = CacheStats() + self._rate_limits: Dict[str, RateLimitState] = {} + self._lock = threading.RLock() + + def _generate_key(self, key: str, **kwargs) -> str: + """Generate cache key from key and optional params.""" + if kwargs: + params_str = json.dumps(kwargs, sort_keys=True) + return f"{key}:{hashlib.md5(params_str.encode()).hexdigest()[:8]}" + return key + + def _get_ttl(self, source: str) -> int: + """Get TTL for a source.""" + return self.DEFAULT_TTLS.get(source, self.DEFAULT_TTLS["default"]) + + def _get_rate_limit_state(self, source: str) -> RateLimitState: + """Get or create rate limit state for source.""" + if source not in self._rate_limits: + limit, window = self.DEFAULT_RATE_LIMITS.get( + source, + self.DEFAULT_RATE_LIMITS["default"] + ) + self._rate_limits[source] = RateLimitState( + source=source, + requests_limit=limit, + window_seconds=window + ) + return self._rate_limits[source] + + def get( + self, + key: str, + serve_stale_if_rate_limited: bool = True, + **kwargs + ) -> tuple[Optional[Any], CacheStatus]: + """Get value from cache. + + Args: + key: Cache key + serve_stale_if_rate_limited: Return expired value if rate limited + **kwargs: Additional key params + + Returns: + Tuple of (value, status) + """ + full_key = self._generate_key(key, **kwargs) + + with self._lock: + entry = self._backend.get(full_key) + + if entry is None: + self._stats.misses += 1 + return None, CacheStatus.MISS + + if not entry.is_expired: + entry.touch() + self._backend.set(entry) # Update metadata + self._stats.hits += 1 + return entry.value, CacheStatus.HIT + + # Entry is expired + self._stats.expired += 1 + + # Check if we should serve stale + if serve_stale_if_rate_limited: + rate_state = self._get_rate_limit_state(entry.source) + if rate_state.is_rate_limited: + self._stats.stale_served += 1 + return entry.value, CacheStatus.STALE + + return None, CacheStatus.EXPIRED + + def set( + self, + key: str, + value: Any, + ttl_seconds: Optional[int] = None, + source: str = "default", + metadata: Optional[Dict[str, Any]] = None, + **kwargs + ) -> None: + """Set value in cache. + + Args: + key: Cache key + value: Value to cache + ttl_seconds: TTL in seconds (uses source default if not specified) + source: Data source name + metadata: Optional metadata + **kwargs: Additional key params + """ + full_key = self._generate_key(key, **kwargs) + actual_ttl = ttl_seconds if ttl_seconds is not None else self._get_ttl(source) + + entry = CacheEntry( + key=full_key, + value=value, + created_at=datetime.now(), + expires_at=datetime.now() + timedelta(seconds=actual_ttl), + source=source, + metadata=metadata or {} + ) + + with self._lock: + self._backend.set(entry) + self._stats.size = self._backend.size() + + def delete(self, key: str, **kwargs) -> bool: + """Delete value from cache.""" + full_key = self._generate_key(key, **kwargs) + with self._lock: + result = self._backend.delete(full_key) + self._stats.size = self._backend.size() + return result + + def clear(self, source: Optional[str] = None) -> int: + """Clear cache entries. + + Args: + source: Clear only entries from this source (None = all) + + Returns: + Number of entries cleared + """ + with self._lock: + if source is None: + count = self._backend.clear() + else: + count = 0 + for key in self._backend.keys(): + entry = self._backend.get(key) + if entry and entry.source == source: + self._backend.delete(key) + count += 1 + + self._stats.evictions += count + self._stats.size = self._backend.size() + return count + + def record_rate_limit(self, source: str, backoff_seconds: int = 60) -> None: + """Record a rate limit hit for a source.""" + with self._lock: + state = self._get_rate_limit_state(source) + state.record_rate_limit(backoff_seconds) + + def record_request(self, source: str) -> None: + """Record a request for rate limit tracking.""" + with self._lock: + state = self._get_rate_limit_state(source) + state.record_request() + + def record_success(self, source: str) -> None: + """Record successful request.""" + with self._lock: + state = self._get_rate_limit_state(source) + state.record_success() + + def is_rate_limited(self, source: str) -> bool: + """Check if source is rate limited.""" + with self._lock: + state = self._get_rate_limit_state(source) + return state.is_rate_limited + + def get_stats(self) -> CacheStats: + """Get cache statistics.""" + with self._lock: + self._stats.size = self._backend.size() + return self._stats + + def cached( + self, + ttl_seconds: Optional[int] = None, + source: str = "default", + key_prefix: str = "" + ) -> Callable: + """Decorator for caching function results. + + Example: + @cache.cached(ttl_seconds=3600, source="fred") + def get_fred_data(series_id): + return fetch_from_api(series_id) + """ + def decorator(func: Callable) -> Callable: + def wrapper(*args, **kwargs): + # Generate cache key from function name and args + key_parts = [key_prefix, func.__name__] + if args: + key_parts.append(str(args)) + if kwargs: + key_parts.append(json.dumps(kwargs, sort_keys=True)) + cache_key = ":".join(filter(None, key_parts)) + + # Check cache + value, status = self.get(cache_key, serve_stale_if_rate_limited=True) + if status in (CacheStatus.HIT, CacheStatus.STALE): + return value + + # Execute function + result = func(*args, **kwargs) + + # Cache result + self.set(cache_key, result, ttl_seconds=ttl_seconds, source=source) + + return result + + return wrapper + return decorator + + +# Global cache instance +_global_cache: Optional[DataCache] = None +_global_cache_lock = threading.Lock() + + +def get_cache() -> DataCache: + """Get the global cache instance.""" + global _global_cache + with _global_cache_lock: + if _global_cache is None: + _global_cache = DataCache() + return _global_cache + + +def reset_cache() -> None: + """Reset the global cache instance. For testing.""" + global _global_cache + with _global_cache_lock: + _global_cache = None