516 lines
19 KiB
Python
516 lines
19 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Stock Data Cache Manager
|
|
Supports local caching of stock data to reduce API calls and improve response speed
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import pickle
|
|
import pandas as pd
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, Any, Union
|
|
import hashlib
|
|
|
|
|
|
class StockDataCache:
|
|
"""Stock Data Cache Manager - Supports optimized caching for US and Chinese stock data"""
|
|
|
|
def __init__(self, cache_dir: str = None):
|
|
"""
|
|
Initialize cache manager
|
|
|
|
Args:
|
|
cache_dir: Cache directory path, defaults to tradingagents/dataflows/data_cache
|
|
"""
|
|
if cache_dir is None:
|
|
# Get current file directory
|
|
current_dir = Path(__file__).parent
|
|
cache_dir = current_dir / "data_cache"
|
|
|
|
self.cache_dir = Path(cache_dir)
|
|
self.cache_dir.mkdir(exist_ok=True)
|
|
|
|
# Create subdirectories - categorized by market
|
|
self.us_stock_dir = self.cache_dir / "us_stocks"
|
|
self.china_stock_dir = self.cache_dir / "china_stocks"
|
|
self.us_news_dir = self.cache_dir / "us_news"
|
|
self.china_news_dir = self.cache_dir / "china_news"
|
|
self.us_fundamentals_dir = self.cache_dir / "us_fundamentals"
|
|
self.china_fundamentals_dir = self.cache_dir / "china_fundamentals"
|
|
self.metadata_dir = self.cache_dir / "metadata"
|
|
|
|
# Create all directories
|
|
for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
|
|
self.china_news_dir, self.us_fundamentals_dir,
|
|
self.china_fundamentals_dir, self.metadata_dir]:
|
|
dir_path.mkdir(exist_ok=True)
|
|
|
|
# Cache configuration - different TTL settings for different markets
|
|
self.cache_config = {
|
|
'us_stock_data': {
|
|
'ttl_hours': 2, # US stock data cached for 2 hours (considering API limits)
|
|
'max_files': 1000,
|
|
'description': 'US stock historical data'
|
|
},
|
|
'china_stock_data': {
|
|
'ttl_hours': 1, # A-share data cached for 1 hour (high real-time requirement)
|
|
'max_files': 1000,
|
|
'description': 'A-share historical data'
|
|
},
|
|
'us_news': {
|
|
'ttl_hours': 6, # US stock news cached for 6 hours
|
|
'max_files': 500,
|
|
'description': 'US stock news data'
|
|
},
|
|
'china_news': {
|
|
'ttl_hours': 4, # A-share news cached for 4 hours
|
|
'max_files': 500,
|
|
'description': 'A-share news data'
|
|
},
|
|
'us_fundamentals': {
|
|
'ttl_hours': 24, # US stock fundamentals cached for 24 hours
|
|
'max_files': 200,
|
|
'description': 'US stock fundamentals data'
|
|
},
|
|
'china_fundamentals': {
|
|
'ttl_hours': 12, # A-share fundamentals cached for 12 hours
|
|
'max_files': 200,
|
|
'description': 'A-share fundamentals data'
|
|
}
|
|
}
|
|
|
|
print(f"📁 Cache manager initialized, cache directory: {self.cache_dir}")
|
|
print(f"🗄️ Database cache manager initialized")
|
|
print(f" US stock data: ✅ Configured")
|
|
print(f" A-share data: ✅ Configured")
|
|
|
|
def _determine_market_type(self, symbol: str) -> str:
|
|
"""Determine market type based on stock symbol"""
|
|
import re
|
|
|
|
# Check if it's Chinese A-share (6-digit number)
|
|
if re.match(r'^\d{6}$', str(symbol)):
|
|
return 'china'
|
|
else:
|
|
return 'us'
|
|
|
|
def _generate_cache_key(self, data_type: str, symbol: str, **kwargs) -> str:
|
|
"""Generate cache key"""
|
|
# Create a string containing all parameters
|
|
params_str = f"{data_type}_{symbol}"
|
|
for key, value in sorted(kwargs.items()):
|
|
params_str += f"_{key}_{value}"
|
|
|
|
# Use MD5 to generate short unique identifier
|
|
cache_key = hashlib.md5(params_str.encode()).hexdigest()[:12]
|
|
return f"{symbol}_{data_type}_{cache_key}"
|
|
|
|
def _get_cache_path(self, data_type: str, cache_key: str, file_format: str = "json", symbol: str = None) -> Path:
|
|
"""Get cache file path - supports market classification"""
|
|
if symbol:
|
|
market_type = self._determine_market_type(symbol)
|
|
else:
|
|
# Try to extract market type from cache key
|
|
market_type = 'us' if not cache_key.startswith(('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')) else 'china'
|
|
|
|
# Select directory based on data type and market type
|
|
if data_type == "stock_data":
|
|
base_dir = self.china_stock_dir if market_type == 'china' else self.us_stock_dir
|
|
elif data_type == "news":
|
|
base_dir = self.china_news_dir if market_type == 'china' else self.us_news_dir
|
|
elif data_type == "fundamentals":
|
|
base_dir = self.china_fundamentals_dir if market_type == 'china' else self.us_fundamentals_dir
|
|
else:
|
|
base_dir = self.cache_dir
|
|
|
|
return base_dir / f"{cache_key}.{file_format}"
|
|
|
|
def _get_metadata_path(self, cache_key: str) -> Path:
|
|
"""Get metadata file path"""
|
|
return self.metadata_dir / f"{cache_key}_meta.json"
|
|
|
|
def _save_metadata(self, cache_key: str, metadata: Dict[str, Any]):
|
|
"""Save cache metadata"""
|
|
metadata_path = self._get_metadata_path(cache_key)
|
|
try:
|
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|
json.dump(metadata, f, ensure_ascii=False, indent=2, default=str)
|
|
except Exception as e:
|
|
print(f"⚠️ Failed to save metadata: {e}")
|
|
|
|
def _load_metadata(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
|
"""Load cache metadata"""
|
|
metadata_path = self._get_metadata_path(cache_key)
|
|
if not metadata_path.exists():
|
|
return None
|
|
|
|
try:
|
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
print(f"⚠️ Failed to load metadata: {e}")
|
|
return None
|
|
|
|
def save_stock_data(self, symbol: str, data: Union[str, pd.DataFrame],
|
|
start_date: str, end_date: str, data_source: str = "unknown") -> str:
|
|
"""
|
|
Save stock data to cache
|
|
|
|
Args:
|
|
symbol: Stock symbol
|
|
data: Stock data (string or DataFrame)
|
|
start_date: Start date
|
|
end_date: End date
|
|
data_source: Data source name
|
|
|
|
Returns:
|
|
Cache key
|
|
"""
|
|
try:
|
|
# Generate cache key
|
|
cache_key = self._generate_cache_key(
|
|
"stock_data", symbol,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
data_source=data_source
|
|
)
|
|
|
|
# Determine file format and save data
|
|
if isinstance(data, pd.DataFrame):
|
|
# Save DataFrame as pickle for better performance
|
|
cache_path = self._get_cache_path("stock_data", cache_key, "pkl", symbol)
|
|
data.to_pickle(cache_path)
|
|
data_type = "dataframe"
|
|
else:
|
|
# Save string data as JSON
|
|
cache_path = self._get_cache_path("stock_data", cache_key, "json", symbol)
|
|
with open(cache_path, 'w', encoding='utf-8') as f:
|
|
json.dump({"data": data}, f, ensure_ascii=False, indent=2)
|
|
data_type = "string"
|
|
|
|
# Save metadata
|
|
market_type = self._determine_market_type(symbol)
|
|
metadata = {
|
|
"symbol": symbol,
|
|
"data_type": data_type,
|
|
"start_date": start_date,
|
|
"end_date": end_date,
|
|
"data_source": data_source,
|
|
"market_type": market_type,
|
|
"cache_time": datetime.now().isoformat(),
|
|
"file_path": str(cache_path),
|
|
"cache_key": cache_key
|
|
}
|
|
self._save_metadata(cache_key, metadata)
|
|
|
|
print(f"💾 Stock data cached: {symbol} ({market_type.upper()}) -> {cache_key}")
|
|
return cache_key
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to save stock data cache: {e}")
|
|
return None
|
|
|
|
def load_stock_data(self, cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
|
|
"""
|
|
Load stock data from cache
|
|
|
|
Args:
|
|
cache_key: Cache key
|
|
|
|
Returns:
|
|
Stock data or None if not found
|
|
"""
|
|
try:
|
|
# Load metadata
|
|
metadata = self._load_metadata(cache_key)
|
|
if not metadata:
|
|
print(f"⚠️ Cache metadata not found: {cache_key}")
|
|
return None
|
|
|
|
# Get file path
|
|
cache_path = Path(metadata["file_path"])
|
|
if not cache_path.exists():
|
|
print(f"⚠️ Cache file not found: {cache_path}")
|
|
return None
|
|
|
|
# Load data based on type
|
|
if metadata["data_type"] == "dataframe":
|
|
data = pd.read_pickle(cache_path)
|
|
else:
|
|
with open(cache_path, 'r', encoding='utf-8') as f:
|
|
json_data = json.load(f)
|
|
data = json_data["data"]
|
|
|
|
print(f"📖 Stock data loaded from cache: {metadata['symbol']} -> {cache_key}")
|
|
return data
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to load stock data from cache: {e}")
|
|
return None
|
|
|
|
def find_cached_stock_data(self, symbol: str, start_date: str, end_date: str,
|
|
data_source: str = "unknown") -> Optional[str]:
|
|
"""
|
|
Find cached stock data
|
|
|
|
Args:
|
|
symbol: Stock symbol
|
|
start_date: Start date
|
|
end_date: End date
|
|
data_source: Data source name
|
|
|
|
Returns:
|
|
Cache key if found, None otherwise
|
|
"""
|
|
# Generate expected cache key
|
|
cache_key = self._generate_cache_key(
|
|
"stock_data", symbol,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
data_source=data_source
|
|
)
|
|
|
|
# Check if metadata exists
|
|
metadata = self._load_metadata(cache_key)
|
|
if metadata:
|
|
cache_path = Path(metadata["file_path"])
|
|
if cache_path.exists():
|
|
return cache_key
|
|
|
|
return None
|
|
|
|
def is_cache_valid(self, cache_key: str, symbol: str = None, data_type: str = "stock_data") -> bool:
|
|
"""
|
|
Check if cache is still valid
|
|
|
|
Args:
|
|
cache_key: Cache key
|
|
symbol: Stock symbol (for market type determination)
|
|
data_type: Data type
|
|
|
|
Returns:
|
|
True if cache is valid, False otherwise
|
|
"""
|
|
try:
|
|
# Load metadata
|
|
metadata = self._load_metadata(cache_key)
|
|
if not metadata:
|
|
return False
|
|
|
|
# Check if file exists
|
|
cache_path = Path(metadata["file_path"])
|
|
if not cache_path.exists():
|
|
return False
|
|
|
|
# Determine market type and get TTL
|
|
if symbol:
|
|
market_type = self._determine_market_type(symbol)
|
|
else:
|
|
market_type = metadata.get("market_type", "us")
|
|
|
|
cache_type_key = f"{market_type}_{data_type}"
|
|
if cache_type_key not in self.cache_config:
|
|
cache_type_key = "us_stock_data" # Default fallback
|
|
|
|
ttl_hours = self.cache_config[cache_type_key]["ttl_hours"]
|
|
|
|
# Check if cache has expired
|
|
cache_time = datetime.fromisoformat(metadata["cache_time"])
|
|
expiry_time = cache_time + timedelta(hours=ttl_hours)
|
|
|
|
is_valid = datetime.now() < expiry_time
|
|
if not is_valid:
|
|
print(f"⏰ Cache expired: {cache_key} (cached at {cache_time}, TTL: {ttl_hours}h)")
|
|
|
|
return is_valid
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to check cache validity: {e}")
|
|
return False
|
|
|
|
def get_cache_stats(self) -> Dict[str, Any]:
|
|
"""
|
|
Get cache statistics
|
|
|
|
Returns:
|
|
Dictionary containing cache statistics
|
|
"""
|
|
try:
|
|
stats = {
|
|
"cache_dir": str(self.cache_dir),
|
|
"total_files": 0,
|
|
"total_size_mb": 0,
|
|
"stock_data_count": 0,
|
|
"news_count": 0,
|
|
"fundamentals_count": 0,
|
|
"us_data_count": 0,
|
|
"china_data_count": 0
|
|
}
|
|
|
|
# Count files in each directory
|
|
for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
|
|
self.china_news_dir, self.us_fundamentals_dir,
|
|
self.china_fundamentals_dir, self.metadata_dir]:
|
|
if dir_path.exists():
|
|
files = list(dir_path.glob("*"))
|
|
stats["total_files"] += len(files)
|
|
|
|
# Calculate total size
|
|
for file_path in files:
|
|
if file_path.is_file():
|
|
stats["total_size_mb"] += file_path.stat().st_size / (1024 * 1024)
|
|
|
|
# Count by data type
|
|
if self.us_stock_dir.exists():
|
|
stats["stock_data_count"] += len(list(self.us_stock_dir.glob("*")))
|
|
stats["us_data_count"] += len(list(self.us_stock_dir.glob("*")))
|
|
|
|
if self.china_stock_dir.exists():
|
|
stats["stock_data_count"] += len(list(self.china_stock_dir.glob("*")))
|
|
stats["china_data_count"] += len(list(self.china_stock_dir.glob("*")))
|
|
|
|
if self.us_news_dir.exists():
|
|
stats["news_count"] += len(list(self.us_news_dir.glob("*")))
|
|
stats["us_data_count"] += len(list(self.us_news_dir.glob("*")))
|
|
|
|
if self.china_news_dir.exists():
|
|
stats["news_count"] += len(list(self.china_news_dir.glob("*")))
|
|
stats["china_data_count"] += len(list(self.china_news_dir.glob("*")))
|
|
|
|
if self.us_fundamentals_dir.exists():
|
|
stats["fundamentals_count"] += len(list(self.us_fundamentals_dir.glob("*")))
|
|
stats["us_data_count"] += len(list(self.us_fundamentals_dir.glob("*")))
|
|
|
|
if self.china_fundamentals_dir.exists():
|
|
stats["fundamentals_count"] += len(list(self.china_fundamentals_dir.glob("*")))
|
|
stats["china_data_count"] += len(list(self.china_fundamentals_dir.glob("*")))
|
|
|
|
# Round size to 2 decimal places
|
|
stats["total_size_mb"] = round(stats["total_size_mb"], 2)
|
|
|
|
return stats
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to get cache statistics: {e}")
|
|
return {"error": str(e)}
|
|
|
|
def cleanup_expired_cache(self):
|
|
"""Clean up expired cache files"""
|
|
try:
|
|
cleaned_count = 0
|
|
|
|
# Check all metadata files
|
|
if self.metadata_dir.exists():
|
|
for metadata_file in self.metadata_dir.glob("*_meta.json"):
|
|
try:
|
|
cache_key = metadata_file.stem.replace("_meta", "")
|
|
|
|
# Load metadata
|
|
with open(metadata_file, 'r', encoding='utf-8') as f:
|
|
metadata = json.load(f)
|
|
|
|
# Check if cache is expired
|
|
if not self.is_cache_valid(cache_key, metadata.get("symbol"), "stock_data"):
|
|
# Remove cache file
|
|
cache_path = Path(metadata["file_path"])
|
|
if cache_path.exists():
|
|
cache_path.unlink()
|
|
|
|
# Remove metadata file
|
|
metadata_file.unlink()
|
|
cleaned_count += 1
|
|
print(f"🗑️ Cleaned expired cache: {cache_key}")
|
|
|
|
except Exception as e:
|
|
print(f"⚠️ Failed to clean cache file {metadata_file}: {e}")
|
|
|
|
print(f"✅ Cache cleanup completed, removed {cleaned_count} expired files")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to cleanup cache: {e}")
|
|
|
|
|
|
# Global cache instance
|
|
_global_cache = None
|
|
|
|
def get_cache(cache_dir: str = None):
|
|
"""
|
|
Get global cache instance with intelligent cache selection
|
|
|
|
This function will automatically choose between:
|
|
1. Integrated cache (with database support) if available
|
|
2. Traditional file cache as fallback
|
|
|
|
Args:
|
|
cache_dir: Cache directory path
|
|
|
|
Returns:
|
|
Cache instance (IntegratedCacheManager or StockDataCache)
|
|
"""
|
|
global _global_cache
|
|
if _global_cache is None:
|
|
# Try to use integrated cache manager first
|
|
try:
|
|
from .integrated_cache import IntegratedCacheManager
|
|
_global_cache = IntegratedCacheManager(cache_dir)
|
|
print("🚀 Using integrated cache manager with database support")
|
|
except ImportError:
|
|
# Fallback to traditional cache
|
|
_global_cache = StockDataCache(cache_dir)
|
|
print("📁 Using traditional file cache")
|
|
except Exception as e:
|
|
# If integrated cache fails, fallback to traditional cache
|
|
print(f"⚠️ Integrated cache initialization failed: {e}")
|
|
print("📁 Falling back to traditional file cache")
|
|
_global_cache = StockDataCache(cache_dir)
|
|
return _global_cache
|
|
|
|
|
|
# Convenience functions
|
|
def save_stock_data(symbol: str, data: Union[str, pd.DataFrame],
|
|
start_date: str, end_date: str, data_source: str = "unknown") -> str:
|
|
"""Save stock data to cache (convenience function)"""
|
|
cache = get_cache()
|
|
return cache.save_stock_data(symbol, data, start_date, end_date, data_source)
|
|
|
|
|
|
def load_stock_data(cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
|
|
"""Load stock data from cache (convenience function)"""
|
|
cache = get_cache()
|
|
return cache.load_stock_data(cache_key)
|
|
|
|
|
|
def find_cached_stock_data(symbol: str, start_date: str, end_date: str,
|
|
data_source: str = "unknown") -> Optional[str]:
|
|
"""Find cached stock data (convenience function)"""
|
|
cache = get_cache()
|
|
return cache.find_cached_stock_data(symbol, start_date, end_date, data_source)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Test the cache manager
|
|
print("🧪 Testing Stock Data Cache Manager...")
|
|
|
|
# Initialize cache
|
|
cache = StockDataCache()
|
|
|
|
# Test data
|
|
test_data = "Sample stock data for AAPL"
|
|
cache_key = cache.save_stock_data("AAPL", test_data, "2024-01-01", "2024-01-31", "test")
|
|
|
|
# Load data
|
|
loaded_data = cache.load_stock_data(cache_key)
|
|
print(f"Loaded data: {loaded_data}")
|
|
|
|
# Check cache validity
|
|
is_valid = cache.is_cache_valid(cache_key, "AAPL")
|
|
print(f"Cache valid: {is_valid}")
|
|
|
|
# Get statistics
|
|
stats = cache.get_cache_stats()
|
|
print(f"Cache stats: {stats}")
|
|
|
|
print("✅ Cache manager test completed!")
|