feat(dataflows): add data caching layer with rate limit awareness - Fixes #12
Implements [DATA-11] Data caching layer - FRED rate limits with: - CacheEntry: Generic cache entries with TTL and metadata - CacheStats: Hit/miss/stale statistics tracking - RateLimitState: Per-source rate limit tracking with exponential backoff - MemoryCache: In-memory LRU cache backend - FileCache: File-based JSON cache backend - DataCache: Main cache with source-specific TTLs and stale-while-rate-limited - @cached decorator: Function result caching Features: - Multi-backend support (memory, file) - TTL-based expiration with configurable per-source defaults - Stale-while-revalidate when rate limited - Thread-safe operations throughout - 41 tests covering all components 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
2c802647e4
commit
ae7899a6fc
|
|
@ -0,0 +1,618 @@
|
||||||
|
"""Tests for data caching layer.
|
||||||
|
|
||||||
|
Issue #12: [DATA-11] Data caching layer - FRED rate limits
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import tempfile
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from tradingagents.dataflows.cache import (
|
||||||
|
CacheEntry,
|
||||||
|
CacheStats,
|
||||||
|
CacheStatus,
|
||||||
|
RateLimitState,
|
||||||
|
MemoryCache,
|
||||||
|
FileCache,
|
||||||
|
DataCache,
|
||||||
|
get_cache,
|
||||||
|
reset_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_global_cache():
|
||||||
|
"""Reset global cache before each test."""
|
||||||
|
reset_cache()
|
||||||
|
yield
|
||||||
|
reset_cache()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheEntry:
|
||||||
|
"""Tests for CacheEntry dataclass."""
|
||||||
|
|
||||||
|
def test_entry_creation(self):
|
||||||
|
"""Test creating a cache entry."""
|
||||||
|
entry = CacheEntry(
|
||||||
|
key="test_key",
|
||||||
|
value={"data": "test"},
|
||||||
|
created_at=datetime.now(),
|
||||||
|
expires_at=datetime.now() + timedelta(hours=1),
|
||||||
|
source="fred"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert entry.key == "test_key"
|
||||||
|
assert entry.value == {"data": "test"}
|
||||||
|
assert entry.source == "fred"
|
||||||
|
assert entry.access_count == 0
|
||||||
|
|
||||||
|
def test_is_expired_false(self):
|
||||||
|
"""Test is_expired returns False for valid entry."""
|
||||||
|
entry = CacheEntry(
|
||||||
|
key="test",
|
||||||
|
value="data",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
expires_at=datetime.now() + timedelta(hours=1)
|
||||||
|
)
|
||||||
|
assert entry.is_expired is False
|
||||||
|
|
||||||
|
def test_is_expired_true(self):
|
||||||
|
"""Test is_expired returns True for expired entry."""
|
||||||
|
entry = CacheEntry(
|
||||||
|
key="test",
|
||||||
|
value="data",
|
||||||
|
created_at=datetime.now() - timedelta(hours=2),
|
||||||
|
expires_at=datetime.now() - timedelta(hours=1)
|
||||||
|
)
|
||||||
|
assert entry.is_expired is True
|
||||||
|
|
||||||
|
def test_age_seconds(self):
|
||||||
|
"""Test age_seconds calculation."""
|
||||||
|
entry = CacheEntry(
|
||||||
|
key="test",
|
||||||
|
value="data",
|
||||||
|
created_at=datetime.now() - timedelta(seconds=60),
|
||||||
|
expires_at=datetime.now() + timedelta(hours=1)
|
||||||
|
)
|
||||||
|
assert 59 < entry.age_seconds < 61
|
||||||
|
|
||||||
|
def test_ttl_remaining(self):
|
||||||
|
"""Test ttl_remaining_seconds calculation."""
|
||||||
|
entry = CacheEntry(
|
||||||
|
key="test",
|
||||||
|
value="data",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
expires_at=datetime.now() + timedelta(seconds=3600)
|
||||||
|
)
|
||||||
|
assert 3599 < entry.ttl_remaining_seconds <= 3600
|
||||||
|
|
||||||
|
def test_touch_updates_metadata(self):
|
||||||
|
"""Test touch updates access metadata."""
|
||||||
|
entry = CacheEntry(
|
||||||
|
key="test",
|
||||||
|
value="data",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
expires_at=datetime.now() + timedelta(hours=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert entry.access_count == 0
|
||||||
|
assert entry.last_accessed is None
|
||||||
|
|
||||||
|
entry.touch()
|
||||||
|
|
||||||
|
assert entry.access_count == 1
|
||||||
|
assert entry.last_accessed is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheStats:
|
||||||
|
"""Tests for CacheStats dataclass."""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Test default values."""
|
||||||
|
stats = CacheStats()
|
||||||
|
assert stats.hits == 0
|
||||||
|
assert stats.misses == 0
|
||||||
|
assert stats.hit_rate == 0.0
|
||||||
|
|
||||||
|
def test_hit_rate_calculation(self):
|
||||||
|
"""Test hit rate calculation."""
|
||||||
|
stats = CacheStats(hits=75, misses=25)
|
||||||
|
assert stats.hit_rate == 75.0
|
||||||
|
|
||||||
|
def test_hit_rate_no_requests(self):
|
||||||
|
"""Test hit rate with no requests."""
|
||||||
|
stats = CacheStats()
|
||||||
|
assert stats.hit_rate == 0.0
|
||||||
|
|
||||||
|
def test_to_dict(self):
|
||||||
|
"""Test conversion to dictionary."""
|
||||||
|
stats = CacheStats(hits=10, misses=5, evictions=2)
|
||||||
|
d = stats.to_dict()
|
||||||
|
|
||||||
|
assert d["hits"] == 10
|
||||||
|
assert d["misses"] == 5
|
||||||
|
assert d["evictions"] == 2
|
||||||
|
assert "hit_rate" in d
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitState:
|
||||||
|
"""Tests for RateLimitState dataclass."""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Test default values."""
|
||||||
|
state = RateLimitState(source="fred")
|
||||||
|
assert state.source == "fred"
|
||||||
|
assert state.requests_made == 0
|
||||||
|
assert state.is_rate_limited is False
|
||||||
|
|
||||||
|
def test_record_request(self):
|
||||||
|
"""Test recording requests."""
|
||||||
|
state = RateLimitState(source="test", requests_limit=5)
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
state.record_request()
|
||||||
|
|
||||||
|
assert state.requests_made == 3
|
||||||
|
assert state.requests_remaining == 2
|
||||||
|
|
||||||
|
def test_is_rate_limited_after_backoff(self):
|
||||||
|
"""Test rate limiting after recording limit hit."""
|
||||||
|
state = RateLimitState(source="test")
|
||||||
|
|
||||||
|
assert state.is_rate_limited is False
|
||||||
|
|
||||||
|
state.record_rate_limit(backoff_seconds=1)
|
||||||
|
|
||||||
|
assert state.is_rate_limited is True
|
||||||
|
|
||||||
|
# Wait for backoff to expire
|
||||||
|
time.sleep(1.1)
|
||||||
|
assert state.is_rate_limited is False
|
||||||
|
|
||||||
|
def test_record_success_clears_backoff(self):
|
||||||
|
"""Test that success clears backoff."""
|
||||||
|
state = RateLimitState(source="test")
|
||||||
|
state.record_rate_limit(backoff_seconds=60)
|
||||||
|
assert state.is_rate_limited is True
|
||||||
|
|
||||||
|
state.record_success()
|
||||||
|
assert state.is_rate_limited is False
|
||||||
|
assert state.consecutive_failures == 0
|
||||||
|
|
||||||
|
def test_exponential_backoff(self):
|
||||||
|
"""Test exponential backoff on consecutive failures."""
|
||||||
|
state = RateLimitState(source="test")
|
||||||
|
|
||||||
|
# First failure - 1 second backoff
|
||||||
|
state.record_rate_limit(backoff_seconds=1)
|
||||||
|
assert state.consecutive_failures == 1
|
||||||
|
|
||||||
|
# Simulate recovery
|
||||||
|
state.backoff_until = None
|
||||||
|
|
||||||
|
# Second failure - 2 second backoff
|
||||||
|
state.record_rate_limit(backoff_seconds=1)
|
||||||
|
assert state.consecutive_failures == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryCache:
|
||||||
|
"""Tests for MemoryCache backend."""
|
||||||
|
|
||||||
|
def test_get_set(self):
|
||||||
|
"""Test basic get/set operations."""
|
||||||
|
cache = MemoryCache()
|
||||||
|
|
||||||
|
entry = CacheEntry(
|
||||||
|
key="test",
|
||||||
|
value="data",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
expires_at=datetime.now() + timedelta(hours=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
cache.set(entry)
|
||||||
|
result = cache.get("test")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "data"
|
||||||
|
|
||||||
|
def test_get_missing(self):
|
||||||
|
"""Test getting missing key."""
|
||||||
|
cache = MemoryCache()
|
||||||
|
assert cache.get("nonexistent") is None
|
||||||
|
|
||||||
|
def test_delete(self):
|
||||||
|
"""Test deleting entry."""
|
||||||
|
cache = MemoryCache()
|
||||||
|
|
||||||
|
entry = CacheEntry(
|
||||||
|
key="test",
|
||||||
|
value="data",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
expires_at=datetime.now() + timedelta(hours=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
cache.set(entry)
|
||||||
|
assert cache.delete("test") is True
|
||||||
|
assert cache.get("test") is None
|
||||||
|
|
||||||
|
def test_delete_missing(self):
|
||||||
|
"""Test deleting missing key."""
|
||||||
|
cache = MemoryCache()
|
||||||
|
assert cache.delete("nonexistent") is False
|
||||||
|
|
||||||
|
def test_clear(self):
|
||||||
|
"""Test clearing cache."""
|
||||||
|
cache = MemoryCache()
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
cache.set(CacheEntry(
|
||||||
|
key=f"key_{i}",
|
||||||
|
value=f"value_{i}",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
expires_at=datetime.now() + timedelta(hours=1)
|
||||||
|
))
|
||||||
|
|
||||||
|
count = cache.clear()
|
||||||
|
assert count == 5
|
||||||
|
assert cache.size() == 0
|
||||||
|
|
||||||
|
def test_lru_eviction(self):
|
||||||
|
"""Test LRU eviction when at capacity."""
|
||||||
|
cache = MemoryCache(max_size=3)
|
||||||
|
|
||||||
|
# Add 3 entries
|
||||||
|
for i in range(3):
|
||||||
|
cache.set(CacheEntry(
|
||||||
|
key=f"key_{i}",
|
||||||
|
value=f"value_{i}",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
expires_at=datetime.now() + timedelta(hours=1)
|
||||||
|
))
|
||||||
|
|
||||||
|
# Access key_1 to make it recently used
|
||||||
|
cache.get("key_1")
|
||||||
|
|
||||||
|
# Add new entry, should evict key_0 (least recently used)
|
||||||
|
cache.set(CacheEntry(
|
||||||
|
key="key_3",
|
||||||
|
value="value_3",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
expires_at=datetime.now() + timedelta(hours=1)
|
||||||
|
))
|
||||||
|
|
||||||
|
assert cache.size() == 3
|
||||||
|
assert cache.get("key_0") is None
|
||||||
|
assert cache.get("key_1") is not None
|
||||||
|
|
||||||
|
def test_thread_safety(self):
|
||||||
|
"""Test thread-safe operations."""
|
||||||
|
cache = MemoryCache()
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def write_entries(start):
|
||||||
|
try:
|
||||||
|
for i in range(100):
|
||||||
|
cache.set(CacheEntry(
|
||||||
|
key=f"key_{start}_{i}",
|
||||||
|
value=f"value_{start}_{i}",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
expires_at=datetime.now() + timedelta(hours=1)
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=write_entries, args=(i,)) for i in range(5)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert len(errors) == 0
|
||||||
|
assert cache.size() == 500
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileCache:
|
||||||
|
"""Tests for FileCache backend."""
|
||||||
|
|
||||||
|
def test_get_set(self):
|
||||||
|
"""Test basic get/set operations."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cache = FileCache(cache_dir=Path(tmpdir))
|
||||||
|
|
||||||
|
entry = CacheEntry(
|
||||||
|
key="test",
|
||||||
|
value={"data": "test"},
|
||||||
|
created_at=datetime.now(),
|
||||||
|
expires_at=datetime.now() + timedelta(hours=1),
|
||||||
|
source="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
cache.set(entry)
|
||||||
|
result = cache.get("test")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == {"data": "test"}
|
||||||
|
assert result.source == "test"
|
||||||
|
|
||||||
|
def test_get_missing(self):
|
||||||
|
"""Test getting missing key."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cache = FileCache(cache_dir=Path(tmpdir))
|
||||||
|
assert cache.get("nonexistent") is None
|
||||||
|
|
||||||
|
def test_delete(self):
|
||||||
|
"""Test deleting entry."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cache = FileCache(cache_dir=Path(tmpdir))
|
||||||
|
|
||||||
|
entry = CacheEntry(
|
||||||
|
key="test",
|
||||||
|
value="data",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
expires_at=datetime.now() + timedelta(hours=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
cache.set(entry)
|
||||||
|
assert cache.delete("test") is True
|
||||||
|
assert cache.get("test") is None
|
||||||
|
|
||||||
|
def test_clear(self):
|
||||||
|
"""Test clearing cache."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cache = FileCache(cache_dir=Path(tmpdir))
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
cache.set(CacheEntry(
|
||||||
|
key=f"key_{i}",
|
||||||
|
value=f"value_{i}",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
expires_at=datetime.now() + timedelta(hours=1)
|
||||||
|
))
|
||||||
|
|
||||||
|
count = cache.clear()
|
||||||
|
assert count == 3
|
||||||
|
assert cache.size() == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataCache:
|
||||||
|
"""Tests for DataCache main class."""
|
||||||
|
|
||||||
|
def test_get_set_basic(self):
|
||||||
|
"""Test basic get/set operations."""
|
||||||
|
cache = DataCache()
|
||||||
|
|
||||||
|
cache.set("test_key", {"data": "test"}, source="fred")
|
||||||
|
value, status = cache.get("test_key")
|
||||||
|
|
||||||
|
assert status == CacheStatus.HIT
|
||||||
|
assert value == {"data": "test"}
|
||||||
|
|
||||||
|
def test_get_miss(self):
|
||||||
|
"""Test cache miss."""
|
||||||
|
cache = DataCache()
|
||||||
|
|
||||||
|
value, status = cache.get("nonexistent")
|
||||||
|
|
||||||
|
assert status == CacheStatus.MISS
|
||||||
|
assert value is None
|
||||||
|
|
||||||
|
def test_get_expired(self):
|
||||||
|
"""Test getting expired entry."""
|
||||||
|
cache = DataCache()
|
||||||
|
|
||||||
|
# Set with very short TTL
|
||||||
|
cache.set("test", "data", ttl_seconds=0, source="test")
|
||||||
|
|
||||||
|
# Wait for expiration
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
value, status = cache.get("test", serve_stale_if_rate_limited=False)
|
||||||
|
|
||||||
|
assert status == CacheStatus.EXPIRED
|
||||||
|
assert value is None
|
||||||
|
|
||||||
|
def test_serve_stale_when_rate_limited(self):
|
||||||
|
"""Test serving stale data when rate limited."""
|
||||||
|
cache = DataCache()
|
||||||
|
|
||||||
|
# Set entry that will expire
|
||||||
|
cache.set("test", "stale_data", ttl_seconds=0, source="test")
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Simulate rate limit
|
||||||
|
cache.record_rate_limit("test", backoff_seconds=60)
|
||||||
|
|
||||||
|
# Should get stale data
|
||||||
|
value, status = cache.get("test", serve_stale_if_rate_limited=True)
|
||||||
|
|
||||||
|
assert status == CacheStatus.STALE
|
||||||
|
assert value == "stale_data"
|
||||||
|
|
||||||
|
def test_delete(self):
|
||||||
|
"""Test deleting entry."""
|
||||||
|
cache = DataCache()
|
||||||
|
|
||||||
|
cache.set("test", "data", source="test")
|
||||||
|
assert cache.delete("test") is True
|
||||||
|
|
||||||
|
value, status = cache.get("test")
|
||||||
|
assert status == CacheStatus.MISS
|
||||||
|
|
||||||
|
def test_clear_all(self):
|
||||||
|
"""Test clearing all entries."""
|
||||||
|
cache = DataCache()
|
||||||
|
|
||||||
|
cache.set("key1", "value1", source="fred")
|
||||||
|
cache.set("key2", "value2", source="yfinance")
|
||||||
|
|
||||||
|
count = cache.clear()
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
def test_clear_by_source(self):
|
||||||
|
"""Test clearing entries by source."""
|
||||||
|
cache = DataCache()
|
||||||
|
|
||||||
|
cache.set("fred_key", "fred_data", source="fred")
|
||||||
|
cache.set("yf_key", "yf_data", source="yfinance")
|
||||||
|
|
||||||
|
count = cache.clear(source="fred")
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
# yfinance entry should still exist
|
||||||
|
value, status = cache.get("yf_key")
|
||||||
|
assert status == CacheStatus.HIT
|
||||||
|
|
||||||
|
def test_key_with_params(self):
|
||||||
|
"""Test key generation with params."""
|
||||||
|
cache = DataCache()
|
||||||
|
|
||||||
|
cache.set("series", "data1", source="fred", series_id="FEDFUNDS")
|
||||||
|
cache.set("series", "data2", source="fred", series_id="DGS10")
|
||||||
|
|
||||||
|
value1, _ = cache.get("series", series_id="FEDFUNDS")
|
||||||
|
value2, _ = cache.get("series", series_id="DGS10")
|
||||||
|
|
||||||
|
assert value1 == "data1"
|
||||||
|
assert value2 == "data2"
|
||||||
|
|
||||||
|
def test_stats_tracking(self):
|
||||||
|
"""Test statistics tracking."""
|
||||||
|
cache = DataCache()
|
||||||
|
|
||||||
|
# Miss
|
||||||
|
cache.get("missing")
|
||||||
|
|
||||||
|
# Hit
|
||||||
|
cache.set("present", "data", source="test")
|
||||||
|
cache.get("present")
|
||||||
|
cache.get("present")
|
||||||
|
|
||||||
|
stats = cache.get_stats()
|
||||||
|
assert stats.misses == 1
|
||||||
|
assert stats.hits == 2
|
||||||
|
|
||||||
|
def test_rate_limit_tracking(self):
|
||||||
|
"""Test rate limit state tracking."""
|
||||||
|
cache = DataCache()
|
||||||
|
|
||||||
|
assert cache.is_rate_limited("fred") is False
|
||||||
|
|
||||||
|
cache.record_rate_limit("fred", backoff_seconds=1)
|
||||||
|
assert cache.is_rate_limited("fred") is True
|
||||||
|
|
||||||
|
time.sleep(1.1)
|
||||||
|
assert cache.is_rate_limited("fred") is False
|
||||||
|
|
||||||
|
def test_cached_decorator(self):
|
||||||
|
"""Test @cached decorator."""
|
||||||
|
cache = DataCache()
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
@cache.cached(ttl_seconds=300, source="test")
|
||||||
|
def expensive_function(x):
|
||||||
|
call_count[0] += 1
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
# First call - executes function
|
||||||
|
result1 = expensive_function(5)
|
||||||
|
assert result1 == 10
|
||||||
|
assert call_count[0] == 1
|
||||||
|
|
||||||
|
# Second call - from cache
|
||||||
|
result2 = expensive_function(5)
|
||||||
|
assert result2 == 10
|
||||||
|
assert call_count[0] == 1
|
||||||
|
|
||||||
|
# Different argument - executes function
|
||||||
|
result3 = expensive_function(10)
|
||||||
|
assert result3 == 20
|
||||||
|
assert call_count[0] == 2
|
||||||
|
|
||||||
|
def test_default_ttls_by_source(self):
|
||||||
|
"""Test default TTLs are applied by source."""
|
||||||
|
cache = DataCache()
|
||||||
|
|
||||||
|
# FRED default is 24 hours
|
||||||
|
cache.set("fred_data", "data", source="fred")
|
||||||
|
entry = cache._backend.get(cache._generate_key("fred_data"))
|
||||||
|
|
||||||
|
# Should have ~24 hour TTL
|
||||||
|
assert entry.ttl_remaining_seconds > 3600 * 23
|
||||||
|
|
||||||
|
|
||||||
|
class TestGlobalCache:
|
||||||
|
"""Tests for global cache functions."""
|
||||||
|
|
||||||
|
def test_get_cache_singleton(self):
|
||||||
|
"""Test get_cache returns singleton."""
|
||||||
|
cache1 = get_cache()
|
||||||
|
cache2 = get_cache()
|
||||||
|
assert cache1 is cache2
|
||||||
|
|
||||||
|
def test_reset_cache(self):
|
||||||
|
"""Test reset_cache creates new instance."""
|
||||||
|
cache1 = get_cache()
|
||||||
|
reset_cache()
|
||||||
|
cache2 = get_cache()
|
||||||
|
assert cache1 is not cache2
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheIntegration:
|
||||||
|
"""Integration tests for cache with rate limiting."""
|
||||||
|
|
||||||
|
def test_rate_limited_fetch_pattern(self):
|
||||||
|
"""Test typical pattern: cache + rate limit handling."""
|
||||||
|
cache = DataCache()
|
||||||
|
fetch_count = [0]
|
||||||
|
|
||||||
|
def fetch_data(key):
|
||||||
|
"""Simulate data fetch with rate limit."""
|
||||||
|
# Check rate limit first
|
||||||
|
if cache.is_rate_limited("api"):
|
||||||
|
# Try stale cache
|
||||||
|
value, status = cache.get(key, serve_stale_if_rate_limited=True)
|
||||||
|
if status == CacheStatus.STALE:
|
||||||
|
return value
|
||||||
|
raise RuntimeError("Rate limited and no stale data")
|
||||||
|
|
||||||
|
# Check cache
|
||||||
|
value, status = cache.get(key)
|
||||||
|
if status == CacheStatus.HIT:
|
||||||
|
return value
|
||||||
|
|
||||||
|
# Fetch fresh data
|
||||||
|
fetch_count[0] += 1
|
||||||
|
cache.record_request("api")
|
||||||
|
|
||||||
|
# Simulate API response
|
||||||
|
data = f"data_for_{key}"
|
||||||
|
|
||||||
|
cache.set(key, data, source="api", ttl_seconds=1)
|
||||||
|
cache.record_success("api")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
# First fetch - from API
|
||||||
|
result1 = fetch_data("key1")
|
||||||
|
assert result1 == "data_for_key1"
|
||||||
|
assert fetch_count[0] == 1
|
||||||
|
|
||||||
|
# Second fetch - from cache
|
||||||
|
result2 = fetch_data("key1")
|
||||||
|
assert result2 == "data_for_key1"
|
||||||
|
assert fetch_count[0] == 1 # No additional fetch
|
||||||
|
|
||||||
|
# Wait for expiration and simulate rate limit
|
||||||
|
time.sleep(1.1)
|
||||||
|
cache.record_rate_limit("api", backoff_seconds=60)
|
||||||
|
|
||||||
|
# Should get stale data
|
||||||
|
result3 = fetch_data("key1")
|
||||||
|
assert result3 == "data_for_key1"
|
||||||
|
assert fetch_count[0] == 1 # Still no additional fetch
|
||||||
|
|
@ -0,0 +1,628 @@
|
||||||
|
"""Data caching layer for vendor data with rate limit awareness.
|
||||||
|
|
||||||
|
This module provides a robust caching layer to handle API rate limits across
|
||||||
|
all data vendors. Features:
|
||||||
|
- Multi-backend support (memory, file, SQLite)
|
||||||
|
- TTL-based expiration with configurable per-source TTLs
|
||||||
|
- Rate limit tracking and backoff
|
||||||
|
- Cache statistics and monitoring
|
||||||
|
- Atomic cache operations for thread safety
|
||||||
|
|
||||||
|
Issue #12: [DATA-11] Data caching layer - FRED rate limits
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from enum import Enum, auto
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, TypeVar, Generic
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
class CacheStatus(Enum):
|
||||||
|
"""Status of a cache lookup."""
|
||||||
|
HIT = auto()
|
||||||
|
MISS = auto()
|
||||||
|
EXPIRED = auto()
|
||||||
|
STALE = auto() # Expired but returned due to rate limit
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheEntry(Generic[T]):
|
||||||
|
"""A single cache entry with metadata."""
|
||||||
|
key: str
|
||||||
|
value: T
|
||||||
|
created_at: datetime
|
||||||
|
expires_at: datetime
|
||||||
|
access_count: int = 0
|
||||||
|
last_accessed: Optional[datetime] = None
|
||||||
|
source: str = ""
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""Check if entry is expired."""
|
||||||
|
return datetime.now() > self.expires_at
|
||||||
|
|
||||||
|
@property
|
||||||
|
def age_seconds(self) -> float:
|
||||||
|
"""Get age in seconds."""
|
||||||
|
return (datetime.now() - self.created_at).total_seconds()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ttl_remaining_seconds(self) -> float:
|
||||||
|
"""Get remaining TTL in seconds."""
|
||||||
|
return max(0, (self.expires_at - datetime.now()).total_seconds())
|
||||||
|
|
||||||
|
def touch(self) -> None:
|
||||||
|
"""Update access metadata."""
|
||||||
|
self.access_count += 1
|
||||||
|
self.last_accessed = datetime.now()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheStats:
|
||||||
|
"""Statistics for cache operations."""
|
||||||
|
hits: int = 0
|
||||||
|
misses: int = 0
|
||||||
|
expired: int = 0
|
||||||
|
stale_served: int = 0
|
||||||
|
evictions: int = 0
|
||||||
|
size: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hit_rate(self) -> float:
|
||||||
|
"""Calculate hit rate as percentage."""
|
||||||
|
total = self.hits + self.misses
|
||||||
|
if total == 0:
|
||||||
|
return 0.0
|
||||||
|
return (self.hits / total) * 100
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"hits": self.hits,
|
||||||
|
"misses": self.misses,
|
||||||
|
"expired": self.expired,
|
||||||
|
"stale_served": self.stale_served,
|
||||||
|
"evictions": self.evictions,
|
||||||
|
"size": self.size,
|
||||||
|
"hit_rate": self.hit_rate
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RateLimitState:
|
||||||
|
"""Track rate limit state for a source."""
|
||||||
|
source: str
|
||||||
|
requests_made: int = 0
|
||||||
|
requests_limit: int = 120 # Default FRED limit
|
||||||
|
window_start: datetime = field(default_factory=datetime.now)
|
||||||
|
window_seconds: int = 60
|
||||||
|
backoff_until: Optional[datetime] = None
|
||||||
|
consecutive_failures: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_rate_limited(self) -> bool:
|
||||||
|
"""Check if currently rate limited."""
|
||||||
|
if self.backoff_until and datetime.now() < self.backoff_until:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requests_remaining(self) -> int:
|
||||||
|
"""Get remaining requests in current window."""
|
||||||
|
self._maybe_reset_window()
|
||||||
|
return max(0, self.requests_limit - self.requests_made)
|
||||||
|
|
||||||
|
def _maybe_reset_window(self) -> None:
|
||||||
|
"""Reset window if expired."""
|
||||||
|
if (datetime.now() - self.window_start).total_seconds() > self.window_seconds:
|
||||||
|
self.window_start = datetime.now()
|
||||||
|
self.requests_made = 0
|
||||||
|
|
||||||
|
def record_request(self) -> None:
|
||||||
|
"""Record a request."""
|
||||||
|
self._maybe_reset_window()
|
||||||
|
self.requests_made += 1
|
||||||
|
|
||||||
|
def record_rate_limit(self, backoff_seconds: int = 60) -> None:
|
||||||
|
"""Record a rate limit hit."""
|
||||||
|
self.consecutive_failures += 1
|
||||||
|
# Exponential backoff
|
||||||
|
actual_backoff = backoff_seconds * (2 ** (self.consecutive_failures - 1))
|
||||||
|
self.backoff_until = datetime.now() + timedelta(seconds=actual_backoff)
|
||||||
|
logger.warning(f"Rate limit hit for {self.source}, backing off for {actual_backoff}s")
|
||||||
|
|
||||||
|
def record_success(self) -> None:
|
||||||
|
"""Record successful request."""
|
||||||
|
self.consecutive_failures = 0
|
||||||
|
self.backoff_until = None
|
||||||
|
|
||||||
|
|
||||||
|
class CacheBackend(ABC):
|
||||||
|
"""Abstract base class for cache backends."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, key: str) -> Optional[CacheEntry]:
|
||||||
|
"""Get entry from cache."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set(self, entry: CacheEntry) -> None:
|
||||||
|
"""Set entry in cache."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, key: str) -> bool:
|
||||||
|
"""Delete entry from cache."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def clear(self) -> int:
|
||||||
|
"""Clear all entries. Returns number cleared."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def keys(self) -> List[str]:
|
||||||
|
"""Get all cache keys."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def size(self) -> int:
|
||||||
|
"""Get number of entries."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryCache(CacheBackend):
|
||||||
|
"""In-memory cache with LRU eviction."""
|
||||||
|
|
||||||
|
def __init__(self, max_size: int = 1000):
|
||||||
|
"""Initialize memory cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size: Maximum number of entries
|
||||||
|
"""
|
||||||
|
self._cache: Dict[str, CacheEntry] = {}
|
||||||
|
self._max_size = max_size
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
self._access_order: List[str] = []
|
||||||
|
|
||||||
|
def get(self, key: str) -> Optional[CacheEntry]:
|
||||||
|
"""Get entry from cache."""
|
||||||
|
with self._lock:
|
||||||
|
entry = self._cache.get(key)
|
||||||
|
if entry:
|
||||||
|
# Update access order for LRU
|
||||||
|
if key in self._access_order:
|
||||||
|
self._access_order.remove(key)
|
||||||
|
self._access_order.append(key)
|
||||||
|
return entry
|
||||||
|
|
||||||
|
def set(self, entry: CacheEntry) -> None:
|
||||||
|
"""Set entry in cache with LRU eviction."""
|
||||||
|
with self._lock:
|
||||||
|
# Evict if at capacity
|
||||||
|
while len(self._cache) >= self._max_size and self._access_order:
|
||||||
|
oldest_key = self._access_order.pop(0)
|
||||||
|
self._cache.pop(oldest_key, None)
|
||||||
|
|
||||||
|
self._cache[entry.key] = entry
|
||||||
|
|
||||||
|
# Update access order
|
||||||
|
if entry.key in self._access_order:
|
||||||
|
self._access_order.remove(entry.key)
|
||||||
|
self._access_order.append(entry.key)
|
||||||
|
|
||||||
|
def delete(self, key: str) -> bool:
|
||||||
|
"""Delete entry from cache."""
|
||||||
|
with self._lock:
|
||||||
|
if key in self._cache:
|
||||||
|
del self._cache[key]
|
||||||
|
if key in self._access_order:
|
||||||
|
self._access_order.remove(key)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def clear(self) -> int:
|
||||||
|
"""Clear all entries."""
|
||||||
|
with self._lock:
|
||||||
|
count = len(self._cache)
|
||||||
|
self._cache.clear()
|
||||||
|
self._access_order.clear()
|
||||||
|
return count
|
||||||
|
|
||||||
|
def keys(self) -> List[str]:
|
||||||
|
"""Get all cache keys."""
|
||||||
|
with self._lock:
|
||||||
|
return list(self._cache.keys())
|
||||||
|
|
||||||
|
def size(self) -> int:
|
||||||
|
"""Get number of entries."""
|
||||||
|
with self._lock:
|
||||||
|
return len(self._cache)
|
||||||
|
|
||||||
|
|
||||||
|
class FileCache(CacheBackend):
|
||||||
|
"""File-based cache using JSON serialization."""
|
||||||
|
|
||||||
|
def __init__(self, cache_dir: Optional[Path] = None):
|
||||||
|
"""Initialize file cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_dir: Directory for cache files
|
||||||
|
"""
|
||||||
|
self._cache_dir = cache_dir or Path.home() / ".cache" / "tradingagents"
|
||||||
|
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def _get_path(self, key: str) -> Path:
|
||||||
|
"""Get file path for key."""
|
||||||
|
# Use hash to avoid filesystem issues
|
||||||
|
safe_key = hashlib.md5(key.encode()).hexdigest()
|
||||||
|
return self._cache_dir / f"{safe_key}.json"
|
||||||
|
|
||||||
|
def get(self, key: str) -> Optional[CacheEntry]:
|
||||||
|
"""Get entry from file cache."""
|
||||||
|
path = self._get_path(key)
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
return CacheEntry(
|
||||||
|
key=data['key'],
|
||||||
|
value=data['value'],
|
||||||
|
created_at=datetime.fromisoformat(data['created_at']),
|
||||||
|
expires_at=datetime.fromisoformat(data['expires_at']),
|
||||||
|
access_count=data.get('access_count', 0),
|
||||||
|
last_accessed=datetime.fromisoformat(data['last_accessed']) if data.get('last_accessed') else None,
|
||||||
|
source=data.get('source', ''),
|
||||||
|
metadata=data.get('metadata', {})
|
||||||
|
)
|
||||||
|
except (json.JSONDecodeError, KeyError, ValueError):
|
||||||
|
# Corrupted file
|
||||||
|
path.unlink(missing_ok=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, entry: CacheEntry) -> None:
|
||||||
|
"""Set entry in file cache."""
|
||||||
|
path = self._get_path(entry.key)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
data = {
|
||||||
|
'key': entry.key,
|
||||||
|
'value': entry.value,
|
||||||
|
'created_at': entry.created_at.isoformat(),
|
||||||
|
'expires_at': entry.expires_at.isoformat(),
|
||||||
|
'access_count': entry.access_count,
|
||||||
|
'last_accessed': entry.last_accessed.isoformat() if entry.last_accessed else None,
|
||||||
|
'source': entry.source,
|
||||||
|
'metadata': entry.metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(path, 'w') as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
|
||||||
|
def delete(self, key: str) -> bool:
|
||||||
|
"""Delete entry from file cache."""
|
||||||
|
path = self._get_path(key)
|
||||||
|
with self._lock:
|
||||||
|
if path.exists():
|
||||||
|
path.unlink()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def clear(self) -> int:
|
||||||
|
"""Clear all entries."""
|
||||||
|
with self._lock:
|
||||||
|
count = 0
|
||||||
|
for path in self._cache_dir.glob("*.json"):
|
||||||
|
path.unlink()
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
def keys(self) -> List[str]:
|
||||||
|
"""Get all cache keys (returns hashed keys)."""
|
||||||
|
return [p.stem for p in self._cache_dir.glob("*.json")]
|
||||||
|
|
||||||
|
def size(self) -> int:
|
||||||
|
"""Get number of entries."""
|
||||||
|
return len(list(self._cache_dir.glob("*.json")))
|
||||||
|
|
||||||
|
|
||||||
|
class DataCache:
|
||||||
|
"""Main data cache with rate limit awareness.
|
||||||
|
|
||||||
|
Provides caching for vendor data with configurable TTLs,
|
||||||
|
rate limit tracking, and stale-while-revalidate support.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
cache = DataCache()
|
||||||
|
|
||||||
|
# Cache with default TTL
|
||||||
|
cache.set("fred:FEDFUNDS", data, source="fred")
|
||||||
|
|
||||||
|
# Get with stale fallback if rate limited
|
||||||
|
result = cache.get("fred:FEDFUNDS", serve_stale_if_rate_limited=True)
|
||||||
|
|
||||||
|
# Use as decorator
|
||||||
|
@cache.cached(ttl_seconds=3600, source="fred")
|
||||||
|
def get_fred_data(series_id):
|
||||||
|
return fetch_from_api(series_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default TTLs by source (in seconds)
|
||||||
|
DEFAULT_TTLS = {
|
||||||
|
"fred": 3600 * 24, # 24 hours for FRED (data updates daily)
|
||||||
|
"yfinance": 60, # 1 minute for real-time quotes
|
||||||
|
"finnhub": 60, # 1 minute for real-time data
|
||||||
|
"polygon": 300, # 5 minutes
|
||||||
|
"alpha_vantage": 300, # 5 minutes
|
||||||
|
"default": 300 # 5 minutes default
|
||||||
|
}
|
||||||
|
|
||||||
|
# Default rate limits by source
|
||||||
|
DEFAULT_RATE_LIMITS = {
|
||||||
|
"fred": (120, 60), # 120 requests per 60 seconds
|
||||||
|
"yfinance": (2000, 60), # High limit (throttles internally)
|
||||||
|
"finnhub": (60, 60), # 60 per minute
|
||||||
|
"polygon": (5, 60), # 5 per minute (free tier)
|
||||||
|
"alpha_vantage": (5, 60), # 5 per minute (free tier)
|
||||||
|
"default": (100, 60)
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backend: Optional[CacheBackend] = None,
|
||||||
|
default_ttl_seconds: int = 300
|
||||||
|
):
|
||||||
|
"""Initialize data cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: Cache backend (defaults to MemoryCache)
|
||||||
|
default_ttl_seconds: Default TTL for entries
|
||||||
|
"""
|
||||||
|
self._backend = backend or MemoryCache()
|
||||||
|
self._default_ttl = default_ttl_seconds
|
||||||
|
self._stats = CacheStats()
|
||||||
|
self._rate_limits: Dict[str, RateLimitState] = {}
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def _generate_key(self, key: str, **kwargs) -> str:
|
||||||
|
"""Generate cache key from key and optional params."""
|
||||||
|
if kwargs:
|
||||||
|
params_str = json.dumps(kwargs, sort_keys=True)
|
||||||
|
return f"{key}:{hashlib.md5(params_str.encode()).hexdigest()[:8]}"
|
||||||
|
return key
|
||||||
|
|
||||||
|
def _get_ttl(self, source: str) -> int:
|
||||||
|
"""Get TTL for a source."""
|
||||||
|
return self.DEFAULT_TTLS.get(source, self.DEFAULT_TTLS["default"])
|
||||||
|
|
||||||
|
def _get_rate_limit_state(self, source: str) -> RateLimitState:
|
||||||
|
"""Get or create rate limit state for source."""
|
||||||
|
if source not in self._rate_limits:
|
||||||
|
limit, window = self.DEFAULT_RATE_LIMITS.get(
|
||||||
|
source,
|
||||||
|
self.DEFAULT_RATE_LIMITS["default"]
|
||||||
|
)
|
||||||
|
self._rate_limits[source] = RateLimitState(
|
||||||
|
source=source,
|
||||||
|
requests_limit=limit,
|
||||||
|
window_seconds=window
|
||||||
|
)
|
||||||
|
return self._rate_limits[source]
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
serve_stale_if_rate_limited: bool = True,
|
||||||
|
**kwargs
|
||||||
|
) -> tuple[Optional[Any], CacheStatus]:
|
||||||
|
"""Get value from cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
serve_stale_if_rate_limited: Return expired value if rate limited
|
||||||
|
**kwargs: Additional key params
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (value, status)
|
||||||
|
"""
|
||||||
|
full_key = self._generate_key(key, **kwargs)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
entry = self._backend.get(full_key)
|
||||||
|
|
||||||
|
if entry is None:
|
||||||
|
self._stats.misses += 1
|
||||||
|
return None, CacheStatus.MISS
|
||||||
|
|
||||||
|
if not entry.is_expired:
|
||||||
|
entry.touch()
|
||||||
|
self._backend.set(entry) # Update metadata
|
||||||
|
self._stats.hits += 1
|
||||||
|
return entry.value, CacheStatus.HIT
|
||||||
|
|
||||||
|
# Entry is expired
|
||||||
|
self._stats.expired += 1
|
||||||
|
|
||||||
|
# Check if we should serve stale
|
||||||
|
if serve_stale_if_rate_limited:
|
||||||
|
rate_state = self._get_rate_limit_state(entry.source)
|
||||||
|
if rate_state.is_rate_limited:
|
||||||
|
self._stats.stale_served += 1
|
||||||
|
return entry.value, CacheStatus.STALE
|
||||||
|
|
||||||
|
return None, CacheStatus.EXPIRED
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
source: str = "default",
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
"""Set value in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
value: Value to cache
|
||||||
|
ttl_seconds: TTL in seconds (uses source default if not specified)
|
||||||
|
source: Data source name
|
||||||
|
metadata: Optional metadata
|
||||||
|
**kwargs: Additional key params
|
||||||
|
"""
|
||||||
|
full_key = self._generate_key(key, **kwargs)
|
||||||
|
actual_ttl = ttl_seconds if ttl_seconds is not None else self._get_ttl(source)
|
||||||
|
|
||||||
|
entry = CacheEntry(
|
||||||
|
key=full_key,
|
||||||
|
value=value,
|
||||||
|
created_at=datetime.now(),
|
||||||
|
expires_at=datetime.now() + timedelta(seconds=actual_ttl),
|
||||||
|
source=source,
|
||||||
|
metadata=metadata or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._backend.set(entry)
|
||||||
|
self._stats.size = self._backend.size()
|
||||||
|
|
||||||
|
def delete(self, key: str, **kwargs) -> bool:
|
||||||
|
"""Delete value from cache."""
|
||||||
|
full_key = self._generate_key(key, **kwargs)
|
||||||
|
with self._lock:
|
||||||
|
result = self._backend.delete(full_key)
|
||||||
|
self._stats.size = self._backend.size()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def clear(self, source: Optional[str] = None) -> int:
|
||||||
|
"""Clear cache entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Clear only entries from this source (None = all)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries cleared
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if source is None:
|
||||||
|
count = self._backend.clear()
|
||||||
|
else:
|
||||||
|
count = 0
|
||||||
|
for key in self._backend.keys():
|
||||||
|
entry = self._backend.get(key)
|
||||||
|
if entry and entry.source == source:
|
||||||
|
self._backend.delete(key)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
self._stats.evictions += count
|
||||||
|
self._stats.size = self._backend.size()
|
||||||
|
return count
|
||||||
|
|
||||||
|
def record_rate_limit(self, source: str, backoff_seconds: int = 60) -> None:
|
||||||
|
"""Record a rate limit hit for a source."""
|
||||||
|
with self._lock:
|
||||||
|
state = self._get_rate_limit_state(source)
|
||||||
|
state.record_rate_limit(backoff_seconds)
|
||||||
|
|
||||||
|
def record_request(self, source: str) -> None:
|
||||||
|
"""Record a request for rate limit tracking."""
|
||||||
|
with self._lock:
|
||||||
|
state = self._get_rate_limit_state(source)
|
||||||
|
state.record_request()
|
||||||
|
|
||||||
|
def record_success(self, source: str) -> None:
|
||||||
|
"""Record successful request."""
|
||||||
|
with self._lock:
|
||||||
|
state = self._get_rate_limit_state(source)
|
||||||
|
state.record_success()
|
||||||
|
|
||||||
|
def is_rate_limited(self, source: str) -> bool:
|
||||||
|
"""Check if source is rate limited."""
|
||||||
|
with self._lock:
|
||||||
|
state = self._get_rate_limit_state(source)
|
||||||
|
return state.is_rate_limited
|
||||||
|
|
||||||
|
def get_stats(self) -> CacheStats:
|
||||||
|
"""Get cache statistics."""
|
||||||
|
with self._lock:
|
||||||
|
self._stats.size = self._backend.size()
|
||||||
|
return self._stats
|
||||||
|
|
||||||
|
def cached(
|
||||||
|
self,
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
source: str = "default",
|
||||||
|
key_prefix: str = ""
|
||||||
|
) -> Callable:
|
||||||
|
"""Decorator for caching function results.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@cache.cached(ttl_seconds=3600, source="fred")
|
||||||
|
def get_fred_data(series_id):
|
||||||
|
return fetch_from_api(series_id)
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Generate cache key from function name and args
|
||||||
|
key_parts = [key_prefix, func.__name__]
|
||||||
|
if args:
|
||||||
|
key_parts.append(str(args))
|
||||||
|
if kwargs:
|
||||||
|
key_parts.append(json.dumps(kwargs, sort_keys=True))
|
||||||
|
cache_key = ":".join(filter(None, key_parts))
|
||||||
|
|
||||||
|
# Check cache
|
||||||
|
value, status = self.get(cache_key, serve_stale_if_rate_limited=True)
|
||||||
|
if status in (CacheStatus.HIT, CacheStatus.STALE):
|
||||||
|
return value
|
||||||
|
|
||||||
|
# Execute function
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Cache result
|
||||||
|
self.set(cache_key, result, ttl_seconds=ttl_seconds, source=source)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# Global cache instance
|
||||||
|
_global_cache: Optional[DataCache] = None
|
||||||
|
_global_cache_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache() -> DataCache:
|
||||||
|
"""Get the global cache instance."""
|
||||||
|
global _global_cache
|
||||||
|
with _global_cache_lock:
|
||||||
|
if _global_cache is None:
|
||||||
|
_global_cache = DataCache()
|
||||||
|
return _global_cache
|
||||||
|
|
||||||
|
|
||||||
|
def reset_cache() -> None:
|
||||||
|
"""Reset the global cache instance. For testing."""
|
||||||
|
global _global_cache
|
||||||
|
with _global_cache_lock:
|
||||||
|
_global_cache = None
|
||||||
Loading…
Reference in New Issue