#!/usr/bin/env python3 """ 自适应缓存系统 根据数据库可用性自动选择最佳缓存策略 """ import os import json import pickle import hashlib import logging from datetime import datetime, timedelta from pathlib import Path from typing import Any, Dict, Optional, Union import pandas as pd from ..config.database_manager import get_database_manager class AdaptiveCacheSystem: """自适应缓存系统""" def __init__(self, cache_dir: str = "data/cache"): self.logger = logging.getLogger(__name__) # 获取数据库管理器 self.db_manager = get_database_manager() # 设置缓存目录 self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) # 获取配置 self.config = self.db_manager.get_config() self.cache_config = self.config["cache"] # 初始化缓存后端 self.primary_backend = self.cache_config["primary_backend"] self.fallback_enabled = self.cache_config["fallback_enabled"] self.logger.info(f"自适应缓存系统初始化 - 主要后端: {self.primary_backend}") def _get_cache_key(self, symbol: str, start_date: str = "", end_date: str = "", data_source: str = "default", data_type: str = "stock_data") -> str: """生成缓存键""" key_data = f"{symbol}_{start_date}_{end_date}_{data_source}_{data_type}" return hashlib.md5(key_data.encode()).hexdigest() def _get_ttl_seconds(self, symbol: str, data_type: str = "stock_data") -> int: """获取TTL秒数""" # 判断市场类型 if len(symbol) == 6 and symbol.isdigit(): market = "china" else: market = "us" # 获取TTL配置 ttl_key = f"{market}_{data_type}" ttl_seconds = self.cache_config["ttl_settings"].get(ttl_key, 7200) return ttl_seconds def _is_cache_valid(self, cache_time: datetime, ttl_seconds: int) -> bool: """检查缓存是否有效""" if cache_time is None: return False expiry_time = cache_time + timedelta(seconds=ttl_seconds) return datetime.now() < expiry_time def _save_to_file(self, cache_key: str, data: Any, metadata: Dict) -> bool: """保存到文件缓存""" try: cache_file = self.cache_dir / f"{cache_key}.pkl" cache_data = { 'data': data, 'metadata': metadata, 'timestamp': datetime.now(), 'backend': 'file' } with open(cache_file, 'wb') as f: pickle.dump(cache_data, f) self.logger.debug(f"文件缓存保存成功: {cache_key}") return True except Exception as e: self.logger.error(f"文件缓存保存失败: {e}") return False def _load_from_file(self, cache_key: str) -> Optional[Dict]: """从文件缓存加载""" try: cache_file = self.cache_dir / f"{cache_key}.pkl" if not cache_file.exists(): return None with open(cache_file, 'rb') as f: cache_data = pickle.load(f) self.logger.debug(f"文件缓存加载成功: {cache_key}") return cache_data except Exception as e: self.logger.error(f"文件缓存加载失败: {e}") return None def _save_to_redis(self, cache_key: str, data: Any, metadata: Dict, ttl_seconds: int) -> bool: """保存到Redis缓存""" redis_client = self.db_manager.get_redis_client() if not redis_client: return False try: cache_data = { 'data': data, 'metadata': metadata, 'timestamp': datetime.now().isoformat(), 'backend': 'redis' } serialized_data = pickle.dumps(cache_data) redis_client.setex(cache_key, ttl_seconds, serialized_data) self.logger.debug(f"Redis缓存保存成功: {cache_key}") return True except Exception as e: self.logger.error(f"Redis缓存保存失败: {e}") return False def _load_from_redis(self, cache_key: str) -> Optional[Dict]: """从Redis缓存加载""" redis_client = self.db_manager.get_redis_client() if not redis_client: return None try: serialized_data = redis_client.get(cache_key) if not serialized_data: return None cache_data = pickle.loads(serialized_data) # 转换时间戳 if isinstance(cache_data['timestamp'], str): cache_data['timestamp'] = datetime.fromisoformat(cache_data['timestamp']) self.logger.debug(f"Redis缓存加载成功: {cache_key}") return cache_data except Exception as e: self.logger.error(f"Redis缓存加载失败: {e}") return None def _save_to_mongodb(self, cache_key: str, data: Any, metadata: Dict, ttl_seconds: int) -> bool: """保存到MongoDB缓存""" mongodb_client = self.db_manager.get_mongodb_client() if not mongodb_client: return False try: db = mongodb_client.tradingagents collection = db.cache # 序列化数据 if isinstance(data, pd.DataFrame): serialized_data = data.to_json() data_type = 'dataframe' else: serialized_data = pickle.dumps(data).hex() data_type = 'pickle' cache_doc = { '_id': cache_key, 'data': serialized_data, 'data_type': data_type, 'metadata': metadata, 'timestamp': datetime.now(), 'expires_at': datetime.now() + timedelta(seconds=ttl_seconds), 'backend': 'mongodb' } collection.replace_one({'_id': cache_key}, cache_doc, upsert=True) self.logger.debug(f"MongoDB缓存保存成功: {cache_key}") return True except Exception as e: self.logger.error(f"MongoDB缓存保存失败: {e}") return False def _load_from_mongodb(self, cache_key: str) -> Optional[Dict]: """从MongoDB缓存加载""" mongodb_client = self.db_manager.get_mongodb_client() if not mongodb_client: return None try: db = mongodb_client.tradingagents collection = db.cache doc = collection.find_one({'_id': cache_key}) if not doc: return None # 检查是否过期 if doc.get('expires_at') and doc['expires_at'] < datetime.now(): collection.delete_one({'_id': cache_key}) return None # 反序列化数据 if doc['data_type'] == 'dataframe': data = pd.read_json(doc['data']) else: data = pickle.loads(bytes.fromhex(doc['data'])) cache_data = { 'data': data, 'metadata': doc['metadata'], 'timestamp': doc['timestamp'], 'backend': 'mongodb' } self.logger.debug(f"MongoDB缓存加载成功: {cache_key}") return cache_data except Exception as e: self.logger.error(f"MongoDB缓存加载失败: {e}") return None def save_data(self, symbol: str, data: Any, start_date: str = "", end_date: str = "", data_source: str = "default", data_type: str = "stock_data") -> str: """保存数据到缓存""" # 生成缓存键 cache_key = self._get_cache_key(symbol, start_date, end_date, data_source, data_type) # 准备元数据 metadata = { 'symbol': symbol, 'start_date': start_date, 'end_date': end_date, 'data_source': data_source, 'data_type': data_type } # 获取TTL ttl_seconds = self._get_ttl_seconds(symbol, data_type) # 根据主要后端保存 success = False if self.primary_backend == "redis": success = self._save_to_redis(cache_key, data, metadata, ttl_seconds) elif self.primary_backend == "mongodb": success = self._save_to_mongodb(cache_key, data, metadata, ttl_seconds) elif self.primary_backend == "file": success = self._save_to_file(cache_key, data, metadata) # 如果主要后端失败,使用降级策略 if not success and self.fallback_enabled: self.logger.warning(f"主要后端({self.primary_backend})保存失败,使用文件缓存降级") success = self._save_to_file(cache_key, data, metadata) if success: self.logger.info(f"数据缓存成功: {symbol} -> {cache_key} (后端: {self.primary_backend})") else: self.logger.error(f"数据缓存失败: {symbol}") return cache_key def load_data(self, cache_key: str) -> Optional[Any]: """从缓存加载数据""" cache_data = None # 根据主要后端加载 if self.primary_backend == "redis": cache_data = self._load_from_redis(cache_key) elif self.primary_backend == "mongodb": cache_data = self._load_from_mongodb(cache_key) elif self.primary_backend == "file": cache_data = self._load_from_file(cache_key) # 如果主要后端失败,尝试降级 if not cache_data and self.fallback_enabled: self.logger.debug(f"主要后端({self.primary_backend})加载失败,尝试文件缓存") cache_data = self._load_from_file(cache_key) if not cache_data: return None # 检查缓存是否有效(仅对文件缓存,数据库缓存有自己的TTL机制) if cache_data.get('backend') == 'file': symbol = cache_data['metadata'].get('symbol', '') data_type = cache_data['metadata'].get('data_type', 'stock_data') ttl_seconds = self._get_ttl_seconds(symbol, data_type) if not self._is_cache_valid(cache_data['timestamp'], ttl_seconds): self.logger.debug(f"文件缓存已过期: {cache_key}") return None return cache_data['data'] def find_cached_data(self, symbol: str, start_date: str = "", end_date: str = "", data_source: str = "default", data_type: str = "stock_data") -> Optional[str]: """查找缓存的数据""" cache_key = self._get_cache_key(symbol, start_date, end_date, data_source, data_type) # 检查缓存是否存在且有效 if self.load_data(cache_key) is not None: return cache_key return None def get_cache_stats(self) -> Dict[str, Any]: """获取缓存统计信息""" stats = { 'primary_backend': self.primary_backend, 'fallback_enabled': self.fallback_enabled, 'database_available': self.db_manager.is_database_available(), 'mongodb_available': self.db_manager.is_mongodb_available(), 'redis_available': self.db_manager.is_redis_available(), 'file_cache_directory': str(self.cache_dir), 'file_cache_count': len(list(self.cache_dir.glob("*.pkl"))), } # Redis统计 redis_client = self.db_manager.get_redis_client() if redis_client: try: redis_info = redis_client.info() stats['redis_memory_used'] = redis_info.get('used_memory_human', 'N/A') stats['redis_keys'] = redis_client.dbsize() except: stats['redis_status'] = 'Error' # MongoDB统计 mongodb_client = self.db_manager.get_mongodb_client() if mongodb_client: try: db = mongodb_client.tradingagents stats['mongodb_cache_count'] = db.cache.count_documents({}) except: stats['mongodb_status'] = 'Error' return stats def clear_expired_cache(self): """清理过期缓存""" self.logger.info("开始清理过期缓存...") # 清理文件缓存 cleared_files = 0 for cache_file in self.cache_dir.glob("*.pkl"): try: with open(cache_file, 'rb') as f: cache_data = pickle.load(f) symbol = cache_data['metadata'].get('symbol', '') data_type = cache_data['metadata'].get('data_type', 'stock_data') ttl_seconds = self._get_ttl_seconds(symbol, data_type) if not self._is_cache_valid(cache_data['timestamp'], ttl_seconds): cache_file.unlink() cleared_files += 1 except Exception as e: self.logger.error(f"清理缓存文件失败 {cache_file}: {e}") self.logger.info(f"文件缓存清理完成,删除 {cleared_files} 个过期文件") # MongoDB会自动清理过期文档(通过expires_at字段) # Redis会自动清理过期键 # 全局缓存系统实例 _cache_system = None def get_cache_system() -> AdaptiveCacheSystem: """获取全局自适应缓存系统实例""" global _cache_system if _cache_system is None: _cache_system = AdaptiveCacheSystem() return _cache_system