384 lines
14 KiB
Python
384 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
自适应缓存系统
|
||
根据数据库可用性自动选择最佳缓存策略
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import pickle
|
||
import hashlib
|
||
import logging
|
||
from datetime import datetime, timedelta
|
||
from pathlib import Path
|
||
from typing import Any, Dict, Optional, Union
|
||
import pandas as pd
|
||
|
||
from ..config.database_manager import get_database_manager
|
||
|
||
class AdaptiveCacheSystem:
|
||
"""自适应缓存系统"""
|
||
|
||
def __init__(self, cache_dir: str = "data/cache"):
|
||
self.logger = logging.getLogger(__name__)
|
||
|
||
# 获取数据库管理器
|
||
self.db_manager = get_database_manager()
|
||
|
||
# 设置缓存目录
|
||
self.cache_dir = Path(cache_dir)
|
||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 获取配置
|
||
self.config = self.db_manager.get_config()
|
||
self.cache_config = self.config["cache"]
|
||
|
||
# 初始化缓存后端
|
||
self.primary_backend = self.cache_config["primary_backend"]
|
||
self.fallback_enabled = self.cache_config["fallback_enabled"]
|
||
|
||
self.logger.info(f"自适应缓存系统初始化 - 主要后端: {self.primary_backend}")
|
||
|
||
def _get_cache_key(self, symbol: str, start_date: str = "", end_date: str = "",
|
||
data_source: str = "default", data_type: str = "stock_data") -> str:
|
||
"""生成缓存键"""
|
||
key_data = f"{symbol}_{start_date}_{end_date}_{data_source}_{data_type}"
|
||
return hashlib.md5(key_data.encode()).hexdigest()
|
||
|
||
def _get_ttl_seconds(self, symbol: str, data_type: str = "stock_data") -> int:
|
||
"""获取TTL秒数"""
|
||
# 判断市场类型
|
||
if len(symbol) == 6 and symbol.isdigit():
|
||
market = "china"
|
||
else:
|
||
market = "us"
|
||
|
||
# 获取TTL配置
|
||
ttl_key = f"{market}_{data_type}"
|
||
ttl_seconds = self.cache_config["ttl_settings"].get(ttl_key, 7200)
|
||
return ttl_seconds
|
||
|
||
def _is_cache_valid(self, cache_time: datetime, ttl_seconds: int) -> bool:
|
||
"""检查缓存是否有效"""
|
||
if cache_time is None:
|
||
return False
|
||
|
||
expiry_time = cache_time + timedelta(seconds=ttl_seconds)
|
||
return datetime.now() < expiry_time
|
||
|
||
def _save_to_file(self, cache_key: str, data: Any, metadata: Dict) -> bool:
|
||
"""保存到文件缓存"""
|
||
try:
|
||
cache_file = self.cache_dir / f"{cache_key}.pkl"
|
||
cache_data = {
|
||
'data': data,
|
||
'metadata': metadata,
|
||
'timestamp': datetime.now(),
|
||
'backend': 'file'
|
||
}
|
||
|
||
with open(cache_file, 'wb') as f:
|
||
pickle.dump(cache_data, f)
|
||
|
||
self.logger.debug(f"文件缓存保存成功: {cache_key}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"文件缓存保存失败: {e}")
|
||
return False
|
||
|
||
def _load_from_file(self, cache_key: str) -> Optional[Dict]:
|
||
"""从文件缓存加载"""
|
||
try:
|
||
cache_file = self.cache_dir / f"{cache_key}.pkl"
|
||
if not cache_file.exists():
|
||
return None
|
||
|
||
with open(cache_file, 'rb') as f:
|
||
cache_data = pickle.load(f)
|
||
|
||
self.logger.debug(f"文件缓存加载成功: {cache_key}")
|
||
return cache_data
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"文件缓存加载失败: {e}")
|
||
return None
|
||
|
||
def _save_to_redis(self, cache_key: str, data: Any, metadata: Dict, ttl_seconds: int) -> bool:
|
||
"""保存到Redis缓存"""
|
||
redis_client = self.db_manager.get_redis_client()
|
||
if not redis_client:
|
||
return False
|
||
|
||
try:
|
||
cache_data = {
|
||
'data': data,
|
||
'metadata': metadata,
|
||
'timestamp': datetime.now().isoformat(),
|
||
'backend': 'redis'
|
||
}
|
||
|
||
serialized_data = pickle.dumps(cache_data)
|
||
redis_client.setex(cache_key, ttl_seconds, serialized_data)
|
||
|
||
self.logger.debug(f"Redis缓存保存成功: {cache_key}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"Redis缓存保存失败: {e}")
|
||
return False
|
||
|
||
def _load_from_redis(self, cache_key: str) -> Optional[Dict]:
|
||
"""从Redis缓存加载"""
|
||
redis_client = self.db_manager.get_redis_client()
|
||
if not redis_client:
|
||
return None
|
||
|
||
try:
|
||
serialized_data = redis_client.get(cache_key)
|
||
if not serialized_data:
|
||
return None
|
||
|
||
cache_data = pickle.loads(serialized_data)
|
||
|
||
# 转换时间戳
|
||
if isinstance(cache_data['timestamp'], str):
|
||
cache_data['timestamp'] = datetime.fromisoformat(cache_data['timestamp'])
|
||
|
||
self.logger.debug(f"Redis缓存加载成功: {cache_key}")
|
||
return cache_data
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"Redis缓存加载失败: {e}")
|
||
return None
|
||
|
||
def _save_to_mongodb(self, cache_key: str, data: Any, metadata: Dict, ttl_seconds: int) -> bool:
|
||
"""保存到MongoDB缓存"""
|
||
mongodb_client = self.db_manager.get_mongodb_client()
|
||
if not mongodb_client:
|
||
return False
|
||
|
||
try:
|
||
db = mongodb_client.tradingagents
|
||
collection = db.cache
|
||
|
||
# 序列化数据
|
||
if isinstance(data, pd.DataFrame):
|
||
serialized_data = data.to_json()
|
||
data_type = 'dataframe'
|
||
else:
|
||
serialized_data = pickle.dumps(data).hex()
|
||
data_type = 'pickle'
|
||
|
||
cache_doc = {
|
||
'_id': cache_key,
|
||
'data': serialized_data,
|
||
'data_type': data_type,
|
||
'metadata': metadata,
|
||
'timestamp': datetime.now(),
|
||
'expires_at': datetime.now() + timedelta(seconds=ttl_seconds),
|
||
'backend': 'mongodb'
|
||
}
|
||
|
||
collection.replace_one({'_id': cache_key}, cache_doc, upsert=True)
|
||
|
||
self.logger.debug(f"MongoDB缓存保存成功: {cache_key}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"MongoDB缓存保存失败: {e}")
|
||
return False
|
||
|
||
def _load_from_mongodb(self, cache_key: str) -> Optional[Dict]:
|
||
"""从MongoDB缓存加载"""
|
||
mongodb_client = self.db_manager.get_mongodb_client()
|
||
if not mongodb_client:
|
||
return None
|
||
|
||
try:
|
||
db = mongodb_client.tradingagents
|
||
collection = db.cache
|
||
|
||
doc = collection.find_one({'_id': cache_key})
|
||
if not doc:
|
||
return None
|
||
|
||
# 检查是否过期
|
||
if doc.get('expires_at') and doc['expires_at'] < datetime.now():
|
||
collection.delete_one({'_id': cache_key})
|
||
return None
|
||
|
||
# 反序列化数据
|
||
if doc['data_type'] == 'dataframe':
|
||
data = pd.read_json(doc['data'])
|
||
else:
|
||
data = pickle.loads(bytes.fromhex(doc['data']))
|
||
|
||
cache_data = {
|
||
'data': data,
|
||
'metadata': doc['metadata'],
|
||
'timestamp': doc['timestamp'],
|
||
'backend': 'mongodb'
|
||
}
|
||
|
||
self.logger.debug(f"MongoDB缓存加载成功: {cache_key}")
|
||
return cache_data
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"MongoDB缓存加载失败: {e}")
|
||
return None
|
||
|
||
def save_data(self, symbol: str, data: Any, start_date: str = "", end_date: str = "",
|
||
data_source: str = "default", data_type: str = "stock_data") -> str:
|
||
"""保存数据到缓存"""
|
||
# 生成缓存键
|
||
cache_key = self._get_cache_key(symbol, start_date, end_date, data_source, data_type)
|
||
|
||
# 准备元数据
|
||
metadata = {
|
||
'symbol': symbol,
|
||
'start_date': start_date,
|
||
'end_date': end_date,
|
||
'data_source': data_source,
|
||
'data_type': data_type
|
||
}
|
||
|
||
# 获取TTL
|
||
ttl_seconds = self._get_ttl_seconds(symbol, data_type)
|
||
|
||
# 根据主要后端保存
|
||
success = False
|
||
|
||
if self.primary_backend == "redis":
|
||
success = self._save_to_redis(cache_key, data, metadata, ttl_seconds)
|
||
elif self.primary_backend == "mongodb":
|
||
success = self._save_to_mongodb(cache_key, data, metadata, ttl_seconds)
|
||
elif self.primary_backend == "file":
|
||
success = self._save_to_file(cache_key, data, metadata)
|
||
|
||
# 如果主要后端失败,使用降级策略
|
||
if not success and self.fallback_enabled:
|
||
self.logger.warning(f"主要后端({self.primary_backend})保存失败,使用文件缓存降级")
|
||
success = self._save_to_file(cache_key, data, metadata)
|
||
|
||
if success:
|
||
self.logger.info(f"数据缓存成功: {symbol} -> {cache_key} (后端: {self.primary_backend})")
|
||
else:
|
||
self.logger.error(f"数据缓存失败: {symbol}")
|
||
|
||
return cache_key
|
||
|
||
def load_data(self, cache_key: str) -> Optional[Any]:
|
||
"""从缓存加载数据"""
|
||
cache_data = None
|
||
|
||
# 根据主要后端加载
|
||
if self.primary_backend == "redis":
|
||
cache_data = self._load_from_redis(cache_key)
|
||
elif self.primary_backend == "mongodb":
|
||
cache_data = self._load_from_mongodb(cache_key)
|
||
elif self.primary_backend == "file":
|
||
cache_data = self._load_from_file(cache_key)
|
||
|
||
# 如果主要后端失败,尝试降级
|
||
if not cache_data and self.fallback_enabled:
|
||
self.logger.debug(f"主要后端({self.primary_backend})加载失败,尝试文件缓存")
|
||
cache_data = self._load_from_file(cache_key)
|
||
|
||
if not cache_data:
|
||
return None
|
||
|
||
# 检查缓存是否有效(仅对文件缓存,数据库缓存有自己的TTL机制)
|
||
if cache_data.get('backend') == 'file':
|
||
symbol = cache_data['metadata'].get('symbol', '')
|
||
data_type = cache_data['metadata'].get('data_type', 'stock_data')
|
||
ttl_seconds = self._get_ttl_seconds(symbol, data_type)
|
||
|
||
if not self._is_cache_valid(cache_data['timestamp'], ttl_seconds):
|
||
self.logger.debug(f"文件缓存已过期: {cache_key}")
|
||
return None
|
||
|
||
return cache_data['data']
|
||
|
||
def find_cached_data(self, symbol: str, start_date: str = "", end_date: str = "",
|
||
data_source: str = "default", data_type: str = "stock_data") -> Optional[str]:
|
||
"""查找缓存的数据"""
|
||
cache_key = self._get_cache_key(symbol, start_date, end_date, data_source, data_type)
|
||
|
||
# 检查缓存是否存在且有效
|
||
if self.load_data(cache_key) is not None:
|
||
return cache_key
|
||
|
||
return None
|
||
|
||
def get_cache_stats(self) -> Dict[str, Any]:
|
||
"""获取缓存统计信息"""
|
||
stats = {
|
||
'primary_backend': self.primary_backend,
|
||
'fallback_enabled': self.fallback_enabled,
|
||
'database_available': self.db_manager.is_database_available(),
|
||
'mongodb_available': self.db_manager.is_mongodb_available(),
|
||
'redis_available': self.db_manager.is_redis_available(),
|
||
'file_cache_directory': str(self.cache_dir),
|
||
'file_cache_count': len(list(self.cache_dir.glob("*.pkl"))),
|
||
}
|
||
|
||
# Redis统计
|
||
redis_client = self.db_manager.get_redis_client()
|
||
if redis_client:
|
||
try:
|
||
redis_info = redis_client.info()
|
||
stats['redis_memory_used'] = redis_info.get('used_memory_human', 'N/A')
|
||
stats['redis_keys'] = redis_client.dbsize()
|
||
except:
|
||
stats['redis_status'] = 'Error'
|
||
|
||
# MongoDB统计
|
||
mongodb_client = self.db_manager.get_mongodb_client()
|
||
if mongodb_client:
|
||
try:
|
||
db = mongodb_client.tradingagents
|
||
stats['mongodb_cache_count'] = db.cache.count_documents({})
|
||
except:
|
||
stats['mongodb_status'] = 'Error'
|
||
|
||
return stats
|
||
|
||
def clear_expired_cache(self):
|
||
"""清理过期缓存"""
|
||
self.logger.info("开始清理过期缓存...")
|
||
|
||
# 清理文件缓存
|
||
cleared_files = 0
|
||
for cache_file in self.cache_dir.glob("*.pkl"):
|
||
try:
|
||
with open(cache_file, 'rb') as f:
|
||
cache_data = pickle.load(f)
|
||
|
||
symbol = cache_data['metadata'].get('symbol', '')
|
||
data_type = cache_data['metadata'].get('data_type', 'stock_data')
|
||
ttl_seconds = self._get_ttl_seconds(symbol, data_type)
|
||
|
||
if not self._is_cache_valid(cache_data['timestamp'], ttl_seconds):
|
||
cache_file.unlink()
|
||
cleared_files += 1
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"清理缓存文件失败 {cache_file}: {e}")
|
||
|
||
self.logger.info(f"文件缓存清理完成,删除 {cleared_files} 个过期文件")
|
||
|
||
# MongoDB会自动清理过期文档(通过expires_at字段)
|
||
# Redis会自动清理过期键
|
||
|
||
|
||
# 全局缓存系统实例
|
||
_cache_system = None
|
||
|
||
def get_cache_system() -> AdaptiveCacheSystem:
|
||
"""获取全局自适应缓存系统实例"""
|
||
global _cache_system
|
||
if _cache_system is None:
|
||
_cache_system = AdaptiveCacheSystem()
|
||
return _cache_system
|