864 lines
36 KiB
Diff
864 lines
36 KiB
Diff
# 文件差异报告
|
||
# 当前文件: tradingagents\dataflows\cache_manager.py
|
||
# 中文版文件: TradingAgentsCN\tradingagents\dataflows\cache_manager.py
|
||
# 生成时间: 周日 2025/07/06
|
||
|
||
--- current/cache_manager.py+++ chinese_version/cache_manager.py@@ -1,7 +1,7 @@ #!/usr/bin/env python3
|
||
"""
|
||
-Stock Data Cache Manager
|
||
-Supports local caching of stock data to reduce API calls and improve response speed
|
||
+股票数据缓存管理器
|
||
+支持本地缓存股票数据,减少API调用,提高响应速度
|
||
"""
|
||
|
||
import os
|
||
@@ -15,24 +15,24 @@
|
||
|
||
class StockDataCache:
|
||
- """Stock Data Cache Manager - Supports optimized caching for US and Chinese stock data"""
|
||
+ """股票数据缓存管理器 - 支持美股和A股数据缓存优化"""
|
||
|
||
def __init__(self, cache_dir: str = None):
|
||
"""
|
||
- Initialize cache manager
|
||
+ 初始化缓存管理器
|
||
|
||
Args:
|
||
- cache_dir: Cache directory path, defaults to tradingagents/dataflows/data_cache
|
||
+ cache_dir: 缓存目录路径,默认为 tradingagents/dataflows/data_cache
|
||
"""
|
||
if cache_dir is None:
|
||
- # Get current file directory
|
||
+ # 获取当前文件所在目录
|
||
current_dir = Path(__file__).parent
|
||
cache_dir = current_dir / "data_cache"
|
||
|
||
self.cache_dir = Path(cache_dir)
|
||
self.cache_dir.mkdir(exist_ok=True)
|
||
|
||
- # Create subdirectories - categorized by market
|
||
+ # 创建子目录 - 按市场分类
|
||
self.us_stock_dir = self.cache_dir / "us_stocks"
|
||
self.china_stock_dir = self.cache_dir / "china_stocks"
|
||
self.us_news_dir = self.cache_dir / "us_news"
|
||
@@ -41,81 +41,81 @@ self.china_fundamentals_dir = self.cache_dir / "china_fundamentals"
|
||
self.metadata_dir = self.cache_dir / "metadata"
|
||
|
||
- # Create all directories
|
||
+ # 创建所有目录
|
||
for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
|
||
self.china_news_dir, self.us_fundamentals_dir,
|
||
self.china_fundamentals_dir, self.metadata_dir]:
|
||
dir_path.mkdir(exist_ok=True)
|
||
|
||
- # Cache configuration - different TTL settings for different markets
|
||
+ # 缓存配置 - 针对不同市场设置不同的TTL
|
||
self.cache_config = {
|
||
'us_stock_data': {
|
||
- 'ttl_hours': 2, # US stock data cached for 2 hours (considering API limits)
|
||
+ 'ttl_hours': 2, # 美股数据缓存2小时(考虑到API限制)
|
||
'max_files': 1000,
|
||
- 'description': 'US stock historical data'
|
||
+ 'description': '美股历史数据'
|
||
},
|
||
'china_stock_data': {
|
||
- 'ttl_hours': 1, # A-share data cached for 1 hour (high real-time requirement)
|
||
+ 'ttl_hours': 1, # A股数据缓存1小时(实时性要求高)
|
||
'max_files': 1000,
|
||
- 'description': 'A-share historical data'
|
||
+ 'description': 'A股历史数据'
|
||
},
|
||
'us_news': {
|
||
- 'ttl_hours': 6, # US stock news cached for 6 hours
|
||
+ 'ttl_hours': 6, # 美股新闻缓存6小时
|
||
'max_files': 500,
|
||
- 'description': 'US stock news data'
|
||
+ 'description': '美股新闻数据'
|
||
},
|
||
'china_news': {
|
||
- 'ttl_hours': 4, # A-share news cached for 4 hours
|
||
+ 'ttl_hours': 4, # A股新闻缓存4小时
|
||
'max_files': 500,
|
||
- 'description': 'A-share news data'
|
||
+ 'description': 'A股新闻数据'
|
||
},
|
||
'us_fundamentals': {
|
||
- 'ttl_hours': 24, # US stock fundamentals cached for 24 hours
|
||
+ 'ttl_hours': 24, # 美股基本面数据缓存24小时
|
||
'max_files': 200,
|
||
- 'description': 'US stock fundamentals data'
|
||
+ 'description': '美股基本面数据'
|
||
},
|
||
'china_fundamentals': {
|
||
- 'ttl_hours': 12, # A-share fundamentals cached for 12 hours
|
||
+ 'ttl_hours': 12, # A股基本面数据缓存12小时
|
||
'max_files': 200,
|
||
- 'description': 'A-share fundamentals data'
|
||
+ 'description': 'A股基本面数据'
|
||
}
|
||
}
|
||
|
||
- print(f"📁 Cache manager initialized, cache directory: {self.cache_dir}")
|
||
- print(f"🗄️ Database cache manager initialized")
|
||
- print(f" US stock data: ✅ Configured")
|
||
- print(f" A-share data: ✅ Configured")
|
||
+ print(f"📁 缓存管理器初始化完成,缓存目录: {self.cache_dir}")
|
||
+ print(f"🗄️ 数据库缓存管理器初始化完成")
|
||
+ print(f" 美股数据: ✅ 已配置")
|
||
+ print(f" A股数据: ✅ 已配置")
|
||
|
||
def _determine_market_type(self, symbol: str) -> str:
|
||
- """Determine market type based on stock symbol"""
|
||
+ """根据股票代码确定市场类型"""
|
||
import re
|
||
|
||
- # Check if it's Chinese A-share (6-digit number)
|
||
+ # 判断是否为中国A股(6位数字)
|
||
if re.match(r'^\d{6}$', str(symbol)):
|
||
return 'china'
|
||
else:
|
||
return 'us'
|
||
|
||
def _generate_cache_key(self, data_type: str, symbol: str, **kwargs) -> str:
|
||
- """Generate cache key"""
|
||
- # Create a string containing all parameters
|
||
+ """生成缓存键"""
|
||
+ # 创建一个包含所有参数的字符串
|
||
params_str = f"{data_type}_{symbol}"
|
||
for key, value in sorted(kwargs.items()):
|
||
params_str += f"_{key}_{value}"
|
||
|
||
- # Use MD5 to generate short unique identifier
|
||
+ # 使用MD5生成短的唯一标识
|
||
cache_key = hashlib.md5(params_str.encode()).hexdigest()[:12]
|
||
return f"{symbol}_{data_type}_{cache_key}"
|
||
|
||
def _get_cache_path(self, data_type: str, cache_key: str, file_format: str = "json", symbol: str = None) -> Path:
|
||
- """Get cache file path - supports market classification"""
|
||
+ """获取缓存文件路径 - 支持市场分类"""
|
||
if symbol:
|
||
market_type = self._determine_market_type(symbol)
|
||
else:
|
||
- # Try to extract market type from cache key
|
||
+ # 从缓存键中尝试提取市场类型
|
||
market_type = 'us' if not cache_key.startswith(('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')) else 'china'
|
||
|
||
- # Select directory based on data type and market type
|
||
+ # 根据数据类型和市场类型选择目录
|
||
if data_type == "stock_data":
|
||
base_dir = self.china_stock_dir if market_type == 'china' else self.us_stock_dir
|
||
elif data_type == "news":
|
||
@@ -128,20 +128,19 @@ return base_dir / f"{cache_key}.{file_format}"
|
||
|
||
def _get_metadata_path(self, cache_key: str) -> Path:
|
||
- """Get metadata file path"""
|
||
+ """获取元数据文件路径"""
|
||
return self.metadata_dir / f"{cache_key}_meta.json"
|
||
|
||
def _save_metadata(self, cache_key: str, metadata: Dict[str, Any]):
|
||
- """Save cache metadata"""
|
||
+ """保存元数据"""
|
||
metadata_path = self._get_metadata_path(cache_key)
|
||
- try:
|
||
- with open(metadata_path, 'w', encoding='utf-8') as f:
|
||
- json.dump(metadata, f, ensure_ascii=False, indent=2, default=str)
|
||
- except Exception as e:
|
||
- print(f"⚠️ Failed to save metadata: {e}")
|
||
+ metadata['cached_at'] = datetime.now().isoformat()
|
||
+
|
||
+ with open(metadata_path, 'w', encoding='utf-8') as f:
|
||
+ json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||
|
||
def _load_metadata(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||
- """Load cache metadata"""
|
||
+ """加载元数据"""
|
||
metadata_path = self._get_metadata_path(cache_key)
|
||
if not metadata_path.exists():
|
||
return None
|
||
@@ -150,349 +149,355 @@ with open(metadata_path, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
except Exception as e:
|
||
- print(f"⚠️ Failed to load metadata: {e}")
|
||
- return None
|
||
-
|
||
- def save_stock_data(self, symbol: str, data: Union[str, pd.DataFrame],
|
||
- start_date: str, end_date: str, data_source: str = "unknown") -> str:
|
||
- """
|
||
- Save stock data to cache
|
||
+ print(f"⚠️ 加载元数据失败: {e}")
|
||
+ return None
|
||
+
|
||
+ def is_cache_valid(self, cache_key: str, max_age_hours: int = None, symbol: str = None, data_type: str = None) -> bool:
|
||
+ """检查缓存是否有效 - 支持智能TTL配置"""
|
||
+ metadata = self._load_metadata(cache_key)
|
||
+ if not metadata:
|
||
+ return False
|
||
+
|
||
+ # 如果没有指定TTL,根据数据类型和市场自动确定
|
||
+ if max_age_hours is None:
|
||
+ if symbol and data_type:
|
||
+ market_type = self._determine_market_type(symbol)
|
||
+ cache_type = f"{market_type}_{data_type}"
|
||
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
|
||
+ else:
|
||
+ # 从元数据中获取信息
|
||
+ symbol = metadata.get('symbol', '')
|
||
+ data_type = metadata.get('data_type', 'stock_data')
|
||
+ market_type = self._determine_market_type(symbol)
|
||
+ cache_type = f"{market_type}_{data_type}"
|
||
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
|
||
+
|
||
+ cached_at = datetime.fromisoformat(metadata['cached_at'])
|
||
+ age = datetime.now() - cached_at
|
||
+
|
||
+ is_valid = age.total_seconds() < max_age_hours * 3600
|
||
+
|
||
+ if is_valid:
|
||
+ market_type = self._determine_market_type(metadata.get('symbol', ''))
|
||
+ cache_type = f"{market_type}_{metadata.get('data_type', 'stock_data')}"
|
||
+ desc = self.cache_config.get(cache_type, {}).get('description', '数据')
|
||
+ print(f"✅ 缓存有效: {desc} - {metadata.get('symbol')} (剩余 {max_age_hours - age.total_seconds()/3600:.1f}h)")
|
||
+
|
||
+ return is_valid
|
||
+
|
||
+ def save_stock_data(self, symbol: str, data: Union[pd.DataFrame, str],
|
||
+ start_date: str = None, end_date: str = None,
|
||
+ data_source: str = "unknown") -> str:
|
||
+ """
|
||
+ 保存股票数据到缓存 - 支持美股和A股分类存储
|
||
|
||
Args:
|
||
- symbol: Stock symbol
|
||
- data: Stock data (string or DataFrame)
|
||
- start_date: Start date
|
||
- end_date: End date
|
||
- data_source: Data source name
|
||
+ symbol: 股票代码
|
||
+ data: 股票数据(DataFrame或字符串)
|
||
+ start_date: 开始日期
|
||
+ end_date: 结束日期
|
||
+ data_source: 数据源(如 "tdx", "yfinance", "finnhub")
|
||
|
||
Returns:
|
||
- Cache key
|
||
- """
|
||
+ cache_key: 缓存键
|
||
+ """
|
||
+ market_type = self._determine_market_type(symbol)
|
||
+ cache_key = self._generate_cache_key("stock_data", symbol,
|
||
+ start_date=start_date,
|
||
+ end_date=end_date,
|
||
+ source=data_source,
|
||
+ market=market_type)
|
||
+
|
||
+ # 保存数据
|
||
+ if isinstance(data, pd.DataFrame):
|
||
+ cache_path = self._get_cache_path("stock_data", cache_key, "csv", symbol)
|
||
+ data.to_csv(cache_path, index=True)
|
||
+ else:
|
||
+ cache_path = self._get_cache_path("stock_data", cache_key, "txt", symbol)
|
||
+ with open(cache_path, 'w', encoding='utf-8') as f:
|
||
+ f.write(str(data))
|
||
+
|
||
+ # 保存元数据
|
||
+ metadata = {
|
||
+ 'symbol': symbol,
|
||
+ 'data_type': 'stock_data',
|
||
+ 'market_type': market_type,
|
||
+ 'start_date': start_date,
|
||
+ 'end_date': end_date,
|
||
+ 'data_source': data_source,
|
||
+ 'file_path': str(cache_path),
|
||
+ 'file_format': 'csv' if isinstance(data, pd.DataFrame) else 'txt'
|
||
+ }
|
||
+ self._save_metadata(cache_key, metadata)
|
||
+
|
||
+ # 获取描述信息
|
||
+ cache_type = f"{market_type}_stock_data"
|
||
+ desc = self.cache_config.get(cache_type, {}).get('description', '股票数据')
|
||
+ print(f"💾 {desc}已缓存: {symbol} ({data_source}) -> {cache_key}")
|
||
+ return cache_key
|
||
+
|
||
+ def load_stock_data(self, cache_key: str) -> Optional[Union[pd.DataFrame, str]]:
|
||
+ """从缓存加载股票数据"""
|
||
+ metadata = self._load_metadata(cache_key)
|
||
+ if not metadata:
|
||
+ return None
|
||
+
|
||
+ cache_path = Path(metadata['file_path'])
|
||
+ if not cache_path.exists():
|
||
+ return None
|
||
+
|
||
try:
|
||
- # Generate cache key
|
||
- cache_key = self._generate_cache_key(
|
||
- "stock_data", symbol,
|
||
- start_date=start_date,
|
||
- end_date=end_date,
|
||
- data_source=data_source
|
||
- )
|
||
-
|
||
- # Determine file format and save data
|
||
- if isinstance(data, pd.DataFrame):
|
||
- # Save DataFrame as pickle for better performance
|
||
- cache_path = self._get_cache_path("stock_data", cache_key, "pkl", symbol)
|
||
- data.to_pickle(cache_path)
|
||
- data_type = "dataframe"
|
||
- else:
|
||
- # Save string data as JSON
|
||
- cache_path = self._get_cache_path("stock_data", cache_key, "json", symbol)
|
||
- with open(cache_path, 'w', encoding='utf-8') as f:
|
||
- json.dump({"data": data}, f, ensure_ascii=False, indent=2)
|
||
- data_type = "string"
|
||
-
|
||
- # Save metadata
|
||
- market_type = self._determine_market_type(symbol)
|
||
- metadata = {
|
||
- "symbol": symbol,
|
||
- "data_type": data_type,
|
||
- "start_date": start_date,
|
||
- "end_date": end_date,
|
||
- "data_source": data_source,
|
||
- "market_type": market_type,
|
||
- "cache_time": datetime.now().isoformat(),
|
||
- "file_path": str(cache_path),
|
||
- "cache_key": cache_key
|
||
- }
|
||
- self._save_metadata(cache_key, metadata)
|
||
-
|
||
- print(f"💾 Stock data cached: {symbol} ({market_type.upper()}) -> {cache_key}")
|
||
- return cache_key
|
||
-
|
||
- except Exception as e:
|
||
- print(f"❌ Failed to save stock data cache: {e}")
|
||
- return None
|
||
-
|
||
- def load_stock_data(self, cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
|
||
- """
|
||
- Load stock data from cache
|
||
-
|
||
- Args:
|
||
- cache_key: Cache key
|
||
-
|
||
- Returns:
|
||
- Stock data or None if not found
|
||
- """
|
||
- try:
|
||
- # Load metadata
|
||
- metadata = self._load_metadata(cache_key)
|
||
- if not metadata:
|
||
- print(f"⚠️ Cache metadata not found: {cache_key}")
|
||
- return None
|
||
-
|
||
- # Get file path
|
||
- cache_path = Path(metadata["file_path"])
|
||
- if not cache_path.exists():
|
||
- print(f"⚠️ Cache file not found: {cache_path}")
|
||
- return None
|
||
-
|
||
- # Load data based on type
|
||
- if metadata["data_type"] == "dataframe":
|
||
- data = pd.read_pickle(cache_path)
|
||
+ if metadata['file_format'] == 'csv':
|
||
+ return pd.read_csv(cache_path, index_col=0)
|
||
else:
|
||
with open(cache_path, 'r', encoding='utf-8') as f:
|
||
- json_data = json.load(f)
|
||
- data = json_data["data"]
|
||
-
|
||
- print(f"📖 Stock data loaded from cache: {metadata['symbol']} -> {cache_key}")
|
||
- return data
|
||
-
|
||
+ return f.read()
|
||
except Exception as e:
|
||
- print(f"❌ Failed to load stock data from cache: {e}")
|
||
- return None
|
||
-
|
||
- def find_cached_stock_data(self, symbol: str, start_date: str, end_date: str,
|
||
- data_source: str = "unknown") -> Optional[str]:
|
||
- """
|
||
- Find cached stock data
|
||
+ print(f"⚠️ 加载缓存数据失败: {e}")
|
||
+ return None
|
||
+
|
||
+ def find_cached_stock_data(self, symbol: str, start_date: str = None,
|
||
+ end_date: str = None, data_source: str = None,
|
||
+ max_age_hours: int = None) -> Optional[str]:
|
||
+ """
|
||
+ 查找匹配的缓存数据 - 支持智能市场分类查找
|
||
|
||
Args:
|
||
- symbol: Stock symbol
|
||
- start_date: Start date
|
||
- end_date: End date
|
||
- data_source: Data source name
|
||
+ symbol: 股票代码
|
||
+ start_date: 开始日期
|
||
+ end_date: 结束日期
|
||
+ data_source: 数据源
|
||
+ max_age_hours: 最大缓存时间(小时),None时使用智能配置
|
||
|
||
Returns:
|
||
- Cache key if found, None otherwise
|
||
- """
|
||
- # Generate expected cache key
|
||
- cache_key = self._generate_cache_key(
|
||
- "stock_data", symbol,
|
||
- start_date=start_date,
|
||
- end_date=end_date,
|
||
- data_source=data_source
|
||
- )
|
||
-
|
||
- # Check if metadata exists
|
||
+ cache_key: 如果找到有效缓存则返回缓存键,否则返回None
|
||
+ """
|
||
+ market_type = self._determine_market_type(symbol)
|
||
+
|
||
+ # 如果没有指定TTL,使用智能配置
|
||
+ if max_age_hours is None:
|
||
+ cache_type = f"{market_type}_stock_data"
|
||
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
|
||
+
|
||
+ # 生成查找键
|
||
+ search_key = self._generate_cache_key("stock_data", symbol,
|
||
+ start_date=start_date,
|
||
+ end_date=end_date,
|
||
+ source=data_source,
|
||
+ market=market_type)
|
||
+
|
||
+ # 检查精确匹配
|
||
+ if self.is_cache_valid(search_key, max_age_hours, symbol, 'stock_data'):
|
||
+ desc = self.cache_config.get(f"{market_type}_stock_data", {}).get('description', '数据')
|
||
+ print(f"🎯 找到精确匹配的{desc}: {symbol} -> {search_key}")
|
||
+ return search_key
|
||
+
|
||
+ # 如果没有精确匹配,查找部分匹配(相同股票代码的其他缓存)
|
||
+ for metadata_file in self.metadata_dir.glob(f"*_meta.json"):
|
||
+ try:
|
||
+ with open(metadata_file, 'r', encoding='utf-8') as f:
|
||
+ metadata = json.load(f)
|
||
+
|
||
+ if (metadata.get('symbol') == symbol and
|
||
+ metadata.get('data_type') == 'stock_data' and
|
||
+ metadata.get('market_type') == market_type and
|
||
+ (data_source is None or metadata.get('data_source') == data_source)):
|
||
+
|
||
+ cache_key = metadata_file.stem.replace('_meta', '')
|
||
+ if self.is_cache_valid(cache_key, max_age_hours, symbol, 'stock_data'):
|
||
+ desc = self.cache_config.get(f"{market_type}_stock_data", {}).get('description', '数据')
|
||
+ print(f"📋 找到部分匹配的{desc}: {symbol} -> {cache_key}")
|
||
+ return cache_key
|
||
+ except Exception:
|
||
+ continue
|
||
+
|
||
+ desc = self.cache_config.get(f"{market_type}_stock_data", {}).get('description', '数据')
|
||
+ print(f"❌ 未找到有效的{desc}缓存: {symbol}")
|
||
+ return None
|
||
+
|
||
+ def save_news_data(self, symbol: str, news_data: str,
|
||
+ start_date: str = None, end_date: str = None,
|
||
+ data_source: str = "unknown") -> str:
|
||
+ """保存新闻数据到缓存"""
|
||
+ cache_key = self._generate_cache_key("news", symbol,
|
||
+ start_date=start_date,
|
||
+ end_date=end_date,
|
||
+ source=data_source)
|
||
+
|
||
+ cache_path = self._get_cache_path("news", cache_key, "txt")
|
||
+ with open(cache_path, 'w', encoding='utf-8') as f:
|
||
+ f.write(news_data)
|
||
+
|
||
+ metadata = {
|
||
+ 'symbol': symbol,
|
||
+ 'data_type': 'news',
|
||
+ 'start_date': start_date,
|
||
+ 'end_date': end_date,
|
||
+ 'data_source': data_source,
|
||
+ 'file_path': str(cache_path),
|
||
+ 'file_format': 'txt'
|
||
+ }
|
||
+ self._save_metadata(cache_key, metadata)
|
||
+
|
||
+ print(f"📰 新闻数据已缓存: {symbol} ({data_source}) -> {cache_key}")
|
||
+ return cache_key
|
||
+
|
||
+ def save_fundamentals_data(self, symbol: str, fundamentals_data: str,
|
||
+ data_source: str = "unknown") -> str:
|
||
+ """保存基本面数据到缓存"""
|
||
+ market_type = self._determine_market_type(symbol)
|
||
+ cache_key = self._generate_cache_key("fundamentals", symbol,
|
||
+ source=data_source,
|
||
+ market=market_type,
|
||
+ date=datetime.now().strftime("%Y-%m-%d"))
|
||
+
|
||
+ cache_path = self._get_cache_path("fundamentals", cache_key, "txt", symbol)
|
||
+ with open(cache_path, 'w', encoding='utf-8') as f:
|
||
+ f.write(fundamentals_data)
|
||
+
|
||
+ metadata = {
|
||
+ 'symbol': symbol,
|
||
+ 'data_type': 'fundamentals',
|
||
+ 'data_source': data_source,
|
||
+ 'market_type': market_type,
|
||
+ 'file_path': str(cache_path),
|
||
+ 'file_format': 'txt'
|
||
+ }
|
||
+ self._save_metadata(cache_key, metadata)
|
||
+
|
||
+ desc = self.cache_config.get(f"{market_type}_fundamentals", {}).get('description', '基本面数据')
|
||
+ print(f"💼 {desc}已缓存: {symbol} ({data_source}) -> {cache_key}")
|
||
+ return cache_key
|
||
+
|
||
+ def load_fundamentals_data(self, cache_key: str) -> Optional[str]:
|
||
+ """从缓存加载基本面数据"""
|
||
metadata = self._load_metadata(cache_key)
|
||
- if metadata:
|
||
- cache_path = Path(metadata["file_path"])
|
||
- if cache_path.exists():
|
||
- return cache_key
|
||
-
|
||
+ if not metadata:
|
||
+ return None
|
||
+
|
||
+ cache_path = Path(metadata['file_path'])
|
||
+ if not cache_path.exists():
|
||
+ return None
|
||
+
|
||
+ try:
|
||
+ with open(cache_path, 'r', encoding='utf-8') as f:
|
||
+ return f.read()
|
||
+ except Exception as e:
|
||
+ print(f"⚠️ 加载基本面缓存数据失败: {e}")
|
||
+ return None
|
||
+
|
||
+ def find_cached_fundamentals_data(self, symbol: str, data_source: str = None,
|
||
+ max_age_hours: int = None) -> Optional[str]:
|
||
+ """
|
||
+ 查找匹配的基本面缓存数据
|
||
+
|
||
+ Args:
|
||
+ symbol: 股票代码
|
||
+ data_source: 数据源(如 "openai", "finnhub")
|
||
+ max_age_hours: 最大缓存时间(小时),None时使用智能配置
|
||
+
|
||
+ Returns:
|
||
+ cache_key: 如果找到有效缓存则返回缓存键,否则返回None
|
||
+ """
|
||
+ market_type = self._determine_market_type(symbol)
|
||
+
|
||
+ # 如果没有指定TTL,使用智能配置
|
||
+ if max_age_hours is None:
|
||
+ cache_type = f"{market_type}_fundamentals"
|
||
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
|
||
+
|
||
+ # 查找匹配的缓存
|
||
+ for metadata_file in self.metadata_dir.glob(f"*_meta.json"):
|
||
+ try:
|
||
+ with open(metadata_file, 'r', encoding='utf-8') as f:
|
||
+ metadata = json.load(f)
|
||
+
|
||
+ if (metadata.get('symbol') == symbol and
|
||
+ metadata.get('data_type') == 'fundamentals' and
|
||
+ metadata.get('market_type') == market_type and
|
||
+ (data_source is None or metadata.get('data_source') == data_source)):
|
||
+
|
||
+ cache_key = metadata_file.stem.replace('_meta', '')
|
||
+ if self.is_cache_valid(cache_key, max_age_hours, symbol, 'fundamentals'):
|
||
+ desc = self.cache_config.get(f"{market_type}_fundamentals", {}).get('description', '基本面数据')
|
||
+ print(f"🎯 找到匹配的{desc}缓存: {symbol} ({data_source}) -> {cache_key}")
|
||
+ return cache_key
|
||
+ except Exception:
|
||
+ continue
|
||
+
|
||
+ desc = self.cache_config.get(f"{market_type}_fundamentals", {}).get('description', '基本面数据')
|
||
+ print(f"❌ 未找到有效的{desc}缓存: {symbol} ({data_source})")
|
||
return None
|
||
-
|
||
- def is_cache_valid(self, cache_key: str, symbol: str = None, data_type: str = "stock_data") -> bool:
|
||
- """
|
||
- Check if cache is still valid
|
||
-
|
||
- Args:
|
||
- cache_key: Cache key
|
||
- symbol: Stock symbol (for market type determination)
|
||
- data_type: Data type
|
||
-
|
||
- Returns:
|
||
- True if cache is valid, False otherwise
|
||
- """
|
||
- try:
|
||
- # Load metadata
|
||
- metadata = self._load_metadata(cache_key)
|
||
- if not metadata:
|
||
- return False
|
||
-
|
||
- # Check if file exists
|
||
- cache_path = Path(metadata["file_path"])
|
||
- if not cache_path.exists():
|
||
- return False
|
||
-
|
||
- # Determine market type and get TTL
|
||
- if symbol:
|
||
- market_type = self._determine_market_type(symbol)
|
||
- else:
|
||
- market_type = metadata.get("market_type", "us")
|
||
-
|
||
- cache_type_key = f"{market_type}_{data_type}"
|
||
- if cache_type_key not in self.cache_config:
|
||
- cache_type_key = "us_stock_data" # Default fallback
|
||
-
|
||
- ttl_hours = self.cache_config[cache_type_key]["ttl_hours"]
|
||
-
|
||
- # Check if cache has expired
|
||
- cache_time = datetime.fromisoformat(metadata["cache_time"])
|
||
- expiry_time = cache_time + timedelta(hours=ttl_hours)
|
||
-
|
||
- is_valid = datetime.now() < expiry_time
|
||
- if not is_valid:
|
||
- print(f"⏰ Cache expired: {cache_key} (cached at {cache_time}, TTL: {ttl_hours}h)")
|
||
-
|
||
- return is_valid
|
||
-
|
||
- except Exception as e:
|
||
- print(f"❌ Failed to check cache validity: {e}")
|
||
- return False
|
||
-
|
||
+
|
||
+ def clear_old_cache(self, max_age_days: int = 7):
|
||
+ """清理过期缓存"""
|
||
+ cutoff_time = datetime.now() - timedelta(days=max_age_days)
|
||
+ cleared_count = 0
|
||
+
|
||
+ for metadata_file in self.metadata_dir.glob("*_meta.json"):
|
||
+ try:
|
||
+ with open(metadata_file, 'r', encoding='utf-8') as f:
|
||
+ metadata = json.load(f)
|
||
+
|
||
+ cached_at = datetime.fromisoformat(metadata['cached_at'])
|
||
+ if cached_at < cutoff_time:
|
||
+ # 删除数据文件
|
||
+ data_file = Path(metadata['file_path'])
|
||
+ if data_file.exists():
|
||
+ data_file.unlink()
|
||
+
|
||
+ # 删除元数据文件
|
||
+ metadata_file.unlink()
|
||
+ cleared_count += 1
|
||
+
|
||
+ except Exception as e:
|
||
+ print(f"⚠️ 清理缓存时出错: {e}")
|
||
+
|
||
+ print(f"🧹 已清理 {cleared_count} 个过期缓存文件")
|
||
+
|
||
def get_cache_stats(self) -> Dict[str, Any]:
|
||
- """
|
||
- Get cache statistics
|
||
-
|
||
- Returns:
|
||
- Dictionary containing cache statistics
|
||
- """
|
||
- try:
|
||
- stats = {
|
||
- "cache_dir": str(self.cache_dir),
|
||
- "total_files": 0,
|
||
- "total_size_mb": 0,
|
||
- "stock_data_count": 0,
|
||
- "news_count": 0,
|
||
- "fundamentals_count": 0,
|
||
- "us_data_count": 0,
|
||
- "china_data_count": 0
|
||
- }
|
||
-
|
||
- # Count files in each directory
|
||
- for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
|
||
- self.china_news_dir, self.us_fundamentals_dir,
|
||
- self.china_fundamentals_dir, self.metadata_dir]:
|
||
- if dir_path.exists():
|
||
- files = list(dir_path.glob("*"))
|
||
- stats["total_files"] += len(files)
|
||
-
|
||
- # Calculate total size
|
||
- for file_path in files:
|
||
- if file_path.is_file():
|
||
- stats["total_size_mb"] += file_path.stat().st_size / (1024 * 1024)
|
||
-
|
||
- # Count by data type
|
||
- if self.us_stock_dir.exists():
|
||
- stats["stock_data_count"] += len(list(self.us_stock_dir.glob("*")))
|
||
- stats["us_data_count"] += len(list(self.us_stock_dir.glob("*")))
|
||
-
|
||
- if self.china_stock_dir.exists():
|
||
- stats["stock_data_count"] += len(list(self.china_stock_dir.glob("*")))
|
||
- stats["china_data_count"] += len(list(self.china_stock_dir.glob("*")))
|
||
-
|
||
- if self.us_news_dir.exists():
|
||
- stats["news_count"] += len(list(self.us_news_dir.glob("*")))
|
||
- stats["us_data_count"] += len(list(self.us_news_dir.glob("*")))
|
||
-
|
||
- if self.china_news_dir.exists():
|
||
- stats["news_count"] += len(list(self.china_news_dir.glob("*")))
|
||
- stats["china_data_count"] += len(list(self.china_news_dir.glob("*")))
|
||
-
|
||
- if self.us_fundamentals_dir.exists():
|
||
- stats["fundamentals_count"] += len(list(self.us_fundamentals_dir.glob("*")))
|
||
- stats["us_data_count"] += len(list(self.us_fundamentals_dir.glob("*")))
|
||
-
|
||
- if self.china_fundamentals_dir.exists():
|
||
- stats["fundamentals_count"] += len(list(self.china_fundamentals_dir.glob("*")))
|
||
- stats["china_data_count"] += len(list(self.china_fundamentals_dir.glob("*")))
|
||
-
|
||
- # Round size to 2 decimal places
|
||
- stats["total_size_mb"] = round(stats["total_size_mb"], 2)
|
||
-
|
||
- return stats
|
||
-
|
||
- except Exception as e:
|
||
- print(f"❌ Failed to get cache statistics: {e}")
|
||
- return {"error": str(e)}
|
||
-
|
||
- def cleanup_expired_cache(self):
|
||
- """Clean up expired cache files"""
|
||
- try:
|
||
- cleaned_count = 0
|
||
-
|
||
- # Check all metadata files
|
||
- if self.metadata_dir.exists():
|
||
- for metadata_file in self.metadata_dir.glob("*_meta.json"):
|
||
- try:
|
||
- cache_key = metadata_file.stem.replace("_meta", "")
|
||
-
|
||
- # Load metadata
|
||
- with open(metadata_file, 'r', encoding='utf-8') as f:
|
||
- metadata = json.load(f)
|
||
-
|
||
- # Check if cache is expired
|
||
- if not self.is_cache_valid(cache_key, metadata.get("symbol"), "stock_data"):
|
||
- # Remove cache file
|
||
- cache_path = Path(metadata["file_path"])
|
||
- if cache_path.exists():
|
||
- cache_path.unlink()
|
||
-
|
||
- # Remove metadata file
|
||
- metadata_file.unlink()
|
||
- cleaned_count += 1
|
||
- print(f"🗑️ Cleaned expired cache: {cache_key}")
|
||
-
|
||
- except Exception as e:
|
||
- print(f"⚠️ Failed to clean cache file {metadata_file}: {e}")
|
||
-
|
||
- print(f"✅ Cache cleanup completed, removed {cleaned_count} expired files")
|
||
-
|
||
- except Exception as e:
|
||
- print(f"❌ Failed to cleanup cache: {e}")
|
||
-
|
||
-
|
||
-# Global cache instance
|
||
-_global_cache = None
|
||
-
|
||
-def get_cache(cache_dir: str = None) -> StockDataCache:
|
||
- """
|
||
- Get global cache instance
|
||
-
|
||
- Args:
|
||
- cache_dir: Cache directory path
|
||
-
|
||
- Returns:
|
||
- StockDataCache instance
|
||
- """
|
||
- global _global_cache
|
||
- if _global_cache is None:
|
||
- _global_cache = StockDataCache(cache_dir)
|
||
- return _global_cache
|
||
-
|
||
-
|
||
-# Convenience functions
|
||
-def save_stock_data(symbol: str, data: Union[str, pd.DataFrame],
|
||
- start_date: str, end_date: str, data_source: str = "unknown") -> str:
|
||
- """Save stock data to cache (convenience function)"""
|
||
- cache = get_cache()
|
||
- return cache.save_stock_data(symbol, data, start_date, end_date, data_source)
|
||
-
|
||
-
|
||
-def load_stock_data(cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
|
||
- """Load stock data from cache (convenience function)"""
|
||
- cache = get_cache()
|
||
- return cache.load_stock_data(cache_key)
|
||
-
|
||
-
|
||
-def find_cached_stock_data(symbol: str, start_date: str, end_date: str,
|
||
- data_source: str = "unknown") -> Optional[str]:
|
||
- """Find cached stock data (convenience function)"""
|
||
- cache = get_cache()
|
||
- return cache.find_cached_stock_data(symbol, start_date, end_date, data_source)
|
||
-
|
||
-
|
||
-if __name__ == "__main__":
|
||
- # Test the cache manager
|
||
- print("🧪 Testing Stock Data Cache Manager...")
|
||
-
|
||
- # Initialize cache
|
||
- cache = StockDataCache()
|
||
-
|
||
- # Test data
|
||
- test_data = "Sample stock data for AAPL"
|
||
- cache_key = cache.save_stock_data("AAPL", test_data, "2024-01-01", "2024-01-31", "test")
|
||
-
|
||
- # Load data
|
||
- loaded_data = cache.load_stock_data(cache_key)
|
||
- print(f"Loaded data: {loaded_data}")
|
||
-
|
||
- # Check cache validity
|
||
- is_valid = cache.is_cache_valid(cache_key, "AAPL")
|
||
- print(f"Cache valid: {is_valid}")
|
||
-
|
||
- # Get statistics
|
||
- stats = cache.get_cache_stats()
|
||
- print(f"Cache stats: {stats}")
|
||
-
|
||
- print("✅ Cache manager test completed!")
|
||
+ """获取缓存统计信息"""
|
||
+ stats = {
|
||
+ 'total_files': 0,
|
||
+ 'stock_data_count': 0,
|
||
+ 'news_count': 0,
|
||
+ 'fundamentals_count': 0,
|
||
+ 'total_size_mb': 0
|
||
+ }
|
||
+
|
||
+ for metadata_file in self.metadata_dir.glob("*_meta.json"):
|
||
+ try:
|
||
+ with open(metadata_file, 'r', encoding='utf-8') as f:
|
||
+ metadata = json.load(f)
|
||
+
|
||
+ data_type = metadata.get('data_type', 'unknown')
|
||
+ if data_type == 'stock_data':
|
||
+ stats['stock_data_count'] += 1
|
||
+ elif data_type == 'news':
|
||
+ stats['news_count'] += 1
|
||
+ elif data_type == 'fundamentals':
|
||
+ stats['fundamentals_count'] += 1
|
||
+
|
||
+ # 计算文件大小
|
||
+ data_file = Path(metadata['file_path'])
|
||
+ if data_file.exists():
|
||
+ stats['total_size_mb'] += data_file.stat().st_size / (1024 * 1024)
|
||
+
|
||
+ stats['total_files'] += 1
|
||
+
|
||
+ except Exception:
|
||
+ continue
|
||
+
|
||
+ stats['total_size_mb'] = round(stats['total_size_mb'], 2)
|
||
+ return stats
|
||
+
|
||
+
|
||
+# 全局缓存实例
|
||
+_cache_instance = None
|
||
+
|
||
+def get_cache() -> StockDataCache:
|
||
+ """获取全局缓存实例"""
|
||
+ global _cache_instance
|
||
+ if _cache_instance is None:
|
||
+ _cache_instance = StockDataCache()
|
||
+ return _cache_instance
|