TradingAgents/tradingagents/dataflows/cache_manager.py.diff

864 lines
36 KiB
Diff
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 文件差异报告
# 当前文件: tradingagents\dataflows\cache_manager.py
# 中文版文件: TradingAgentsCN\tradingagents\dataflows\cache_manager.py
# 生成时间: 周日 2025/07/06
--- current/cache_manager.py+++ chinese_version/cache_manager.py@@ -1,7 +1,7 @@ #!/usr/bin/env python3
"""
-Stock Data Cache Manager
-Supports local caching of stock data to reduce API calls and improve response speed
+股票数据缓存管理器
+支持本地缓存股票数据减少API调用提高响应速度
"""
import os
@@ -15,24 +15,24 @@
class StockDataCache:
- """Stock Data Cache Manager - Supports optimized caching for US and Chinese stock data"""
+ """股票数据缓存管理器 - 支持美股和A股数据缓存优化"""
def __init__(self, cache_dir: str = None):
"""
- Initialize cache manager
+ 初始化缓存管理器
Args:
- cache_dir: Cache directory path, defaults to tradingagents/dataflows/data_cache
+ cache_dir: 缓存目录路径,默认为 tradingagents/dataflows/data_cache
"""
if cache_dir is None:
- # Get current file directory
+ # 获取当前文件所在目录
current_dir = Path(__file__).parent
cache_dir = current_dir / "data_cache"
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(exist_ok=True)
- # Create subdirectories - categorized by market
+ # 创建子目录 - 按市场分类
self.us_stock_dir = self.cache_dir / "us_stocks"
self.china_stock_dir = self.cache_dir / "china_stocks"
self.us_news_dir = self.cache_dir / "us_news"
@@ -41,81 +41,81 @@ self.china_fundamentals_dir = self.cache_dir / "china_fundamentals"
self.metadata_dir = self.cache_dir / "metadata"
- # Create all directories
+ # 创建所有目录
for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
self.china_news_dir, self.us_fundamentals_dir,
self.china_fundamentals_dir, self.metadata_dir]:
dir_path.mkdir(exist_ok=True)
- # Cache configuration - different TTL settings for different markets
+ # 缓存配置 - 针对不同市场设置不同的TTL
self.cache_config = {
'us_stock_data': {
- 'ttl_hours': 2, # US stock data cached for 2 hours (considering API limits)
+ 'ttl_hours': 2, # 美股数据缓存2小时考虑到API限制
'max_files': 1000,
- 'description': 'US stock historical data'
+ 'description': '美股历史数据'
},
'china_stock_data': {
- 'ttl_hours': 1, # A-share data cached for 1 hour (high real-time requirement)
+ 'ttl_hours': 1, # A股数据缓存1小时实时性要求高
'max_files': 1000,
- 'description': 'A-share historical data'
+ 'description': 'A股历史数据'
},
'us_news': {
- 'ttl_hours': 6, # US stock news cached for 6 hours
+ 'ttl_hours': 6, # 美股新闻缓存6小时
'max_files': 500,
- 'description': 'US stock news data'
+ 'description': '美股新闻数据'
},
'china_news': {
- 'ttl_hours': 4, # A-share news cached for 4 hours
+ 'ttl_hours': 4, # A股新闻缓存4小时
'max_files': 500,
- 'description': 'A-share news data'
+ 'description': 'A股新闻数据'
},
'us_fundamentals': {
- 'ttl_hours': 24, # US stock fundamentals cached for 24 hours
+ 'ttl_hours': 24, # 美股基本面数据缓存24小时
'max_files': 200,
- 'description': 'US stock fundamentals data'
+ 'description': '美股基本面数据'
},
'china_fundamentals': {
- 'ttl_hours': 12, # A-share fundamentals cached for 12 hours
+ 'ttl_hours': 12, # A股基本面数据缓存12小时
'max_files': 200,
- 'description': 'A-share fundamentals data'
+ 'description': 'A股基本面数据'
}
}
- print(f"📁 Cache manager initialized, cache directory: {self.cache_dir}")
- print(f"🗄️ Database cache manager initialized")
- print(f" US stock data: ✅ Configured")
- print(f" A-share data: ✅ Configured")
+ print(f"📁 缓存管理器初始化完成,缓存目录: {self.cache_dir}")
+ print(f"🗄️ 数据库缓存管理器初始化完成")
+ print(f" 美股数据: ✅ 已配置")
+ print(f" A股数据: ✅ 已配置")
def _determine_market_type(self, symbol: str) -> str:
- """Determine market type based on stock symbol"""
+ """根据股票代码确定市场类型"""
import re
- # Check if it's Chinese A-share (6-digit number)
+ # 判断是否为中国A股6位数字
if re.match(r'^\d{6}$', str(symbol)):
return 'china'
else:
return 'us'
def _generate_cache_key(self, data_type: str, symbol: str, **kwargs) -> str:
- """Generate cache key"""
- # Create a string containing all parameters
+ """生成缓存键"""
+ # 创建一个包含所有参数的字符串
params_str = f"{data_type}_{symbol}"
for key, value in sorted(kwargs.items()):
params_str += f"_{key}_{value}"
- # Use MD5 to generate short unique identifier
+ # 使用MD5生成短的唯一标识
cache_key = hashlib.md5(params_str.encode()).hexdigest()[:12]
return f"{symbol}_{data_type}_{cache_key}"
def _get_cache_path(self, data_type: str, cache_key: str, file_format: str = "json", symbol: str = None) -> Path:
- """Get cache file path - supports market classification"""
+ """获取缓存文件路径 - 支持市场分类"""
if symbol:
market_type = self._determine_market_type(symbol)
else:
- # Try to extract market type from cache key
+ # 从缓存键中尝试提取市场类型
market_type = 'us' if not cache_key.startswith(('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')) else 'china'
- # Select directory based on data type and market type
+ # 根据数据类型和市场类型选择目录
if data_type == "stock_data":
base_dir = self.china_stock_dir if market_type == 'china' else self.us_stock_dir
elif data_type == "news":
@@ -128,20 +128,19 @@ return base_dir / f"{cache_key}.{file_format}"
def _get_metadata_path(self, cache_key: str) -> Path:
- """Get metadata file path"""
+ """获取元数据文件路径"""
return self.metadata_dir / f"{cache_key}_meta.json"
def _save_metadata(self, cache_key: str, metadata: Dict[str, Any]):
- """Save cache metadata"""
+ """保存元数据"""
metadata_path = self._get_metadata_path(cache_key)
- try:
- with open(metadata_path, 'w', encoding='utf-8') as f:
- json.dump(metadata, f, ensure_ascii=False, indent=2, default=str)
- except Exception as e:
- print(f"⚠️ Failed to save metadata: {e}")
+ metadata['cached_at'] = datetime.now().isoformat()
+
+ with open(metadata_path, 'w', encoding='utf-8') as f:
+ json.dump(metadata, f, ensure_ascii=False, indent=2)
def _load_metadata(self, cache_key: str) -> Optional[Dict[str, Any]]:
- """Load cache metadata"""
+ """加载元数据"""
metadata_path = self._get_metadata_path(cache_key)
if not metadata_path.exists():
return None
@@ -150,349 +149,355 @@ with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
- print(f"⚠️ Failed to load metadata: {e}")
- return None
-
- def save_stock_data(self, symbol: str, data: Union[str, pd.DataFrame],
- start_date: str, end_date: str, data_source: str = "unknown") -> str:
- """
- Save stock data to cache
+ print(f"⚠️ 加载元数据失败: {e}")
+ return None
+
+ def is_cache_valid(self, cache_key: str, max_age_hours: int = None, symbol: str = None, data_type: str = None) -> bool:
+ """检查缓存是否有效 - 支持智能TTL配置"""
+ metadata = self._load_metadata(cache_key)
+ if not metadata:
+ return False
+
+ # 如果没有指定TTL根据数据类型和市场自动确定
+ if max_age_hours is None:
+ if symbol and data_type:
+ market_type = self._determine_market_type(symbol)
+ cache_type = f"{market_type}_{data_type}"
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
+ else:
+ # 从元数据中获取信息
+ symbol = metadata.get('symbol', '')
+ data_type = metadata.get('data_type', 'stock_data')
+ market_type = self._determine_market_type(symbol)
+ cache_type = f"{market_type}_{data_type}"
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
+
+ cached_at = datetime.fromisoformat(metadata['cached_at'])
+ age = datetime.now() - cached_at
+
+ is_valid = age.total_seconds() < max_age_hours * 3600
+
+ if is_valid:
+ market_type = self._determine_market_type(metadata.get('symbol', ''))
+ cache_type = f"{market_type}_{metadata.get('data_type', 'stock_data')}"
+ desc = self.cache_config.get(cache_type, {}).get('description', '数据')
+ print(f"✅ 缓存有效: {desc} - {metadata.get('symbol')} (剩余 {max_age_hours - age.total_seconds()/3600:.1f}h)")
+
+ return is_valid
+
+ def save_stock_data(self, symbol: str, data: Union[pd.DataFrame, str],
+ start_date: str = None, end_date: str = None,
+ data_source: str = "unknown") -> str:
+ """
+ 保存股票数据到缓存 - 支持美股和A股分类存储
Args:
- symbol: Stock symbol
- data: Stock data (string or DataFrame)
- start_date: Start date
- end_date: End date
- data_source: Data source name
+ symbol: 股票代码
+ data: 股票数据DataFrame或字符串
+ start_date: 开始日期
+ end_date: 结束日期
+ data_source: 数据源(如 "tdx", "yfinance", "finnhub"
Returns:
- Cache key
- """
+ cache_key: 缓存键
+ """
+ market_type = self._determine_market_type(symbol)
+ cache_key = self._generate_cache_key("stock_data", symbol,
+ start_date=start_date,
+ end_date=end_date,
+ source=data_source,
+ market=market_type)
+
+ # 保存数据
+ if isinstance(data, pd.DataFrame):
+ cache_path = self._get_cache_path("stock_data", cache_key, "csv", symbol)
+ data.to_csv(cache_path, index=True)
+ else:
+ cache_path = self._get_cache_path("stock_data", cache_key, "txt", symbol)
+ with open(cache_path, 'w', encoding='utf-8') as f:
+ f.write(str(data))
+
+ # 保存元数据
+ metadata = {
+ 'symbol': symbol,
+ 'data_type': 'stock_data',
+ 'market_type': market_type,
+ 'start_date': start_date,
+ 'end_date': end_date,
+ 'data_source': data_source,
+ 'file_path': str(cache_path),
+ 'file_format': 'csv' if isinstance(data, pd.DataFrame) else 'txt'
+ }
+ self._save_metadata(cache_key, metadata)
+
+ # 获取描述信息
+ cache_type = f"{market_type}_stock_data"
+ desc = self.cache_config.get(cache_type, {}).get('description', '股票数据')
+ print(f"💾 {desc}已缓存: {symbol} ({data_source}) -> {cache_key}")
+ return cache_key
+
+ def load_stock_data(self, cache_key: str) -> Optional[Union[pd.DataFrame, str]]:
+ """从缓存加载股票数据"""
+ metadata = self._load_metadata(cache_key)
+ if not metadata:
+ return None
+
+ cache_path = Path(metadata['file_path'])
+ if not cache_path.exists():
+ return None
+
try:
- # Generate cache key
- cache_key = self._generate_cache_key(
- "stock_data", symbol,
- start_date=start_date,
- end_date=end_date,
- data_source=data_source
- )
-
- # Determine file format and save data
- if isinstance(data, pd.DataFrame):
- # Save DataFrame as pickle for better performance
- cache_path = self._get_cache_path("stock_data", cache_key, "pkl", symbol)
- data.to_pickle(cache_path)
- data_type = "dataframe"
- else:
- # Save string data as JSON
- cache_path = self._get_cache_path("stock_data", cache_key, "json", symbol)
- with open(cache_path, 'w', encoding='utf-8') as f:
- json.dump({"data": data}, f, ensure_ascii=False, indent=2)
- data_type = "string"
-
- # Save metadata
- market_type = self._determine_market_type(symbol)
- metadata = {
- "symbol": symbol,
- "data_type": data_type,
- "start_date": start_date,
- "end_date": end_date,
- "data_source": data_source,
- "market_type": market_type,
- "cache_time": datetime.now().isoformat(),
- "file_path": str(cache_path),
- "cache_key": cache_key
- }
- self._save_metadata(cache_key, metadata)
-
- print(f"💾 Stock data cached: {symbol} ({market_type.upper()}) -> {cache_key}")
- return cache_key
-
- except Exception as e:
- print(f"❌ Failed to save stock data cache: {e}")
- return None
-
- def load_stock_data(self, cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
- """
- Load stock data from cache
-
- Args:
- cache_key: Cache key
-
- Returns:
- Stock data or None if not found
- """
- try:
- # Load metadata
- metadata = self._load_metadata(cache_key)
- if not metadata:
- print(f"⚠️ Cache metadata not found: {cache_key}")
- return None
-
- # Get file path
- cache_path = Path(metadata["file_path"])
- if not cache_path.exists():
- print(f"⚠️ Cache file not found: {cache_path}")
- return None
-
- # Load data based on type
- if metadata["data_type"] == "dataframe":
- data = pd.read_pickle(cache_path)
+ if metadata['file_format'] == 'csv':
+ return pd.read_csv(cache_path, index_col=0)
else:
with open(cache_path, 'r', encoding='utf-8') as f:
- json_data = json.load(f)
- data = json_data["data"]
-
- print(f"📖 Stock data loaded from cache: {metadata['symbol']} -> {cache_key}")
- return data
-
+ return f.read()
except Exception as e:
- print(f"❌ Failed to load stock data from cache: {e}")
- return None
-
- def find_cached_stock_data(self, symbol: str, start_date: str, end_date: str,
- data_source: str = "unknown") -> Optional[str]:
- """
- Find cached stock data
+ print(f"⚠️ 加载缓存数据失败: {e}")
+ return None
+
+ def find_cached_stock_data(self, symbol: str, start_date: str = None,
+ end_date: str = None, data_source: str = None,
+ max_age_hours: int = None) -> Optional[str]:
+ """
+ 查找匹配的缓存数据 - 支持智能市场分类查找
Args:
- symbol: Stock symbol
- start_date: Start date
- end_date: End date
- data_source: Data source name
+ symbol: 股票代码
+ start_date: 开始日期
+ end_date: 结束日期
+ data_source: 数据源
+ max_age_hours: 最大缓存时间小时None时使用智能配置
Returns:
- Cache key if found, None otherwise
- """
- # Generate expected cache key
- cache_key = self._generate_cache_key(
- "stock_data", symbol,
- start_date=start_date,
- end_date=end_date,
- data_source=data_source
- )
-
- # Check if metadata exists
+ cache_key: 如果找到有效缓存则返回缓存键否则返回None
+ """
+ market_type = self._determine_market_type(symbol)
+
+ # 如果没有指定TTL使用智能配置
+ if max_age_hours is None:
+ cache_type = f"{market_type}_stock_data"
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
+
+ # 生成查找键
+ search_key = self._generate_cache_key("stock_data", symbol,
+ start_date=start_date,
+ end_date=end_date,
+ source=data_source,
+ market=market_type)
+
+ # 检查精确匹配
+ if self.is_cache_valid(search_key, max_age_hours, symbol, 'stock_data'):
+ desc = self.cache_config.get(f"{market_type}_stock_data", {}).get('description', '数据')
+ print(f"🎯 找到精确匹配的{desc}: {symbol} -> {search_key}")
+ return search_key
+
+ # 如果没有精确匹配,查找部分匹配(相同股票代码的其他缓存)
+ for metadata_file in self.metadata_dir.glob(f"*_meta.json"):
+ try:
+ with open(metadata_file, 'r', encoding='utf-8') as f:
+ metadata = json.load(f)
+
+ if (metadata.get('symbol') == symbol and
+ metadata.get('data_type') == 'stock_data' and
+ metadata.get('market_type') == market_type and
+ (data_source is None or metadata.get('data_source') == data_source)):
+
+ cache_key = metadata_file.stem.replace('_meta', '')
+ if self.is_cache_valid(cache_key, max_age_hours, symbol, 'stock_data'):
+ desc = self.cache_config.get(f"{market_type}_stock_data", {}).get('description', '数据')
+ print(f"📋 找到部分匹配的{desc}: {symbol} -> {cache_key}")
+ return cache_key
+ except Exception:
+ continue
+
+ desc = self.cache_config.get(f"{market_type}_stock_data", {}).get('description', '数据')
+ print(f"❌ 未找到有效的{desc}缓存: {symbol}")
+ return None
+
+ def save_news_data(self, symbol: str, news_data: str,
+ start_date: str = None, end_date: str = None,
+ data_source: str = "unknown") -> str:
+ """保存新闻数据到缓存"""
+ cache_key = self._generate_cache_key("news", symbol,
+ start_date=start_date,
+ end_date=end_date,
+ source=data_source)
+
+ cache_path = self._get_cache_path("news", cache_key, "txt")
+ with open(cache_path, 'w', encoding='utf-8') as f:
+ f.write(news_data)
+
+ metadata = {
+ 'symbol': symbol,
+ 'data_type': 'news',
+ 'start_date': start_date,
+ 'end_date': end_date,
+ 'data_source': data_source,
+ 'file_path': str(cache_path),
+ 'file_format': 'txt'
+ }
+ self._save_metadata(cache_key, metadata)
+
+ print(f"📰 新闻数据已缓存: {symbol} ({data_source}) -> {cache_key}")
+ return cache_key
+
+ def save_fundamentals_data(self, symbol: str, fundamentals_data: str,
+ data_source: str = "unknown") -> str:
+ """保存基本面数据到缓存"""
+ market_type = self._determine_market_type(symbol)
+ cache_key = self._generate_cache_key("fundamentals", symbol,
+ source=data_source,
+ market=market_type,
+ date=datetime.now().strftime("%Y-%m-%d"))
+
+ cache_path = self._get_cache_path("fundamentals", cache_key, "txt", symbol)
+ with open(cache_path, 'w', encoding='utf-8') as f:
+ f.write(fundamentals_data)
+
+ metadata = {
+ 'symbol': symbol,
+ 'data_type': 'fundamentals',
+ 'data_source': data_source,
+ 'market_type': market_type,
+ 'file_path': str(cache_path),
+ 'file_format': 'txt'
+ }
+ self._save_metadata(cache_key, metadata)
+
+ desc = self.cache_config.get(f"{market_type}_fundamentals", {}).get('description', '基本面数据')
+ print(f"💼 {desc}已缓存: {symbol} ({data_source}) -> {cache_key}")
+ return cache_key
+
+ def load_fundamentals_data(self, cache_key: str) -> Optional[str]:
+ """从缓存加载基本面数据"""
metadata = self._load_metadata(cache_key)
- if metadata:
- cache_path = Path(metadata["file_path"])
- if cache_path.exists():
- return cache_key
-
+ if not metadata:
+ return None
+
+ cache_path = Path(metadata['file_path'])
+ if not cache_path.exists():
+ return None
+
+ try:
+ with open(cache_path, 'r', encoding='utf-8') as f:
+ return f.read()
+ except Exception as e:
+ print(f"⚠️ 加载基本面缓存数据失败: {e}")
+ return None
+
+ def find_cached_fundamentals_data(self, symbol: str, data_source: str = None,
+ max_age_hours: int = None) -> Optional[str]:
+ """
+ 查找匹配的基本面缓存数据
+
+ Args:
+ symbol: 股票代码
+ data_source: 数据源(如 "openai", "finnhub"
+ max_age_hours: 最大缓存时间小时None时使用智能配置
+
+ Returns:
+ cache_key: 如果找到有效缓存则返回缓存键否则返回None
+ """
+ market_type = self._determine_market_type(symbol)
+
+ # 如果没有指定TTL使用智能配置
+ if max_age_hours is None:
+ cache_type = f"{market_type}_fundamentals"
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
+
+ # 查找匹配的缓存
+ for metadata_file in self.metadata_dir.glob(f"*_meta.json"):
+ try:
+ with open(metadata_file, 'r', encoding='utf-8') as f:
+ metadata = json.load(f)
+
+ if (metadata.get('symbol') == symbol and
+ metadata.get('data_type') == 'fundamentals' and
+ metadata.get('market_type') == market_type and
+ (data_source is None or metadata.get('data_source') == data_source)):
+
+ cache_key = metadata_file.stem.replace('_meta', '')
+ if self.is_cache_valid(cache_key, max_age_hours, symbol, 'fundamentals'):
+ desc = self.cache_config.get(f"{market_type}_fundamentals", {}).get('description', '基本面数据')
+ print(f"🎯 找到匹配的{desc}缓存: {symbol} ({data_source}) -> {cache_key}")
+ return cache_key
+ except Exception:
+ continue
+
+ desc = self.cache_config.get(f"{market_type}_fundamentals", {}).get('description', '基本面数据')
+ print(f"❌ 未找到有效的{desc}缓存: {symbol} ({data_source})")
return None
-
- def is_cache_valid(self, cache_key: str, symbol: str = None, data_type: str = "stock_data") -> bool:
- """
- Check if cache is still valid
-
- Args:
- cache_key: Cache key
- symbol: Stock symbol (for market type determination)
- data_type: Data type
-
- Returns:
- True if cache is valid, False otherwise
- """
- try:
- # Load metadata
- metadata = self._load_metadata(cache_key)
- if not metadata:
- return False
-
- # Check if file exists
- cache_path = Path(metadata["file_path"])
- if not cache_path.exists():
- return False
-
- # Determine market type and get TTL
- if symbol:
- market_type = self._determine_market_type(symbol)
- else:
- market_type = metadata.get("market_type", "us")
-
- cache_type_key = f"{market_type}_{data_type}"
- if cache_type_key not in self.cache_config:
- cache_type_key = "us_stock_data" # Default fallback
-
- ttl_hours = self.cache_config[cache_type_key]["ttl_hours"]
-
- # Check if cache has expired
- cache_time = datetime.fromisoformat(metadata["cache_time"])
- expiry_time = cache_time + timedelta(hours=ttl_hours)
-
- is_valid = datetime.now() < expiry_time
- if not is_valid:
- print(f"⏰ Cache expired: {cache_key} (cached at {cache_time}, TTL: {ttl_hours}h)")
-
- return is_valid
-
- except Exception as e:
- print(f"❌ Failed to check cache validity: {e}")
- return False
-
+
+ def clear_old_cache(self, max_age_days: int = 7):
+ """清理过期缓存"""
+ cutoff_time = datetime.now() - timedelta(days=max_age_days)
+ cleared_count = 0
+
+ for metadata_file in self.metadata_dir.glob("*_meta.json"):
+ try:
+ with open(metadata_file, 'r', encoding='utf-8') as f:
+ metadata = json.load(f)
+
+ cached_at = datetime.fromisoformat(metadata['cached_at'])
+ if cached_at < cutoff_time:
+ # 删除数据文件
+ data_file = Path(metadata['file_path'])
+ if data_file.exists():
+ data_file.unlink()
+
+ # 删除元数据文件
+ metadata_file.unlink()
+ cleared_count += 1
+
+ except Exception as e:
+ print(f"⚠️ 清理缓存时出错: {e}")
+
+ print(f"🧹 已清理 {cleared_count} 个过期缓存文件")
+
def get_cache_stats(self) -> Dict[str, Any]:
- """
- Get cache statistics
-
- Returns:
- Dictionary containing cache statistics
- """
- try:
- stats = {
- "cache_dir": str(self.cache_dir),
- "total_files": 0,
- "total_size_mb": 0,
- "stock_data_count": 0,
- "news_count": 0,
- "fundamentals_count": 0,
- "us_data_count": 0,
- "china_data_count": 0
- }
-
- # Count files in each directory
- for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
- self.china_news_dir, self.us_fundamentals_dir,
- self.china_fundamentals_dir, self.metadata_dir]:
- if dir_path.exists():
- files = list(dir_path.glob("*"))
- stats["total_files"] += len(files)
-
- # Calculate total size
- for file_path in files:
- if file_path.is_file():
- stats["total_size_mb"] += file_path.stat().st_size / (1024 * 1024)
-
- # Count by data type
- if self.us_stock_dir.exists():
- stats["stock_data_count"] += len(list(self.us_stock_dir.glob("*")))
- stats["us_data_count"] += len(list(self.us_stock_dir.glob("*")))
-
- if self.china_stock_dir.exists():
- stats["stock_data_count"] += len(list(self.china_stock_dir.glob("*")))
- stats["china_data_count"] += len(list(self.china_stock_dir.glob("*")))
-
- if self.us_news_dir.exists():
- stats["news_count"] += len(list(self.us_news_dir.glob("*")))
- stats["us_data_count"] += len(list(self.us_news_dir.glob("*")))
-
- if self.china_news_dir.exists():
- stats["news_count"] += len(list(self.china_news_dir.glob("*")))
- stats["china_data_count"] += len(list(self.china_news_dir.glob("*")))
-
- if self.us_fundamentals_dir.exists():
- stats["fundamentals_count"] += len(list(self.us_fundamentals_dir.glob("*")))
- stats["us_data_count"] += len(list(self.us_fundamentals_dir.glob("*")))
-
- if self.china_fundamentals_dir.exists():
- stats["fundamentals_count"] += len(list(self.china_fundamentals_dir.glob("*")))
- stats["china_data_count"] += len(list(self.china_fundamentals_dir.glob("*")))
-
- # Round size to 2 decimal places
- stats["total_size_mb"] = round(stats["total_size_mb"], 2)
-
- return stats
-
- except Exception as e:
- print(f"❌ Failed to get cache statistics: {e}")
- return {"error": str(e)}
-
- def cleanup_expired_cache(self):
- """Clean up expired cache files"""
- try:
- cleaned_count = 0
-
- # Check all metadata files
- if self.metadata_dir.exists():
- for metadata_file in self.metadata_dir.glob("*_meta.json"):
- try:
- cache_key = metadata_file.stem.replace("_meta", "")
-
- # Load metadata
- with open(metadata_file, 'r', encoding='utf-8') as f:
- metadata = json.load(f)
-
- # Check if cache is expired
- if not self.is_cache_valid(cache_key, metadata.get("symbol"), "stock_data"):
- # Remove cache file
- cache_path = Path(metadata["file_path"])
- if cache_path.exists():
- cache_path.unlink()
-
- # Remove metadata file
- metadata_file.unlink()
- cleaned_count += 1
- print(f"🗑️ Cleaned expired cache: {cache_key}")
-
- except Exception as e:
- print(f"⚠️ Failed to clean cache file {metadata_file}: {e}")
-
- print(f"✅ Cache cleanup completed, removed {cleaned_count} expired files")
-
- except Exception as e:
- print(f"❌ Failed to cleanup cache: {e}")
-
-
-# Global cache instance
-_global_cache = None
-
-def get_cache(cache_dir: str = None) -> StockDataCache:
- """
- Get global cache instance
-
- Args:
- cache_dir: Cache directory path
-
- Returns:
- StockDataCache instance
- """
- global _global_cache
- if _global_cache is None:
- _global_cache = StockDataCache(cache_dir)
- return _global_cache
-
-
-# Convenience functions
-def save_stock_data(symbol: str, data: Union[str, pd.DataFrame],
- start_date: str, end_date: str, data_source: str = "unknown") -> str:
- """Save stock data to cache (convenience function)"""
- cache = get_cache()
- return cache.save_stock_data(symbol, data, start_date, end_date, data_source)
-
-
-def load_stock_data(cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
- """Load stock data from cache (convenience function)"""
- cache = get_cache()
- return cache.load_stock_data(cache_key)
-
-
-def find_cached_stock_data(symbol: str, start_date: str, end_date: str,
- data_source: str = "unknown") -> Optional[str]:
- """Find cached stock data (convenience function)"""
- cache = get_cache()
- return cache.find_cached_stock_data(symbol, start_date, end_date, data_source)
-
-
-if __name__ == "__main__":
- # Test the cache manager
- print("🧪 Testing Stock Data Cache Manager...")
-
- # Initialize cache
- cache = StockDataCache()
-
- # Test data
- test_data = "Sample stock data for AAPL"
- cache_key = cache.save_stock_data("AAPL", test_data, "2024-01-01", "2024-01-31", "test")
-
- # Load data
- loaded_data = cache.load_stock_data(cache_key)
- print(f"Loaded data: {loaded_data}")
-
- # Check cache validity
- is_valid = cache.is_cache_valid(cache_key, "AAPL")
- print(f"Cache valid: {is_valid}")
-
- # Get statistics
- stats = cache.get_cache_stats()
- print(f"Cache stats: {stats}")
-
- print("✅ Cache manager test completed!")
+ """获取缓存统计信息"""
+ stats = {
+ 'total_files': 0,
+ 'stock_data_count': 0,
+ 'news_count': 0,
+ 'fundamentals_count': 0,
+ 'total_size_mb': 0
+ }
+
+ for metadata_file in self.metadata_dir.glob("*_meta.json"):
+ try:
+ with open(metadata_file, 'r', encoding='utf-8') as f:
+ metadata = json.load(f)
+
+ data_type = metadata.get('data_type', 'unknown')
+ if data_type == 'stock_data':
+ stats['stock_data_count'] += 1
+ elif data_type == 'news':
+ stats['news_count'] += 1
+ elif data_type == 'fundamentals':
+ stats['fundamentals_count'] += 1
+
+ # 计算文件大小
+ data_file = Path(metadata['file_path'])
+ if data_file.exists():
+ stats['total_size_mb'] += data_file.stat().st_size / (1024 * 1024)
+
+ stats['total_files'] += 1
+
+ except Exception:
+ continue
+
+ stats['total_size_mb'] = round(stats['total_size_mb'], 2)
+ return stats
+
+
+# 全局缓存实例
+_cache_instance = None
+
+def get_cache() -> StockDataCache:
+ """获取全局缓存实例"""
+ global _cache_instance
+ if _cache_instance is None:
+ _cache_instance = StockDataCache()
+ return _cache_instance