811 lines
22 KiB
Python
811 lines
22 KiB
Python
"""
|
|
Redis Cache Layer
|
|
=================
|
|
|
|
High-performance caching with Redis for market data, AI decisions,
|
|
and expensive computations.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import pickle
|
|
from typing import Any, Optional, Dict, List, Union
|
|
from datetime import datetime, timedelta
|
|
from decimal import Decimal
|
|
import hashlib
|
|
import asyncio
|
|
|
|
import redis.asyncio as redis
|
|
from redis.asyncio.lock import Lock
|
|
from redis.exceptions import RedisError, ConnectionError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CacheKey:
|
|
"""Cache key builder with namespacing"""
|
|
|
|
# Namespaces
|
|
MARKET_DATA = "market"
|
|
AI_DECISION = "ai"
|
|
SIGNAL = "signal"
|
|
CONGRESSIONAL = "congress"
|
|
NEWS = "news"
|
|
TECHNICAL = "tech"
|
|
POSITION = "position"
|
|
METRICS = "metrics"
|
|
|
|
@staticmethod
|
|
def build(*parts: Union[str, int, float]) -> str:
|
|
"""Build a cache key from parts"""
|
|
return ":".join(str(p) for p in parts)
|
|
|
|
@staticmethod
|
|
def market_data(ticker: str) -> str:
|
|
"""Key for market data"""
|
|
return CacheKey.build(CacheKey.MARKET_DATA, ticker)
|
|
|
|
@staticmethod
|
|
def ai_decision(ticker: str, date: str) -> str:
|
|
"""Key for AI trading decision"""
|
|
return CacheKey.build(CacheKey.AI_DECISION, ticker, date)
|
|
|
|
@staticmethod
|
|
def signal(ticker: str, signal_type: str) -> str:
|
|
"""Key for trading signal"""
|
|
return CacheKey.build(CacheKey.SIGNAL, ticker, signal_type)
|
|
|
|
@staticmethod
|
|
def congressional_trades(days_back: int) -> str:
|
|
"""Key for congressional trades"""
|
|
return CacheKey.build(CacheKey.CONGRESSIONAL, f"last_{days_back}_days")
|
|
|
|
@staticmethod
|
|
def news_sentiment(ticker: str) -> str:
|
|
"""Key for news sentiment"""
|
|
return CacheKey.build(CacheKey.NEWS, ticker, "sentiment")
|
|
|
|
@staticmethod
|
|
def technical_indicators(ticker: str) -> str:
|
|
"""Key for technical indicators"""
|
|
return CacheKey.build(CacheKey.TECHNICAL, ticker)
|
|
|
|
@staticmethod
|
|
def risk_metrics() -> str:
|
|
"""Key for risk metrics"""
|
|
return CacheKey.build(CacheKey.METRICS, "risk")
|
|
|
|
|
|
class RedisCache:
|
|
"""
|
|
Redis cache manager with connection pooling and error handling
|
|
"""
|
|
|
|
def __init__(self,
|
|
host: str = "localhost",
|
|
port: int = 6379,
|
|
db: int = 0,
|
|
password: Optional[str] = None,
|
|
max_connections: int = 50,
|
|
socket_timeout: int = 5,
|
|
retry_on_timeout: bool = True):
|
|
"""
|
|
Initialize Redis cache
|
|
|
|
Args:
|
|
host: Redis host
|
|
port: Redis port
|
|
db: Redis database number
|
|
password: Redis password
|
|
max_connections: Maximum connection pool size
|
|
socket_timeout: Socket timeout in seconds
|
|
retry_on_timeout: Retry on timeout errors
|
|
"""
|
|
self.host = host
|
|
self.port = port
|
|
self.db = db
|
|
|
|
# Connection pool for better performance
|
|
self.pool = redis.ConnectionPool(
|
|
host=host,
|
|
port=port,
|
|
db=db,
|
|
password=password,
|
|
max_connections=max_connections,
|
|
socket_timeout=socket_timeout,
|
|
socket_connect_timeout=socket_timeout,
|
|
retry_on_timeout=retry_on_timeout,
|
|
decode_responses=False # Handle encoding ourselves
|
|
)
|
|
|
|
self.redis: Optional[redis.Redis] = None
|
|
self.connected = False
|
|
|
|
# Default TTLs (seconds)
|
|
self.default_ttls = {
|
|
CacheKey.MARKET_DATA: 60, # 1 minute for market data
|
|
CacheKey.AI_DECISION: 3600, # 1 hour for AI decisions
|
|
CacheKey.SIGNAL: 300, # 5 minutes for signals
|
|
CacheKey.CONGRESSIONAL: 3600, # 1 hour for congressional
|
|
CacheKey.NEWS: 1800, # 30 minutes for news
|
|
CacheKey.TECHNICAL: 300, # 5 minutes for technical
|
|
CacheKey.METRICS: 60, # 1 minute for metrics
|
|
}
|
|
|
|
# Cache statistics
|
|
self.stats = {
|
|
'hits': 0,
|
|
'misses': 0,
|
|
'errors': 0,
|
|
'evictions': 0
|
|
}
|
|
|
|
async def connect(self) -> bool:
|
|
"""
|
|
Connect to Redis
|
|
|
|
Returns:
|
|
True if connected successfully
|
|
"""
|
|
try:
|
|
self.redis = redis.Redis(connection_pool=self.pool)
|
|
|
|
# Test connection
|
|
await self.redis.ping()
|
|
|
|
self.connected = True
|
|
logger.info(f"Connected to Redis at {self.host}:{self.port}")
|
|
|
|
# Set memory policy
|
|
await self._configure_memory_policy()
|
|
|
|
return True
|
|
|
|
except (RedisError, ConnectionError) as e:
|
|
logger.error(f"Failed to connect to Redis: {e}")
|
|
self.connected = False
|
|
return False
|
|
|
|
async def disconnect(self):
|
|
"""Disconnect from Redis"""
|
|
if self.redis:
|
|
await self.redis.close()
|
|
await self.pool.disconnect()
|
|
self.connected = False
|
|
logger.info("Disconnected from Redis")
|
|
|
|
async def _configure_memory_policy(self):
|
|
"""Configure Redis memory policy"""
|
|
try:
|
|
# Set max memory policy to LRU (Least Recently Used)
|
|
await self.redis.config_set('maxmemory-policy', 'allkeys-lru')
|
|
|
|
# Set max memory (optional, depends on your setup)
|
|
# await self.redis.config_set('maxmemory', '1gb')
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Could not configure memory policy: {e}")
|
|
|
|
# === Core Cache Operations ===
|
|
|
|
async def get(self, key: str) -> Optional[Any]:
|
|
"""
|
|
Get value from cache
|
|
|
|
Args:
|
|
key: Cache key
|
|
|
|
Returns:
|
|
Cached value or None if not found
|
|
"""
|
|
if not self.connected:
|
|
return None
|
|
|
|
try:
|
|
value = await self.redis.get(key)
|
|
|
|
if value is None:
|
|
self.stats['misses'] += 1
|
|
return None
|
|
|
|
self.stats['hits'] += 1
|
|
|
|
# Deserialize based on data type
|
|
return self._deserialize(value)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Cache get error for {key}: {e}")
|
|
self.stats['errors'] += 1
|
|
return None
|
|
|
|
async def set(self,
|
|
key: str,
|
|
value: Any,
|
|
ttl: Optional[int] = None,
|
|
nx: bool = False) -> bool:
|
|
"""
|
|
Set value in cache
|
|
|
|
Args:
|
|
key: Cache key
|
|
value: Value to cache
|
|
ttl: Time to live in seconds
|
|
nx: Only set if key doesn't exist
|
|
|
|
Returns:
|
|
True if set successfully
|
|
"""
|
|
if not self.connected:
|
|
return False
|
|
|
|
try:
|
|
# Determine TTL
|
|
if ttl is None:
|
|
namespace = key.split(':')[0]
|
|
ttl = self.default_ttls.get(namespace, 300)
|
|
|
|
# Serialize value
|
|
serialized = self._serialize(value)
|
|
|
|
# Set with TTL
|
|
if nx:
|
|
result = await self.redis.set(key, serialized, ex=ttl, nx=True)
|
|
else:
|
|
result = await self.redis.setex(key, ttl, serialized)
|
|
|
|
return bool(result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Cache set error for {key}: {e}")
|
|
self.stats['errors'] += 1
|
|
return False
|
|
|
|
async def delete(self, *keys: str) -> int:
|
|
"""
|
|
Delete keys from cache
|
|
|
|
Args:
|
|
*keys: Keys to delete
|
|
|
|
Returns:
|
|
Number of keys deleted
|
|
"""
|
|
if not self.connected or not keys:
|
|
return 0
|
|
|
|
try:
|
|
return await self.redis.delete(*keys)
|
|
except Exception as e:
|
|
logger.error(f"Cache delete error: {e}")
|
|
return 0
|
|
|
|
async def exists(self, key: str) -> bool:
|
|
"""
|
|
Check if key exists
|
|
|
|
Args:
|
|
key: Cache key
|
|
|
|
Returns:
|
|
True if exists
|
|
"""
|
|
if not self.connected:
|
|
return False
|
|
|
|
try:
|
|
return bool(await self.redis.exists(key))
|
|
except Exception:
|
|
return False
|
|
|
|
async def expire(self, key: str, ttl: int) -> bool:
|
|
"""
|
|
Set expiration on key
|
|
|
|
Args:
|
|
key: Cache key
|
|
ttl: Time to live in seconds
|
|
|
|
Returns:
|
|
True if expiration set
|
|
"""
|
|
if not self.connected:
|
|
return False
|
|
|
|
try:
|
|
return bool(await self.redis.expire(key, ttl))
|
|
except Exception:
|
|
return False
|
|
|
|
# === Batch Operations ===
|
|
|
|
async def mget(self, keys: List[str]) -> Dict[str, Any]:
|
|
"""
|
|
Get multiple values
|
|
|
|
Args:
|
|
keys: List of keys
|
|
|
|
Returns:
|
|
Dictionary of key-value pairs
|
|
"""
|
|
if not self.connected or not keys:
|
|
return {}
|
|
|
|
try:
|
|
values = await self.redis.mget(keys)
|
|
result = {}
|
|
|
|
for key, value in zip(keys, values):
|
|
if value is not None:
|
|
result[key] = self._deserialize(value)
|
|
self.stats['hits'] += 1
|
|
else:
|
|
self.stats['misses'] += 1
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Cache mget error: {e}")
|
|
return {}
|
|
|
|
async def mset(self,
|
|
data: Dict[str, Any],
|
|
ttl: Optional[int] = None) -> bool:
|
|
"""
|
|
Set multiple values
|
|
|
|
Args:
|
|
data: Dictionary of key-value pairs
|
|
ttl: Time to live in seconds
|
|
|
|
Returns:
|
|
True if all set successfully
|
|
"""
|
|
if not self.connected or not data:
|
|
return False
|
|
|
|
try:
|
|
# Use pipeline for atomic operations
|
|
pipe = self.redis.pipeline()
|
|
|
|
for key, value in data.items():
|
|
serialized = self._serialize(value)
|
|
|
|
if ttl is None:
|
|
namespace = key.split(':')[0]
|
|
key_ttl = self.default_ttls.get(namespace, 300)
|
|
else:
|
|
key_ttl = ttl
|
|
|
|
pipe.setex(key, key_ttl, serialized)
|
|
|
|
results = await pipe.execute()
|
|
return all(results)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Cache mset error: {e}")
|
|
return False
|
|
|
|
# === Pattern Operations ===
|
|
|
|
async def keys_pattern(self, pattern: str) -> List[str]:
|
|
"""
|
|
Get keys matching pattern
|
|
|
|
Args:
|
|
pattern: Redis pattern (e.g., "market:*")
|
|
|
|
Returns:
|
|
List of matching keys
|
|
"""
|
|
if not self.connected:
|
|
return []
|
|
|
|
try:
|
|
keys = await self.redis.keys(pattern)
|
|
return [k.decode('utf-8') if isinstance(k, bytes) else k for k in keys]
|
|
except Exception as e:
|
|
logger.error(f"Cache keys error: {e}")
|
|
return []
|
|
|
|
async def delete_pattern(self, pattern: str) -> int:
|
|
"""
|
|
Delete keys matching pattern
|
|
|
|
Args:
|
|
pattern: Redis pattern
|
|
|
|
Returns:
|
|
Number of keys deleted
|
|
"""
|
|
keys = await self.keys_pattern(pattern)
|
|
if keys:
|
|
return await self.delete(*keys)
|
|
return 0
|
|
|
|
# === Distributed Locking ===
|
|
|
|
async def acquire_lock(self,
|
|
name: str,
|
|
timeout: int = 10,
|
|
blocking: bool = True) -> Optional[Lock]:
|
|
"""
|
|
Acquire distributed lock
|
|
|
|
Args:
|
|
name: Lock name
|
|
timeout: Lock timeout in seconds
|
|
blocking: Whether to block waiting for lock
|
|
|
|
Returns:
|
|
Lock object or None
|
|
"""
|
|
if not self.connected:
|
|
return None
|
|
|
|
try:
|
|
lock = Lock(
|
|
self.redis,
|
|
f"lock:{name}",
|
|
timeout=timeout,
|
|
blocking=blocking,
|
|
blocking_timeout=5 if blocking else 0
|
|
)
|
|
|
|
if await lock.acquire():
|
|
return lock
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Lock acquire error: {e}")
|
|
return None
|
|
|
|
# === Specialized Cache Methods ===
|
|
|
|
async def cache_market_data(self,
|
|
ticker: str,
|
|
data: Dict[str, Any],
|
|
ttl: int = 60) -> bool:
|
|
"""
|
|
Cache market data with automatic expiration
|
|
|
|
Args:
|
|
ticker: Stock ticker
|
|
data: Market data dictionary
|
|
ttl: Time to live in seconds
|
|
|
|
Returns:
|
|
True if cached successfully
|
|
"""
|
|
key = CacheKey.market_data(ticker)
|
|
|
|
# Add timestamp
|
|
data['cached_at'] = datetime.now().isoformat()
|
|
|
|
return await self.set(key, data, ttl)
|
|
|
|
async def get_market_data(self, ticker: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get cached market data
|
|
|
|
Args:
|
|
ticker: Stock ticker
|
|
|
|
Returns:
|
|
Market data or None
|
|
"""
|
|
key = CacheKey.market_data(ticker)
|
|
data = await self.get(key)
|
|
|
|
if data and isinstance(data, dict):
|
|
# Check if data is stale (older than 2 minutes)
|
|
cached_at = data.get('cached_at')
|
|
if cached_at:
|
|
cache_time = datetime.fromisoformat(cached_at)
|
|
if (datetime.now() - cache_time).seconds > 120:
|
|
# Data is stale, delete it
|
|
await self.delete(key)
|
|
return None
|
|
|
|
return data
|
|
|
|
async def cache_ai_decision(self,
|
|
ticker: str,
|
|
date: str,
|
|
decision: str,
|
|
ttl: int = 3600) -> bool:
|
|
"""
|
|
Cache AI trading decision
|
|
|
|
Args:
|
|
ticker: Stock ticker
|
|
date: Analysis date
|
|
decision: AI decision
|
|
ttl: Time to live (default 1 hour)
|
|
|
|
Returns:
|
|
True if cached successfully
|
|
"""
|
|
key = CacheKey.ai_decision(ticker, date)
|
|
|
|
data = {
|
|
'decision': decision,
|
|
'ticker': ticker,
|
|
'date': date,
|
|
'cached_at': datetime.now().isoformat()
|
|
}
|
|
|
|
return await self.set(key, data, ttl)
|
|
|
|
async def get_ai_decision(self,
|
|
ticker: str,
|
|
date: str) -> Optional[str]:
|
|
"""
|
|
Get cached AI decision
|
|
|
|
Args:
|
|
ticker: Stock ticker
|
|
date: Analysis date
|
|
|
|
Returns:
|
|
AI decision or None
|
|
"""
|
|
key = CacheKey.ai_decision(ticker, date)
|
|
data = await self.get(key)
|
|
|
|
if data and isinstance(data, dict):
|
|
return data.get('decision')
|
|
return None
|
|
|
|
async def invalidate_ticker(self, ticker: str):
|
|
"""
|
|
Invalidate all cache entries for a ticker
|
|
|
|
Args:
|
|
ticker: Stock ticker
|
|
"""
|
|
patterns = [
|
|
f"{CacheKey.MARKET_DATA}:{ticker}*",
|
|
f"{CacheKey.AI_DECISION}:{ticker}*",
|
|
f"{CacheKey.SIGNAL}:{ticker}*",
|
|
f"{CacheKey.TECHNICAL}:{ticker}*",
|
|
f"{CacheKey.NEWS}:{ticker}*"
|
|
]
|
|
|
|
for pattern in patterns:
|
|
await self.delete_pattern(pattern)
|
|
|
|
# === Helper Methods ===
|
|
|
|
def _serialize(self, value: Any) -> bytes:
|
|
"""
|
|
Serialize value for storage
|
|
|
|
Args:
|
|
value: Value to serialize
|
|
|
|
Returns:
|
|
Serialized bytes
|
|
"""
|
|
# Handle different types
|
|
if isinstance(value, (str, int, float)):
|
|
return json.dumps(value).encode('utf-8')
|
|
elif isinstance(value, Decimal):
|
|
return json.dumps(str(value)).encode('utf-8')
|
|
elif isinstance(value, (dict, list)):
|
|
return json.dumps(value, default=str).encode('utf-8')
|
|
else:
|
|
# Use pickle for complex objects
|
|
return pickle.dumps(value)
|
|
|
|
def _deserialize(self, value: bytes) -> Any:
|
|
"""
|
|
Deserialize value from storage
|
|
|
|
Args:
|
|
value: Serialized bytes
|
|
|
|
Returns:
|
|
Deserialized value
|
|
"""
|
|
if not value:
|
|
return None
|
|
|
|
# Try JSON first
|
|
try:
|
|
decoded = value.decode('utf-8')
|
|
return json.loads(decoded)
|
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
pass
|
|
|
|
# Try pickle
|
|
try:
|
|
return pickle.loads(value)
|
|
except Exception as e:
|
|
logger.error(f"Deserialization error: {e}")
|
|
return None
|
|
|
|
async def get_stats(self) -> Dict[str, Any]:
|
|
"""
|
|
Get cache statistics
|
|
|
|
Returns:
|
|
Statistics dictionary
|
|
"""
|
|
info = {}
|
|
|
|
if self.connected:
|
|
try:
|
|
# Get Redis server info
|
|
redis_info = await self.redis.info('stats')
|
|
info['redis'] = {
|
|
'total_connections': redis_info.get('total_connections_received', 0),
|
|
'commands_processed': redis_info.get('total_commands_processed', 0),
|
|
'keyspace_hits': redis_info.get('keyspace_hits', 0),
|
|
'keyspace_misses': redis_info.get('keyspace_misses', 0),
|
|
}
|
|
|
|
# Get memory info
|
|
memory_info = await self.redis.info('memory')
|
|
info['memory'] = {
|
|
'used_memory': memory_info.get('used_memory_human', '0'),
|
|
'peak_memory': memory_info.get('used_memory_peak_human', '0'),
|
|
}
|
|
|
|
# Get database size
|
|
info['db_size'] = await self.redis.dbsize()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting Redis stats: {e}")
|
|
|
|
# Add local stats
|
|
total_requests = self.stats['hits'] + self.stats['misses']
|
|
hit_rate = (self.stats['hits'] / total_requests * 100) if total_requests > 0 else 0
|
|
|
|
info['cache_stats'] = {
|
|
**self.stats,
|
|
'hit_rate': f"{hit_rate:.1f}%",
|
|
'total_requests': total_requests
|
|
}
|
|
|
|
return info
|
|
|
|
async def clear_all(self) -> bool:
|
|
"""
|
|
Clear all cache (USE WITH CAUTION)
|
|
|
|
Returns:
|
|
True if cleared successfully
|
|
"""
|
|
if not self.connected:
|
|
return False
|
|
|
|
try:
|
|
await self.redis.flushdb()
|
|
logger.warning("Cache cleared - all data deleted")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error clearing cache: {e}")
|
|
return False
|
|
|
|
|
|
# === Cache Decorator ===
|
|
|
|
def cached(ttl: int = 300, key_prefix: str = ""):
|
|
"""
|
|
Decorator for caching function results
|
|
|
|
Args:
|
|
ttl: Time to live in seconds
|
|
key_prefix: Prefix for cache key
|
|
"""
|
|
def decorator(func):
|
|
async def wrapper(*args, **kwargs):
|
|
# Build cache key from function name and arguments
|
|
cache_key_parts = [key_prefix or func.__name__]
|
|
|
|
# Add args to key
|
|
for arg in args:
|
|
if hasattr(arg, '__dict__'):
|
|
# Skip object instances
|
|
continue
|
|
cache_key_parts.append(str(arg))
|
|
|
|
# Add kwargs to key
|
|
for k, v in sorted(kwargs.items()):
|
|
cache_key_parts.append(f"{k}={v}")
|
|
|
|
cache_key = ":".join(cache_key_parts)
|
|
|
|
# Try to get from cache
|
|
if hasattr(wrapper, '_cache'):
|
|
cached_value = await wrapper._cache.get(cache_key)
|
|
if cached_value is not None:
|
|
return cached_value
|
|
|
|
# Call function
|
|
result = await func(*args, **kwargs)
|
|
|
|
# Cache result
|
|
if hasattr(wrapper, '_cache') and result is not None:
|
|
await wrapper._cache.set(cache_key, result, ttl)
|
|
|
|
return result
|
|
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
# === Cache Manager Singleton ===
|
|
|
|
class CacheManager:
|
|
"""Singleton cache manager"""
|
|
|
|
_instance: Optional[RedisCache] = None
|
|
|
|
@classmethod
|
|
def get_instance(cls) -> RedisCache:
|
|
"""Get or create cache instance"""
|
|
if cls._instance is None:
|
|
cls._instance = RedisCache()
|
|
return cls._instance
|
|
|
|
@classmethod
|
|
async def initialize(cls,
|
|
host: str = "localhost",
|
|
port: int = 6379,
|
|
**kwargs) -> RedisCache:
|
|
"""
|
|
Initialize cache manager
|
|
|
|
Args:
|
|
host: Redis host
|
|
port: Redis port
|
|
**kwargs: Additional Redis options
|
|
|
|
Returns:
|
|
Redis cache instance
|
|
"""
|
|
cls._instance = RedisCache(host, port, **kwargs)
|
|
await cls._instance.connect()
|
|
return cls._instance
|
|
|
|
|
|
# Example usage
|
|
async def main():
|
|
"""Example of using the cache layer"""
|
|
|
|
# Initialize cache
|
|
cache = await CacheManager.initialize()
|
|
|
|
# Cache market data
|
|
market_data = {
|
|
'last': 150.25,
|
|
'bid': 150.20,
|
|
'ask': 150.30,
|
|
'volume': 1000000
|
|
}
|
|
await cache.cache_market_data("AAPL", market_data)
|
|
|
|
# Get cached data
|
|
cached = await cache.get_market_data("AAPL")
|
|
print(f"Cached market data: {cached}")
|
|
|
|
# Cache AI decision
|
|
await cache.cache_ai_decision("NVDA", "2024-01-01", "BUY with high confidence")
|
|
|
|
# Get cached decision
|
|
decision = await cache.get_ai_decision("NVDA", "2024-01-01")
|
|
print(f"Cached AI decision: {decision}")
|
|
|
|
# Get cache statistics
|
|
stats = await cache.get_stats()
|
|
print(f"Cache stats: {stats}")
|
|
|
|
# Clean up
|
|
await cache.disconnect()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import asyncio
|
|
asyncio.run(main()) |