feat(dataflows): add data caching layer with rate limit awareness - Fixes #12
Implements [DATA-11] Data caching layer - FRED rate limits with: - CacheEntry: Generic cache entries with TTL and metadata - CacheStats: Hit/miss/stale statistics tracking - RateLimitState: Per-source rate limit tracking with exponential backoff - MemoryCache: In-memory LRU cache backend - FileCache: File-based JSON cache backend - DataCache: Main cache with source-specific TTLs and stale-while-rate-limited - @cached decorator: Function result caching Features: - Multi-backend support (memory, file) - TTL-based expiration with configurable per-source defaults - Stale-while-revalidate when rate limited - Thread-safe operations throughout - 41 tests covering all components 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
2c802647e4
commit
ae7899a6fc
|
|
@ -0,0 +1,618 @@
|
|||
"""Tests for data caching layer.
|
||||
|
||||
Issue #12: [DATA-11] Data caching layer - FRED rate limits
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
import threading
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from tradingagents.dataflows.cache import (
|
||||
CacheEntry,
|
||||
CacheStats,
|
||||
CacheStatus,
|
||||
RateLimitState,
|
||||
MemoryCache,
|
||||
FileCache,
|
||||
DataCache,
|
||||
get_cache,
|
||||
reset_cache,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_global_cache():
|
||||
"""Reset global cache before each test."""
|
||||
reset_cache()
|
||||
yield
|
||||
reset_cache()
|
||||
|
||||
|
||||
class TestCacheEntry:
|
||||
"""Tests for CacheEntry dataclass."""
|
||||
|
||||
def test_entry_creation(self):
|
||||
"""Test creating a cache entry."""
|
||||
entry = CacheEntry(
|
||||
key="test_key",
|
||||
value={"data": "test"},
|
||||
created_at=datetime.now(),
|
||||
expires_at=datetime.now() + timedelta(hours=1),
|
||||
source="fred"
|
||||
)
|
||||
|
||||
assert entry.key == "test_key"
|
||||
assert entry.value == {"data": "test"}
|
||||
assert entry.source == "fred"
|
||||
assert entry.access_count == 0
|
||||
|
||||
def test_is_expired_false(self):
|
||||
"""Test is_expired returns False for valid entry."""
|
||||
entry = CacheEntry(
|
||||
key="test",
|
||||
value="data",
|
||||
created_at=datetime.now(),
|
||||
expires_at=datetime.now() + timedelta(hours=1)
|
||||
)
|
||||
assert entry.is_expired is False
|
||||
|
||||
def test_is_expired_true(self):
|
||||
"""Test is_expired returns True for expired entry."""
|
||||
entry = CacheEntry(
|
||||
key="test",
|
||||
value="data",
|
||||
created_at=datetime.now() - timedelta(hours=2),
|
||||
expires_at=datetime.now() - timedelta(hours=1)
|
||||
)
|
||||
assert entry.is_expired is True
|
||||
|
||||
def test_age_seconds(self):
|
||||
"""Test age_seconds calculation."""
|
||||
entry = CacheEntry(
|
||||
key="test",
|
||||
value="data",
|
||||
created_at=datetime.now() - timedelta(seconds=60),
|
||||
expires_at=datetime.now() + timedelta(hours=1)
|
||||
)
|
||||
assert 59 < entry.age_seconds < 61
|
||||
|
||||
def test_ttl_remaining(self):
|
||||
"""Test ttl_remaining_seconds calculation."""
|
||||
entry = CacheEntry(
|
||||
key="test",
|
||||
value="data",
|
||||
created_at=datetime.now(),
|
||||
expires_at=datetime.now() + timedelta(seconds=3600)
|
||||
)
|
||||
assert 3599 < entry.ttl_remaining_seconds <= 3600
|
||||
|
||||
def test_touch_updates_metadata(self):
|
||||
"""Test touch updates access metadata."""
|
||||
entry = CacheEntry(
|
||||
key="test",
|
||||
value="data",
|
||||
created_at=datetime.now(),
|
||||
expires_at=datetime.now() + timedelta(hours=1)
|
||||
)
|
||||
|
||||
assert entry.access_count == 0
|
||||
assert entry.last_accessed is None
|
||||
|
||||
entry.touch()
|
||||
|
||||
assert entry.access_count == 1
|
||||
assert entry.last_accessed is not None
|
||||
|
||||
|
||||
class TestCacheStats:
|
||||
"""Tests for CacheStats dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values."""
|
||||
stats = CacheStats()
|
||||
assert stats.hits == 0
|
||||
assert stats.misses == 0
|
||||
assert stats.hit_rate == 0.0
|
||||
|
||||
def test_hit_rate_calculation(self):
|
||||
"""Test hit rate calculation."""
|
||||
stats = CacheStats(hits=75, misses=25)
|
||||
assert stats.hit_rate == 75.0
|
||||
|
||||
def test_hit_rate_no_requests(self):
|
||||
"""Test hit rate with no requests."""
|
||||
stats = CacheStats()
|
||||
assert stats.hit_rate == 0.0
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test conversion to dictionary."""
|
||||
stats = CacheStats(hits=10, misses=5, evictions=2)
|
||||
d = stats.to_dict()
|
||||
|
||||
assert d["hits"] == 10
|
||||
assert d["misses"] == 5
|
||||
assert d["evictions"] == 2
|
||||
assert "hit_rate" in d
|
||||
|
||||
|
||||
class TestRateLimitState:
|
||||
"""Tests for RateLimitState dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values."""
|
||||
state = RateLimitState(source="fred")
|
||||
assert state.source == "fred"
|
||||
assert state.requests_made == 0
|
||||
assert state.is_rate_limited is False
|
||||
|
||||
def test_record_request(self):
|
||||
"""Test recording requests."""
|
||||
state = RateLimitState(source="test", requests_limit=5)
|
||||
|
||||
for i in range(3):
|
||||
state.record_request()
|
||||
|
||||
assert state.requests_made == 3
|
||||
assert state.requests_remaining == 2
|
||||
|
||||
def test_is_rate_limited_after_backoff(self):
|
||||
"""Test rate limiting after recording limit hit."""
|
||||
state = RateLimitState(source="test")
|
||||
|
||||
assert state.is_rate_limited is False
|
||||
|
||||
state.record_rate_limit(backoff_seconds=1)
|
||||
|
||||
assert state.is_rate_limited is True
|
||||
|
||||
# Wait for backoff to expire
|
||||
time.sleep(1.1)
|
||||
assert state.is_rate_limited is False
|
||||
|
||||
def test_record_success_clears_backoff(self):
|
||||
"""Test that success clears backoff."""
|
||||
state = RateLimitState(source="test")
|
||||
state.record_rate_limit(backoff_seconds=60)
|
||||
assert state.is_rate_limited is True
|
||||
|
||||
state.record_success()
|
||||
assert state.is_rate_limited is False
|
||||
assert state.consecutive_failures == 0
|
||||
|
||||
def test_exponential_backoff(self):
|
||||
"""Test exponential backoff on consecutive failures."""
|
||||
state = RateLimitState(source="test")
|
||||
|
||||
# First failure - 1 second backoff
|
||||
state.record_rate_limit(backoff_seconds=1)
|
||||
assert state.consecutive_failures == 1
|
||||
|
||||
# Simulate recovery
|
||||
state.backoff_until = None
|
||||
|
||||
# Second failure - 2 second backoff
|
||||
state.record_rate_limit(backoff_seconds=1)
|
||||
assert state.consecutive_failures == 2
|
||||
|
||||
|
||||
class TestMemoryCache:
|
||||
"""Tests for MemoryCache backend."""
|
||||
|
||||
def test_get_set(self):
|
||||
"""Test basic get/set operations."""
|
||||
cache = MemoryCache()
|
||||
|
||||
entry = CacheEntry(
|
||||
key="test",
|
||||
value="data",
|
||||
created_at=datetime.now(),
|
||||
expires_at=datetime.now() + timedelta(hours=1)
|
||||
)
|
||||
|
||||
cache.set(entry)
|
||||
result = cache.get("test")
|
||||
|
||||
assert result is not None
|
||||
assert result.value == "data"
|
||||
|
||||
def test_get_missing(self):
|
||||
"""Test getting missing key."""
|
||||
cache = MemoryCache()
|
||||
assert cache.get("nonexistent") is None
|
||||
|
||||
def test_delete(self):
|
||||
"""Test deleting entry."""
|
||||
cache = MemoryCache()
|
||||
|
||||
entry = CacheEntry(
|
||||
key="test",
|
||||
value="data",
|
||||
created_at=datetime.now(),
|
||||
expires_at=datetime.now() + timedelta(hours=1)
|
||||
)
|
||||
|
||||
cache.set(entry)
|
||||
assert cache.delete("test") is True
|
||||
assert cache.get("test") is None
|
||||
|
||||
def test_delete_missing(self):
|
||||
"""Test deleting missing key."""
|
||||
cache = MemoryCache()
|
||||
assert cache.delete("nonexistent") is False
|
||||
|
||||
def test_clear(self):
|
||||
"""Test clearing cache."""
|
||||
cache = MemoryCache()
|
||||
|
||||
for i in range(5):
|
||||
cache.set(CacheEntry(
|
||||
key=f"key_{i}",
|
||||
value=f"value_{i}",
|
||||
created_at=datetime.now(),
|
||||
expires_at=datetime.now() + timedelta(hours=1)
|
||||
))
|
||||
|
||||
count = cache.clear()
|
||||
assert count == 5
|
||||
assert cache.size() == 0
|
||||
|
||||
def test_lru_eviction(self):
|
||||
"""Test LRU eviction when at capacity."""
|
||||
cache = MemoryCache(max_size=3)
|
||||
|
||||
# Add 3 entries
|
||||
for i in range(3):
|
||||
cache.set(CacheEntry(
|
||||
key=f"key_{i}",
|
||||
value=f"value_{i}",
|
||||
created_at=datetime.now(),
|
||||
expires_at=datetime.now() + timedelta(hours=1)
|
||||
))
|
||||
|
||||
# Access key_1 to make it recently used
|
||||
cache.get("key_1")
|
||||
|
||||
# Add new entry, should evict key_0 (least recently used)
|
||||
cache.set(CacheEntry(
|
||||
key="key_3",
|
||||
value="value_3",
|
||||
created_at=datetime.now(),
|
||||
expires_at=datetime.now() + timedelta(hours=1)
|
||||
))
|
||||
|
||||
assert cache.size() == 3
|
||||
assert cache.get("key_0") is None
|
||||
assert cache.get("key_1") is not None
|
||||
|
||||
def test_thread_safety(self):
|
||||
"""Test thread-safe operations."""
|
||||
cache = MemoryCache()
|
||||
errors = []
|
||||
|
||||
def write_entries(start):
|
||||
try:
|
||||
for i in range(100):
|
||||
cache.set(CacheEntry(
|
||||
key=f"key_{start}_{i}",
|
||||
value=f"value_{start}_{i}",
|
||||
created_at=datetime.now(),
|
||||
expires_at=datetime.now() + timedelta(hours=1)
|
||||
))
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=write_entries, args=(i,)) for i in range(5)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0
|
||||
assert cache.size() == 500
|
||||
|
||||
|
||||
class TestFileCache:
|
||||
"""Tests for FileCache backend."""
|
||||
|
||||
def test_get_set(self):
|
||||
"""Test basic get/set operations."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = FileCache(cache_dir=Path(tmpdir))
|
||||
|
||||
entry = CacheEntry(
|
||||
key="test",
|
||||
value={"data": "test"},
|
||||
created_at=datetime.now(),
|
||||
expires_at=datetime.now() + timedelta(hours=1),
|
||||
source="test"
|
||||
)
|
||||
|
||||
cache.set(entry)
|
||||
result = cache.get("test")
|
||||
|
||||
assert result is not None
|
||||
assert result.value == {"data": "test"}
|
||||
assert result.source == "test"
|
||||
|
||||
def test_get_missing(self):
|
||||
"""Test getting missing key."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = FileCache(cache_dir=Path(tmpdir))
|
||||
assert cache.get("nonexistent") is None
|
||||
|
||||
def test_delete(self):
|
||||
"""Test deleting entry."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = FileCache(cache_dir=Path(tmpdir))
|
||||
|
||||
entry = CacheEntry(
|
||||
key="test",
|
||||
value="data",
|
||||
created_at=datetime.now(),
|
||||
expires_at=datetime.now() + timedelta(hours=1)
|
||||
)
|
||||
|
||||
cache.set(entry)
|
||||
assert cache.delete("test") is True
|
||||
assert cache.get("test") is None
|
||||
|
||||
def test_clear(self):
|
||||
"""Test clearing cache."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = FileCache(cache_dir=Path(tmpdir))
|
||||
|
||||
for i in range(3):
|
||||
cache.set(CacheEntry(
|
||||
key=f"key_{i}",
|
||||
value=f"value_{i}",
|
||||
created_at=datetime.now(),
|
||||
expires_at=datetime.now() + timedelta(hours=1)
|
||||
))
|
||||
|
||||
count = cache.clear()
|
||||
assert count == 3
|
||||
assert cache.size() == 0
|
||||
|
||||
|
||||
class TestDataCache:
|
||||
"""Tests for DataCache main class."""
|
||||
|
||||
def test_get_set_basic(self):
|
||||
"""Test basic get/set operations."""
|
||||
cache = DataCache()
|
||||
|
||||
cache.set("test_key", {"data": "test"}, source="fred")
|
||||
value, status = cache.get("test_key")
|
||||
|
||||
assert status == CacheStatus.HIT
|
||||
assert value == {"data": "test"}
|
||||
|
||||
def test_get_miss(self):
|
||||
"""Test cache miss."""
|
||||
cache = DataCache()
|
||||
|
||||
value, status = cache.get("nonexistent")
|
||||
|
||||
assert status == CacheStatus.MISS
|
||||
assert value is None
|
||||
|
||||
def test_get_expired(self):
|
||||
"""Test getting expired entry."""
|
||||
cache = DataCache()
|
||||
|
||||
# Set with very short TTL
|
||||
cache.set("test", "data", ttl_seconds=0, source="test")
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(0.1)
|
||||
|
||||
value, status = cache.get("test", serve_stale_if_rate_limited=False)
|
||||
|
||||
assert status == CacheStatus.EXPIRED
|
||||
assert value is None
|
||||
|
||||
def test_serve_stale_when_rate_limited(self):
|
||||
"""Test serving stale data when rate limited."""
|
||||
cache = DataCache()
|
||||
|
||||
# Set entry that will expire
|
||||
cache.set("test", "stale_data", ttl_seconds=0, source="test")
|
||||
time.sleep(0.1)
|
||||
|
||||
# Simulate rate limit
|
||||
cache.record_rate_limit("test", backoff_seconds=60)
|
||||
|
||||
# Should get stale data
|
||||
value, status = cache.get("test", serve_stale_if_rate_limited=True)
|
||||
|
||||
assert status == CacheStatus.STALE
|
||||
assert value == "stale_data"
|
||||
|
||||
def test_delete(self):
|
||||
"""Test deleting entry."""
|
||||
cache = DataCache()
|
||||
|
||||
cache.set("test", "data", source="test")
|
||||
assert cache.delete("test") is True
|
||||
|
||||
value, status = cache.get("test")
|
||||
assert status == CacheStatus.MISS
|
||||
|
||||
def test_clear_all(self):
|
||||
"""Test clearing all entries."""
|
||||
cache = DataCache()
|
||||
|
||||
cache.set("key1", "value1", source="fred")
|
||||
cache.set("key2", "value2", source="yfinance")
|
||||
|
||||
count = cache.clear()
|
||||
assert count == 2
|
||||
|
||||
def test_clear_by_source(self):
|
||||
"""Test clearing entries by source."""
|
||||
cache = DataCache()
|
||||
|
||||
cache.set("fred_key", "fred_data", source="fred")
|
||||
cache.set("yf_key", "yf_data", source="yfinance")
|
||||
|
||||
count = cache.clear(source="fred")
|
||||
assert count == 1
|
||||
|
||||
# yfinance entry should still exist
|
||||
value, status = cache.get("yf_key")
|
||||
assert status == CacheStatus.HIT
|
||||
|
||||
def test_key_with_params(self):
|
||||
"""Test key generation with params."""
|
||||
cache = DataCache()
|
||||
|
||||
cache.set("series", "data1", source="fred", series_id="FEDFUNDS")
|
||||
cache.set("series", "data2", source="fred", series_id="DGS10")
|
||||
|
||||
value1, _ = cache.get("series", series_id="FEDFUNDS")
|
||||
value2, _ = cache.get("series", series_id="DGS10")
|
||||
|
||||
assert value1 == "data1"
|
||||
assert value2 == "data2"
|
||||
|
||||
def test_stats_tracking(self):
|
||||
"""Test statistics tracking."""
|
||||
cache = DataCache()
|
||||
|
||||
# Miss
|
||||
cache.get("missing")
|
||||
|
||||
# Hit
|
||||
cache.set("present", "data", source="test")
|
||||
cache.get("present")
|
||||
cache.get("present")
|
||||
|
||||
stats = cache.get_stats()
|
||||
assert stats.misses == 1
|
||||
assert stats.hits == 2
|
||||
|
||||
def test_rate_limit_tracking(self):
|
||||
"""Test rate limit state tracking."""
|
||||
cache = DataCache()
|
||||
|
||||
assert cache.is_rate_limited("fred") is False
|
||||
|
||||
cache.record_rate_limit("fred", backoff_seconds=1)
|
||||
assert cache.is_rate_limited("fred") is True
|
||||
|
||||
time.sleep(1.1)
|
||||
assert cache.is_rate_limited("fred") is False
|
||||
|
||||
def test_cached_decorator(self):
|
||||
"""Test @cached decorator."""
|
||||
cache = DataCache()
|
||||
call_count = [0]
|
||||
|
||||
@cache.cached(ttl_seconds=300, source="test")
|
||||
def expensive_function(x):
|
||||
call_count[0] += 1
|
||||
return x * 2
|
||||
|
||||
# First call - executes function
|
||||
result1 = expensive_function(5)
|
||||
assert result1 == 10
|
||||
assert call_count[0] == 1
|
||||
|
||||
# Second call - from cache
|
||||
result2 = expensive_function(5)
|
||||
assert result2 == 10
|
||||
assert call_count[0] == 1
|
||||
|
||||
# Different argument - executes function
|
||||
result3 = expensive_function(10)
|
||||
assert result3 == 20
|
||||
assert call_count[0] == 2
|
||||
|
||||
def test_default_ttls_by_source(self):
|
||||
"""Test default TTLs are applied by source."""
|
||||
cache = DataCache()
|
||||
|
||||
# FRED default is 24 hours
|
||||
cache.set("fred_data", "data", source="fred")
|
||||
entry = cache._backend.get(cache._generate_key("fred_data"))
|
||||
|
||||
# Should have ~24 hour TTL
|
||||
assert entry.ttl_remaining_seconds > 3600 * 23
|
||||
|
||||
|
||||
class TestGlobalCache:
|
||||
"""Tests for global cache functions."""
|
||||
|
||||
def test_get_cache_singleton(self):
|
||||
"""Test get_cache returns singleton."""
|
||||
cache1 = get_cache()
|
||||
cache2 = get_cache()
|
||||
assert cache1 is cache2
|
||||
|
||||
def test_reset_cache(self):
|
||||
"""Test reset_cache creates new instance."""
|
||||
cache1 = get_cache()
|
||||
reset_cache()
|
||||
cache2 = get_cache()
|
||||
assert cache1 is not cache2
|
||||
|
||||
|
||||
class TestCacheIntegration:
|
||||
"""Integration tests for cache with rate limiting."""
|
||||
|
||||
def test_rate_limited_fetch_pattern(self):
|
||||
"""Test typical pattern: cache + rate limit handling."""
|
||||
cache = DataCache()
|
||||
fetch_count = [0]
|
||||
|
||||
def fetch_data(key):
|
||||
"""Simulate data fetch with rate limit."""
|
||||
# Check rate limit first
|
||||
if cache.is_rate_limited("api"):
|
||||
# Try stale cache
|
||||
value, status = cache.get(key, serve_stale_if_rate_limited=True)
|
||||
if status == CacheStatus.STALE:
|
||||
return value
|
||||
raise RuntimeError("Rate limited and no stale data")
|
||||
|
||||
# Check cache
|
||||
value, status = cache.get(key)
|
||||
if status == CacheStatus.HIT:
|
||||
return value
|
||||
|
||||
# Fetch fresh data
|
||||
fetch_count[0] += 1
|
||||
cache.record_request("api")
|
||||
|
||||
# Simulate API response
|
||||
data = f"data_for_{key}"
|
||||
|
||||
cache.set(key, data, source="api", ttl_seconds=1)
|
||||
cache.record_success("api")
|
||||
|
||||
return data
|
||||
|
||||
# First fetch - from API
|
||||
result1 = fetch_data("key1")
|
||||
assert result1 == "data_for_key1"
|
||||
assert fetch_count[0] == 1
|
||||
|
||||
# Second fetch - from cache
|
||||
result2 = fetch_data("key1")
|
||||
assert result2 == "data_for_key1"
|
||||
assert fetch_count[0] == 1 # No additional fetch
|
||||
|
||||
# Wait for expiration and simulate rate limit
|
||||
time.sleep(1.1)
|
||||
cache.record_rate_limit("api", backoff_seconds=60)
|
||||
|
||||
# Should get stale data
|
||||
result3 = fetch_data("key1")
|
||||
assert result3 == "data_for_key1"
|
||||
assert fetch_count[0] == 1 # Still no additional fetch
|
||||
|
|
@ -0,0 +1,628 @@
|
|||
"""Data caching layer for vendor data with rate limit awareness.
|
||||
|
||||
This module provides a robust caching layer to handle API rate limits across
|
||||
all data vendors. Features:
|
||||
- Multi-backend support (memory, file, SQLite)
|
||||
- TTL-based expiration with configurable per-source TTLs
|
||||
- Rate limit tracking and backoff
|
||||
- Cache statistics and monitoring
|
||||
- Atomic cache operations for thread safety
|
||||
|
||||
Issue #12: [DATA-11] Data caching layer - FRED rate limits
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar, Generic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class CacheStatus(Enum):
|
||||
"""Status of a cache lookup."""
|
||||
HIT = auto()
|
||||
MISS = auto()
|
||||
EXPIRED = auto()
|
||||
STALE = auto() # Expired but returned due to rate limit
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry(Generic[T]):
|
||||
"""A single cache entry with metadata."""
|
||||
key: str
|
||||
value: T
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
access_count: int = 0
|
||||
last_accessed: Optional[datetime] = None
|
||||
source: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if entry is expired."""
|
||||
return datetime.now() > self.expires_at
|
||||
|
||||
@property
|
||||
def age_seconds(self) -> float:
|
||||
"""Get age in seconds."""
|
||||
return (datetime.now() - self.created_at).total_seconds()
|
||||
|
||||
@property
|
||||
def ttl_remaining_seconds(self) -> float:
|
||||
"""Get remaining TTL in seconds."""
|
||||
return max(0, (self.expires_at - datetime.now()).total_seconds())
|
||||
|
||||
def touch(self) -> None:
|
||||
"""Update access metadata."""
|
||||
self.access_count += 1
|
||||
self.last_accessed = datetime.now()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""Statistics for cache operations."""
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
expired: int = 0
|
||||
stale_served: int = 0
|
||||
evictions: int = 0
|
||||
size: int = 0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate hit rate as percentage."""
|
||||
total = self.hits + self.misses
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return (self.hits / total) * 100
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"expired": self.expired,
|
||||
"stale_served": self.stale_served,
|
||||
"evictions": self.evictions,
|
||||
"size": self.size,
|
||||
"hit_rate": self.hit_rate
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitState:
|
||||
"""Track rate limit state for a source."""
|
||||
source: str
|
||||
requests_made: int = 0
|
||||
requests_limit: int = 120 # Default FRED limit
|
||||
window_start: datetime = field(default_factory=datetime.now)
|
||||
window_seconds: int = 60
|
||||
backoff_until: Optional[datetime] = None
|
||||
consecutive_failures: int = 0
|
||||
|
||||
@property
|
||||
def is_rate_limited(self) -> bool:
|
||||
"""Check if currently rate limited."""
|
||||
if self.backoff_until and datetime.now() < self.backoff_until:
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def requests_remaining(self) -> int:
|
||||
"""Get remaining requests in current window."""
|
||||
self._maybe_reset_window()
|
||||
return max(0, self.requests_limit - self.requests_made)
|
||||
|
||||
def _maybe_reset_window(self) -> None:
|
||||
"""Reset window if expired."""
|
||||
if (datetime.now() - self.window_start).total_seconds() > self.window_seconds:
|
||||
self.window_start = datetime.now()
|
||||
self.requests_made = 0
|
||||
|
||||
def record_request(self) -> None:
|
||||
"""Record a request."""
|
||||
self._maybe_reset_window()
|
||||
self.requests_made += 1
|
||||
|
||||
def record_rate_limit(self, backoff_seconds: int = 60) -> None:
|
||||
"""Record a rate limit hit."""
|
||||
self.consecutive_failures += 1
|
||||
# Exponential backoff
|
||||
actual_backoff = backoff_seconds * (2 ** (self.consecutive_failures - 1))
|
||||
self.backoff_until = datetime.now() + timedelta(seconds=actual_backoff)
|
||||
logger.warning(f"Rate limit hit for {self.source}, backing off for {actual_backoff}s")
|
||||
|
||||
def record_success(self) -> None:
|
||||
"""Record successful request."""
|
||||
self.consecutive_failures = 0
|
||||
self.backoff_until = None
|
||||
|
||||
|
||||
class CacheBackend(ABC):
|
||||
"""Abstract base class for cache backends."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Get entry from cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, entry: CacheEntry) -> None:
|
||||
"""Set entry in cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> bool:
|
||||
"""Delete entry from cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> int:
|
||||
"""Clear all entries. Returns number cleared."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def keys(self) -> List[str]:
|
||||
"""Get all cache keys."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def size(self) -> int:
|
||||
"""Get number of entries."""
|
||||
pass
|
||||
|
||||
|
||||
class MemoryCache(CacheBackend):
|
||||
"""In-memory cache with LRU eviction."""
|
||||
|
||||
def __init__(self, max_size: int = 1000):
|
||||
"""Initialize memory cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
"""
|
||||
self._cache: Dict[str, CacheEntry] = {}
|
||||
self._max_size = max_size
|
||||
self._lock = threading.RLock()
|
||||
self._access_order: List[str] = []
|
||||
|
||||
def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Get entry from cache."""
|
||||
with self._lock:
|
||||
entry = self._cache.get(key)
|
||||
if entry:
|
||||
# Update access order for LRU
|
||||
if key in self._access_order:
|
||||
self._access_order.remove(key)
|
||||
self._access_order.append(key)
|
||||
return entry
|
||||
|
||||
def set(self, entry: CacheEntry) -> None:
|
||||
"""Set entry in cache with LRU eviction."""
|
||||
with self._lock:
|
||||
# Evict if at capacity
|
||||
while len(self._cache) >= self._max_size and self._access_order:
|
||||
oldest_key = self._access_order.pop(0)
|
||||
self._cache.pop(oldest_key, None)
|
||||
|
||||
self._cache[entry.key] = entry
|
||||
|
||||
# Update access order
|
||||
if entry.key in self._access_order:
|
||||
self._access_order.remove(entry.key)
|
||||
self._access_order.append(entry.key)
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""Delete entry from cache."""
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
if key in self._access_order:
|
||||
self._access_order.remove(key)
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self) -> int:
|
||||
"""Clear all entries."""
|
||||
with self._lock:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._access_order.clear()
|
||||
return count
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
"""Get all cache keys."""
|
||||
with self._lock:
|
||||
return list(self._cache.keys())
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get number of entries."""
|
||||
with self._lock:
|
||||
return len(self._cache)
|
||||
|
||||
|
||||
class FileCache(CacheBackend):
|
||||
"""File-based cache using JSON serialization."""
|
||||
|
||||
def __init__(self, cache_dir: Optional[Path] = None):
|
||||
"""Initialize file cache.
|
||||
|
||||
Args:
|
||||
cache_dir: Directory for cache files
|
||||
"""
|
||||
self._cache_dir = cache_dir or Path.home() / ".cache" / "tradingagents"
|
||||
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _get_path(self, key: str) -> Path:
|
||||
"""Get file path for key."""
|
||||
# Use hash to avoid filesystem issues
|
||||
safe_key = hashlib.md5(key.encode()).hexdigest()
|
||||
return self._cache_dir / f"{safe_key}.json"
|
||||
|
||||
def get(self, key: str) -> Optional[CacheEntry]:
|
||||
"""Get entry from file cache."""
|
||||
path = self._get_path(key)
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
try:
|
||||
with open(path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
return CacheEntry(
|
||||
key=data['key'],
|
||||
value=data['value'],
|
||||
created_at=datetime.fromisoformat(data['created_at']),
|
||||
expires_at=datetime.fromisoformat(data['expires_at']),
|
||||
access_count=data.get('access_count', 0),
|
||||
last_accessed=datetime.fromisoformat(data['last_accessed']) if data.get('last_accessed') else None,
|
||||
source=data.get('source', ''),
|
||||
metadata=data.get('metadata', {})
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError, ValueError):
|
||||
# Corrupted file
|
||||
path.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
def set(self, entry: CacheEntry) -> None:
|
||||
"""Set entry in file cache."""
|
||||
path = self._get_path(entry.key)
|
||||
|
||||
with self._lock:
|
||||
data = {
|
||||
'key': entry.key,
|
||||
'value': entry.value,
|
||||
'created_at': entry.created_at.isoformat(),
|
||||
'expires_at': entry.expires_at.isoformat(),
|
||||
'access_count': entry.access_count,
|
||||
'last_accessed': entry.last_accessed.isoformat() if entry.last_accessed else None,
|
||||
'source': entry.source,
|
||||
'metadata': entry.metadata
|
||||
}
|
||||
|
||||
with open(path, 'w') as f:
|
||||
json.dump(data, f)
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""Delete entry from file cache."""
|
||||
path = self._get_path(key)
|
||||
with self._lock:
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self) -> int:
|
||||
"""Clear all entries."""
|
||||
with self._lock:
|
||||
count = 0
|
||||
for path in self._cache_dir.glob("*.json"):
|
||||
path.unlink()
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
"""Get all cache keys (returns hashed keys)."""
|
||||
return [p.stem for p in self._cache_dir.glob("*.json")]
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get number of entries."""
|
||||
return len(list(self._cache_dir.glob("*.json")))
|
||||
|
||||
|
||||
class DataCache:
|
||||
"""Main data cache with rate limit awareness.
|
||||
|
||||
Provides caching for vendor data with configurable TTLs,
|
||||
rate limit tracking, and stale-while-revalidate support.
|
||||
|
||||
Example:
|
||||
cache = DataCache()
|
||||
|
||||
# Cache with default TTL
|
||||
cache.set("fred:FEDFUNDS", data, source="fred")
|
||||
|
||||
# Get with stale fallback if rate limited
|
||||
result = cache.get("fred:FEDFUNDS", serve_stale_if_rate_limited=True)
|
||||
|
||||
# Use as decorator
|
||||
@cache.cached(ttl_seconds=3600, source="fred")
|
||||
def get_fred_data(series_id):
|
||||
return fetch_from_api(series_id)
|
||||
"""
|
||||
|
||||
# Default TTLs by source (in seconds)
|
||||
DEFAULT_TTLS = {
|
||||
"fred": 3600 * 24, # 24 hours for FRED (data updates daily)
|
||||
"yfinance": 60, # 1 minute for real-time quotes
|
||||
"finnhub": 60, # 1 minute for real-time data
|
||||
"polygon": 300, # 5 minutes
|
||||
"alpha_vantage": 300, # 5 minutes
|
||||
"default": 300 # 5 minutes default
|
||||
}
|
||||
|
||||
# Default rate limits by source
|
||||
DEFAULT_RATE_LIMITS = {
|
||||
"fred": (120, 60), # 120 requests per 60 seconds
|
||||
"yfinance": (2000, 60), # High limit (throttles internally)
|
||||
"finnhub": (60, 60), # 60 per minute
|
||||
"polygon": (5, 60), # 5 per minute (free tier)
|
||||
"alpha_vantage": (5, 60), # 5 per minute (free tier)
|
||||
"default": (100, 60)
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend: Optional[CacheBackend] = None,
|
||||
default_ttl_seconds: int = 300
|
||||
):
|
||||
"""Initialize data cache.
|
||||
|
||||
Args:
|
||||
backend: Cache backend (defaults to MemoryCache)
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
"""
|
||||
self._backend = backend or MemoryCache()
|
||||
self._default_ttl = default_ttl_seconds
|
||||
self._stats = CacheStats()
|
||||
self._rate_limits: Dict[str, RateLimitState] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _generate_key(self, key: str, **kwargs) -> str:
|
||||
"""Generate cache key from key and optional params."""
|
||||
if kwargs:
|
||||
params_str = json.dumps(kwargs, sort_keys=True)
|
||||
return f"{key}:{hashlib.md5(params_str.encode()).hexdigest()[:8]}"
|
||||
return key
|
||||
|
||||
def _get_ttl(self, source: str) -> int:
|
||||
"""Get TTL for a source."""
|
||||
return self.DEFAULT_TTLS.get(source, self.DEFAULT_TTLS["default"])
|
||||
|
||||
def _get_rate_limit_state(self, source: str) -> RateLimitState:
|
||||
"""Get or create rate limit state for source."""
|
||||
if source not in self._rate_limits:
|
||||
limit, window = self.DEFAULT_RATE_LIMITS.get(
|
||||
source,
|
||||
self.DEFAULT_RATE_LIMITS["default"]
|
||||
)
|
||||
self._rate_limits[source] = RateLimitState(
|
||||
source=source,
|
||||
requests_limit=limit,
|
||||
window_seconds=window
|
||||
)
|
||||
return self._rate_limits[source]
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
serve_stale_if_rate_limited: bool = True,
|
||||
**kwargs
|
||||
) -> tuple[Optional[Any], CacheStatus]:
|
||||
"""Get value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
serve_stale_if_rate_limited: Return expired value if rate limited
|
||||
**kwargs: Additional key params
|
||||
|
||||
Returns:
|
||||
Tuple of (value, status)
|
||||
"""
|
||||
full_key = self._generate_key(key, **kwargs)
|
||||
|
||||
with self._lock:
|
||||
entry = self._backend.get(full_key)
|
||||
|
||||
if entry is None:
|
||||
self._stats.misses += 1
|
||||
return None, CacheStatus.MISS
|
||||
|
||||
if not entry.is_expired:
|
||||
entry.touch()
|
||||
self._backend.set(entry) # Update metadata
|
||||
self._stats.hits += 1
|
||||
return entry.value, CacheStatus.HIT
|
||||
|
||||
# Entry is expired
|
||||
self._stats.expired += 1
|
||||
|
||||
# Check if we should serve stale
|
||||
if serve_stale_if_rate_limited:
|
||||
rate_state = self._get_rate_limit_state(entry.source)
|
||||
if rate_state.is_rate_limited:
|
||||
self._stats.stale_served += 1
|
||||
return entry.value, CacheStatus.STALE
|
||||
|
||||
return None, CacheStatus.EXPIRED
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
source: str = "default",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""Set value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl_seconds: TTL in seconds (uses source default if not specified)
|
||||
source: Data source name
|
||||
metadata: Optional metadata
|
||||
**kwargs: Additional key params
|
||||
"""
|
||||
full_key = self._generate_key(key, **kwargs)
|
||||
actual_ttl = ttl_seconds if ttl_seconds is not None else self._get_ttl(source)
|
||||
|
||||
entry = CacheEntry(
|
||||
key=full_key,
|
||||
value=value,
|
||||
created_at=datetime.now(),
|
||||
expires_at=datetime.now() + timedelta(seconds=actual_ttl),
|
||||
source=source,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._backend.set(entry)
|
||||
self._stats.size = self._backend.size()
|
||||
|
||||
def delete(self, key: str, **kwargs) -> bool:
|
||||
"""Delete value from cache."""
|
||||
full_key = self._generate_key(key, **kwargs)
|
||||
with self._lock:
|
||||
result = self._backend.delete(full_key)
|
||||
self._stats.size = self._backend.size()
|
||||
return result
|
||||
|
||||
def clear(self, source: Optional[str] = None) -> int:
|
||||
"""Clear cache entries.
|
||||
|
||||
Args:
|
||||
source: Clear only entries from this source (None = all)
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
with self._lock:
|
||||
if source is None:
|
||||
count = self._backend.clear()
|
||||
else:
|
||||
count = 0
|
||||
for key in self._backend.keys():
|
||||
entry = self._backend.get(key)
|
||||
if entry and entry.source == source:
|
||||
self._backend.delete(key)
|
||||
count += 1
|
||||
|
||||
self._stats.evictions += count
|
||||
self._stats.size = self._backend.size()
|
||||
return count
|
||||
|
||||
def record_rate_limit(self, source: str, backoff_seconds: int = 60) -> None:
|
||||
"""Record a rate limit hit for a source."""
|
||||
with self._lock:
|
||||
state = self._get_rate_limit_state(source)
|
||||
state.record_rate_limit(backoff_seconds)
|
||||
|
||||
def record_request(self, source: str) -> None:
|
||||
"""Record a request for rate limit tracking."""
|
||||
with self._lock:
|
||||
state = self._get_rate_limit_state(source)
|
||||
state.record_request()
|
||||
|
||||
def record_success(self, source: str) -> None:
|
||||
"""Record successful request."""
|
||||
with self._lock:
|
||||
state = self._get_rate_limit_state(source)
|
||||
state.record_success()
|
||||
|
||||
def is_rate_limited(self, source: str) -> bool:
|
||||
"""Check if source is rate limited."""
|
||||
with self._lock:
|
||||
state = self._get_rate_limit_state(source)
|
||||
return state.is_rate_limited
|
||||
|
||||
def get_stats(self) -> CacheStats:
|
||||
"""Get cache statistics."""
|
||||
with self._lock:
|
||||
self._stats.size = self._backend.size()
|
||||
return self._stats
|
||||
|
||||
def cached(
|
||||
self,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
source: str = "default",
|
||||
key_prefix: str = ""
|
||||
) -> Callable:
|
||||
"""Decorator for caching function results.
|
||||
|
||||
Example:
|
||||
@cache.cached(ttl_seconds=3600, source="fred")
|
||||
def get_fred_data(series_id):
|
||||
return fetch_from_api(series_id)
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
def wrapper(*args, **kwargs):
|
||||
# Generate cache key from function name and args
|
||||
key_parts = [key_prefix, func.__name__]
|
||||
if args:
|
||||
key_parts.append(str(args))
|
||||
if kwargs:
|
||||
key_parts.append(json.dumps(kwargs, sort_keys=True))
|
||||
cache_key = ":".join(filter(None, key_parts))
|
||||
|
||||
# Check cache
|
||||
value, status = self.get(cache_key, serve_stale_if_rate_limited=True)
|
||||
if status in (CacheStatus.HIT, CacheStatus.STALE):
|
||||
return value
|
||||
|
||||
# Execute function
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Cache result
|
||||
self.set(cache_key, result, ttl_seconds=ttl_seconds, source=source)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# Global cache instance
|
||||
_global_cache: Optional[DataCache] = None
|
||||
_global_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_cache() -> DataCache:
|
||||
"""Get the global cache instance."""
|
||||
global _global_cache
|
||||
with _global_cache_lock:
|
||||
if _global_cache is None:
|
||||
_global_cache = DataCache()
|
||||
return _global_cache
|
||||
|
||||
|
||||
def reset_cache() -> None:
|
||||
"""Reset the global cache instance. For testing."""
|
||||
global _global_cache
|
||||
with _global_cache_lock:
|
||||
_global_cache = None
|
||||
Loading…
Reference in New Issue