619 lines
18 KiB
Python
619 lines
18 KiB
Python
"""Tests for data caching layer.
|
|
|
|
Issue #12: [DATA-11] Data caching layer - FRED rate limits
|
|
"""
|
|
|
|
import pytest
|
|
import time
|
|
import threading
|
|
import tempfile
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch
|
|
|
|
from tradingagents.dataflows.cache import (
|
|
CacheEntry,
|
|
CacheStats,
|
|
CacheStatus,
|
|
RateLimitState,
|
|
MemoryCache,
|
|
FileCache,
|
|
DataCache,
|
|
get_cache,
|
|
reset_cache,
|
|
)
|
|
|
|
pytestmark = pytest.mark.unit
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_global_cache():
|
|
"""Reset global cache before each test."""
|
|
reset_cache()
|
|
yield
|
|
reset_cache()
|
|
|
|
|
|
class TestCacheEntry:
|
|
"""Tests for CacheEntry dataclass."""
|
|
|
|
def test_entry_creation(self):
|
|
"""Test creating a cache entry."""
|
|
entry = CacheEntry(
|
|
key="test_key",
|
|
value={"data": "test"},
|
|
created_at=datetime.now(),
|
|
expires_at=datetime.now() + timedelta(hours=1),
|
|
source="fred"
|
|
)
|
|
|
|
assert entry.key == "test_key"
|
|
assert entry.value == {"data": "test"}
|
|
assert entry.source == "fred"
|
|
assert entry.access_count == 0
|
|
|
|
def test_is_expired_false(self):
|
|
"""Test is_expired returns False for valid entry."""
|
|
entry = CacheEntry(
|
|
key="test",
|
|
value="data",
|
|
created_at=datetime.now(),
|
|
expires_at=datetime.now() + timedelta(hours=1)
|
|
)
|
|
assert entry.is_expired is False
|
|
|
|
def test_is_expired_true(self):
|
|
"""Test is_expired returns True for expired entry."""
|
|
entry = CacheEntry(
|
|
key="test",
|
|
value="data",
|
|
created_at=datetime.now() - timedelta(hours=2),
|
|
expires_at=datetime.now() - timedelta(hours=1)
|
|
)
|
|
assert entry.is_expired is True
|
|
|
|
def test_age_seconds(self):
|
|
"""Test age_seconds calculation."""
|
|
entry = CacheEntry(
|
|
key="test",
|
|
value="data",
|
|
created_at=datetime.now() - timedelta(seconds=60),
|
|
expires_at=datetime.now() + timedelta(hours=1)
|
|
)
|
|
assert 59 < entry.age_seconds < 61
|
|
|
|
def test_ttl_remaining(self):
|
|
"""Test ttl_remaining_seconds calculation."""
|
|
entry = CacheEntry(
|
|
key="test",
|
|
value="data",
|
|
created_at=datetime.now(),
|
|
expires_at=datetime.now() + timedelta(seconds=3600)
|
|
)
|
|
assert 3599 < entry.ttl_remaining_seconds <= 3600
|
|
|
|
def test_touch_updates_metadata(self):
|
|
"""Test touch updates access metadata."""
|
|
entry = CacheEntry(
|
|
key="test",
|
|
value="data",
|
|
created_at=datetime.now(),
|
|
expires_at=datetime.now() + timedelta(hours=1)
|
|
)
|
|
|
|
assert entry.access_count == 0
|
|
assert entry.last_accessed is None
|
|
|
|
entry.touch()
|
|
|
|
assert entry.access_count == 1
|
|
assert entry.last_accessed is not None
|
|
|
|
|
|
class TestCacheStats:
|
|
"""Tests for CacheStats dataclass."""
|
|
|
|
def test_default_values(self):
|
|
"""Test default values."""
|
|
stats = CacheStats()
|
|
assert stats.hits == 0
|
|
assert stats.misses == 0
|
|
assert stats.hit_rate == 0.0
|
|
|
|
def test_hit_rate_calculation(self):
|
|
"""Test hit rate calculation."""
|
|
stats = CacheStats(hits=75, misses=25)
|
|
assert stats.hit_rate == 75.0
|
|
|
|
def test_hit_rate_no_requests(self):
|
|
"""Test hit rate with no requests."""
|
|
stats = CacheStats()
|
|
assert stats.hit_rate == 0.0
|
|
|
|
def test_to_dict(self):
|
|
"""Test conversion to dictionary."""
|
|
stats = CacheStats(hits=10, misses=5, evictions=2)
|
|
d = stats.to_dict()
|
|
|
|
assert d["hits"] == 10
|
|
assert d["misses"] == 5
|
|
assert d["evictions"] == 2
|
|
assert "hit_rate" in d
|
|
|
|
|
|
class TestRateLimitState:
|
|
"""Tests for RateLimitState dataclass."""
|
|
|
|
def test_default_values(self):
|
|
"""Test default values."""
|
|
state = RateLimitState(source="fred")
|
|
assert state.source == "fred"
|
|
assert state.requests_made == 0
|
|
assert state.is_rate_limited is False
|
|
|
|
def test_record_request(self):
|
|
"""Test recording requests."""
|
|
state = RateLimitState(source="test", requests_limit=5)
|
|
|
|
for i in range(3):
|
|
state.record_request()
|
|
|
|
assert state.requests_made == 3
|
|
assert state.requests_remaining == 2
|
|
|
|
def test_is_rate_limited_after_backoff(self):
|
|
"""Test rate limiting after recording limit hit."""
|
|
state = RateLimitState(source="test")
|
|
|
|
assert state.is_rate_limited is False
|
|
|
|
state.record_rate_limit(backoff_seconds=1)
|
|
|
|
assert state.is_rate_limited is True
|
|
|
|
# Wait for backoff to expire
|
|
time.sleep(1.1)
|
|
assert state.is_rate_limited is False
|
|
|
|
def test_record_success_clears_backoff(self):
|
|
"""Test that success clears backoff."""
|
|
state = RateLimitState(source="test")
|
|
state.record_rate_limit(backoff_seconds=60)
|
|
assert state.is_rate_limited is True
|
|
|
|
state.record_success()
|
|
assert state.is_rate_limited is False
|
|
assert state.consecutive_failures == 0
|
|
|
|
def test_exponential_backoff(self):
|
|
"""Test exponential backoff on consecutive failures."""
|
|
state = RateLimitState(source="test")
|
|
|
|
# First failure - 1 second backoff
|
|
state.record_rate_limit(backoff_seconds=1)
|
|
assert state.consecutive_failures == 1
|
|
|
|
# Simulate recovery
|
|
state.backoff_until = None
|
|
|
|
# Second failure - 2 second backoff
|
|
state.record_rate_limit(backoff_seconds=1)
|
|
assert state.consecutive_failures == 2
|
|
|
|
|
|
class TestMemoryCache:
|
|
"""Tests for MemoryCache backend."""
|
|
|
|
def test_get_set(self):
|
|
"""Test basic get/set operations."""
|
|
cache = MemoryCache()
|
|
|
|
entry = CacheEntry(
|
|
key="test",
|
|
value="data",
|
|
created_at=datetime.now(),
|
|
expires_at=datetime.now() + timedelta(hours=1)
|
|
)
|
|
|
|
cache.set(entry)
|
|
result = cache.get("test")
|
|
|
|
assert result is not None
|
|
assert result.value == "data"
|
|
|
|
def test_get_missing(self):
|
|
"""Test getting missing key."""
|
|
cache = MemoryCache()
|
|
assert cache.get("nonexistent") is None
|
|
|
|
def test_delete(self):
|
|
"""Test deleting entry."""
|
|
cache = MemoryCache()
|
|
|
|
entry = CacheEntry(
|
|
key="test",
|
|
value="data",
|
|
created_at=datetime.now(),
|
|
expires_at=datetime.now() + timedelta(hours=1)
|
|
)
|
|
|
|
cache.set(entry)
|
|
assert cache.delete("test") is True
|
|
assert cache.get("test") is None
|
|
|
|
def test_delete_missing(self):
|
|
"""Test deleting missing key."""
|
|
cache = MemoryCache()
|
|
assert cache.delete("nonexistent") is False
|
|
|
|
def test_clear(self):
|
|
"""Test clearing cache."""
|
|
cache = MemoryCache()
|
|
|
|
for i in range(5):
|
|
cache.set(CacheEntry(
|
|
key=f"key_{i}",
|
|
value=f"value_{i}",
|
|
created_at=datetime.now(),
|
|
expires_at=datetime.now() + timedelta(hours=1)
|
|
))
|
|
|
|
count = cache.clear()
|
|
assert count == 5
|
|
assert cache.size() == 0
|
|
|
|
def test_lru_eviction(self):
|
|
"""Test LRU eviction when at capacity."""
|
|
cache = MemoryCache(max_size=3)
|
|
|
|
# Add 3 entries
|
|
for i in range(3):
|
|
cache.set(CacheEntry(
|
|
key=f"key_{i}",
|
|
value=f"value_{i}",
|
|
created_at=datetime.now(),
|
|
expires_at=datetime.now() + timedelta(hours=1)
|
|
))
|
|
|
|
# Access key_1 to make it recently used
|
|
cache.get("key_1")
|
|
|
|
# Add new entry, should evict key_0 (least recently used)
|
|
cache.set(CacheEntry(
|
|
key="key_3",
|
|
value="value_3",
|
|
created_at=datetime.now(),
|
|
expires_at=datetime.now() + timedelta(hours=1)
|
|
))
|
|
|
|
assert cache.size() == 3
|
|
assert cache.get("key_0") is None
|
|
assert cache.get("key_1") is not None
|
|
|
|
def test_thread_safety(self):
|
|
"""Test thread-safe operations."""
|
|
cache = MemoryCache()
|
|
errors = []
|
|
|
|
def write_entries(start):
|
|
try:
|
|
for i in range(100):
|
|
cache.set(CacheEntry(
|
|
key=f"key_{start}_{i}",
|
|
value=f"value_{start}_{i}",
|
|
created_at=datetime.now(),
|
|
expires_at=datetime.now() + timedelta(hours=1)
|
|
))
|
|
except Exception as e:
|
|
errors.append(e)
|
|
|
|
threads = [threading.Thread(target=write_entries, args=(i,)) for i in range(5)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
assert len(errors) == 0
|
|
assert cache.size() == 500
|
|
|
|
|
|
class TestFileCache:
|
|
"""Tests for FileCache backend."""
|
|
|
|
def test_get_set(self):
|
|
"""Test basic get/set operations."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cache = FileCache(cache_dir=Path(tmpdir))
|
|
|
|
entry = CacheEntry(
|
|
key="test",
|
|
value={"data": "test"},
|
|
created_at=datetime.now(),
|
|
expires_at=datetime.now() + timedelta(hours=1),
|
|
source="test"
|
|
)
|
|
|
|
cache.set(entry)
|
|
result = cache.get("test")
|
|
|
|
assert result is not None
|
|
assert result.value == {"data": "test"}
|
|
assert result.source == "test"
|
|
|
|
def test_get_missing(self):
|
|
"""Test getting missing key."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cache = FileCache(cache_dir=Path(tmpdir))
|
|
assert cache.get("nonexistent") is None
|
|
|
|
def test_delete(self):
|
|
"""Test deleting entry."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cache = FileCache(cache_dir=Path(tmpdir))
|
|
|
|
entry = CacheEntry(
|
|
key="test",
|
|
value="data",
|
|
created_at=datetime.now(),
|
|
expires_at=datetime.now() + timedelta(hours=1)
|
|
)
|
|
|
|
cache.set(entry)
|
|
assert cache.delete("test") is True
|
|
assert cache.get("test") is None
|
|
|
|
def test_clear(self):
|
|
"""Test clearing cache."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
cache = FileCache(cache_dir=Path(tmpdir))
|
|
|
|
for i in range(3):
|
|
cache.set(CacheEntry(
|
|
key=f"key_{i}",
|
|
value=f"value_{i}",
|
|
created_at=datetime.now(),
|
|
expires_at=datetime.now() + timedelta(hours=1)
|
|
))
|
|
|
|
count = cache.clear()
|
|
assert count == 3
|
|
assert cache.size() == 0
|
|
|
|
|
|
class TestDataCache:
|
|
"""Tests for DataCache main class."""
|
|
|
|
def test_get_set_basic(self):
|
|
"""Test basic get/set operations."""
|
|
cache = DataCache()
|
|
|
|
cache.set("test_key", {"data": "test"}, source="fred")
|
|
value, status = cache.get("test_key")
|
|
|
|
assert status == CacheStatus.HIT
|
|
assert value == {"data": "test"}
|
|
|
|
def test_get_miss(self):
|
|
"""Test cache miss."""
|
|
cache = DataCache()
|
|
|
|
value, status = cache.get("nonexistent")
|
|
|
|
assert status == CacheStatus.MISS
|
|
assert value is None
|
|
|
|
def test_get_expired(self):
|
|
"""Test getting expired entry."""
|
|
cache = DataCache()
|
|
|
|
# Set with very short TTL
|
|
cache.set("test", "data", ttl_seconds=0, source="test")
|
|
|
|
# Wait for expiration
|
|
time.sleep(0.1)
|
|
|
|
value, status = cache.get("test", serve_stale_if_rate_limited=False)
|
|
|
|
assert status == CacheStatus.EXPIRED
|
|
assert value is None
|
|
|
|
def test_serve_stale_when_rate_limited(self):
|
|
"""Test serving stale data when rate limited."""
|
|
cache = DataCache()
|
|
|
|
# Set entry that will expire
|
|
cache.set("test", "stale_data", ttl_seconds=0, source="test")
|
|
time.sleep(0.1)
|
|
|
|
# Simulate rate limit
|
|
cache.record_rate_limit("test", backoff_seconds=60)
|
|
|
|
# Should get stale data
|
|
value, status = cache.get("test", serve_stale_if_rate_limited=True)
|
|
|
|
assert status == CacheStatus.STALE
|
|
assert value == "stale_data"
|
|
|
|
def test_delete(self):
|
|
"""Test deleting entry."""
|
|
cache = DataCache()
|
|
|
|
cache.set("test", "data", source="test")
|
|
assert cache.delete("test") is True
|
|
|
|
value, status = cache.get("test")
|
|
assert status == CacheStatus.MISS
|
|
|
|
def test_clear_all(self):
|
|
"""Test clearing all entries."""
|
|
cache = DataCache()
|
|
|
|
cache.set("key1", "value1", source="fred")
|
|
cache.set("key2", "value2", source="yfinance")
|
|
|
|
count = cache.clear()
|
|
assert count == 2
|
|
|
|
def test_clear_by_source(self):
|
|
"""Test clearing entries by source."""
|
|
cache = DataCache()
|
|
|
|
cache.set("fred_key", "fred_data", source="fred")
|
|
cache.set("yf_key", "yf_data", source="yfinance")
|
|
|
|
count = cache.clear(source="fred")
|
|
assert count == 1
|
|
|
|
# yfinance entry should still exist
|
|
value, status = cache.get("yf_key")
|
|
assert status == CacheStatus.HIT
|
|
|
|
def test_key_with_params(self):
|
|
"""Test key generation with params."""
|
|
cache = DataCache()
|
|
|
|
cache.set("series", "data1", source="fred", series_id="FEDFUNDS")
|
|
cache.set("series", "data2", source="fred", series_id="DGS10")
|
|
|
|
value1, _ = cache.get("series", series_id="FEDFUNDS")
|
|
value2, _ = cache.get("series", series_id="DGS10")
|
|
|
|
assert value1 == "data1"
|
|
assert value2 == "data2"
|
|
|
|
def test_stats_tracking(self):
|
|
"""Test statistics tracking."""
|
|
cache = DataCache()
|
|
|
|
# Miss
|
|
cache.get("missing")
|
|
|
|
# Hit
|
|
cache.set("present", "data", source="test")
|
|
cache.get("present")
|
|
cache.get("present")
|
|
|
|
stats = cache.get_stats()
|
|
assert stats.misses == 1
|
|
assert stats.hits == 2
|
|
|
|
def test_rate_limit_tracking(self):
|
|
"""Test rate limit state tracking."""
|
|
cache = DataCache()
|
|
|
|
assert cache.is_rate_limited("fred") is False
|
|
|
|
cache.record_rate_limit("fred", backoff_seconds=1)
|
|
assert cache.is_rate_limited("fred") is True
|
|
|
|
time.sleep(1.1)
|
|
assert cache.is_rate_limited("fred") is False
|
|
|
|
def test_cached_decorator(self):
|
|
"""Test @cached decorator."""
|
|
cache = DataCache()
|
|
call_count = [0]
|
|
|
|
@cache.cached(ttl_seconds=300, source="test")
|
|
def expensive_function(x):
|
|
call_count[0] += 1
|
|
return x * 2
|
|
|
|
# First call - executes function
|
|
result1 = expensive_function(5)
|
|
assert result1 == 10
|
|
assert call_count[0] == 1
|
|
|
|
# Second call - from cache
|
|
result2 = expensive_function(5)
|
|
assert result2 == 10
|
|
assert call_count[0] == 1
|
|
|
|
# Different argument - executes function
|
|
result3 = expensive_function(10)
|
|
assert result3 == 20
|
|
assert call_count[0] == 2
|
|
|
|
def test_default_ttls_by_source(self):
|
|
"""Test default TTLs are applied by source."""
|
|
cache = DataCache()
|
|
|
|
# FRED default is 24 hours
|
|
cache.set("fred_data", "data", source="fred")
|
|
entry = cache._backend.get(cache._generate_key("fred_data"))
|
|
|
|
# Should have ~24 hour TTL
|
|
assert entry.ttl_remaining_seconds > 3600 * 23
|
|
|
|
|
|
class TestGlobalCache:
|
|
"""Tests for global cache functions."""
|
|
|
|
def test_get_cache_singleton(self):
|
|
"""Test get_cache returns singleton."""
|
|
cache1 = get_cache()
|
|
cache2 = get_cache()
|
|
assert cache1 is cache2
|
|
|
|
def test_reset_cache(self):
|
|
"""Test reset_cache creates new instance."""
|
|
cache1 = get_cache()
|
|
reset_cache()
|
|
cache2 = get_cache()
|
|
assert cache1 is not cache2
|
|
|
|
|
|
class TestCacheIntegration:
|
|
"""Integration tests for cache with rate limiting."""
|
|
|
|
def test_rate_limited_fetch_pattern(self):
|
|
"""Test typical pattern: cache + rate limit handling."""
|
|
cache = DataCache()
|
|
fetch_count = [0]
|
|
|
|
def fetch_data(key):
|
|
"""Simulate data fetch with rate limit."""
|
|
# Check rate limit first
|
|
if cache.is_rate_limited("api"):
|
|
# Try stale cache
|
|
value, status = cache.get(key, serve_stale_if_rate_limited=True)
|
|
if status == CacheStatus.STALE:
|
|
return value
|
|
raise RuntimeError("Rate limited and no stale data")
|
|
|
|
# Check cache
|
|
value, status = cache.get(key)
|
|
if status == CacheStatus.HIT:
|
|
return value
|
|
|
|
# Fetch fresh data
|
|
fetch_count[0] += 1
|
|
cache.record_request("api")
|
|
|
|
# Simulate API response
|
|
data = f"data_for_{key}"
|
|
|
|
cache.set(key, data, source="api", ttl_seconds=1)
|
|
cache.record_success("api")
|
|
|
|
return data
|
|
|
|
# First fetch - from API
|
|
result1 = fetch_data("key1")
|
|
assert result1 == "data_for_key1"
|
|
assert fetch_count[0] == 1
|
|
|
|
# Second fetch - from cache
|
|
result2 = fetch_data("key1")
|
|
assert result2 == "data_for_key1"
|
|
assert fetch_count[0] == 1 # No additional fetch
|
|
|
|
# Wait for expiration and simulate rate limit
|
|
time.sleep(1.1)
|
|
cache.record_rate_limit("api", backoff_seconds=60)
|
|
|
|
# Should get stale data
|
|
result3 = fetch_data("key1")
|
|
assert result3 == "data_for_key1"
|
|
assert fetch_count[0] == 1 # Still no additional fetch
|