feat: merge requirements.txt and add optional dependencies

- Add 5 new dependencies from Chinese version:
  * dashscope - Alibaba Cloud LLM support
  * streamlit - Web app framework
  * plotly - Interactive plotting
  * pytdx - TongDaXin API for Chinese stock data
  * pymongo - MongoDB database support

- Create requirements-optional.txt for optional dependencies
- Update pyproject.toml with optional dependency groups:
  * chinese: pytdx, dashscope
  * database: pymongo
  * visualization: streamlit, plotly
  * development: pytest, black, flake8
  * all: includes all optional dependencies

Installation options:
- Basic: pip install -r requirements.txt
- With Chinese support: pip install .[chinese]
- With all features: pip install .[all]
This commit is contained in:
liuping 2025-07-06 00:50:20 +08:00
parent 1cbdd4c76b
commit 1f292d5ab1
18 changed files with 642 additions and 3570 deletions

1
.gitignore vendored
View File

@ -17,3 +17,4 @@ TradingAgentsCN/
# 虚拟环境目录(不纳入版本控制)
test_env/
.venv/

47
config/models.json Normal file
View File

@ -0,0 +1,47 @@
[
{
"provider": "dashscope",
"model_name": "qwen-turbo",
"api_key": "",
"base_url": null,
"max_tokens": 4000,
"temperature": 0.7,
"enabled": true
},
{
"provider": "dashscope",
"model_name": "qwen-plus-latest",
"api_key": "",
"base_url": null,
"max_tokens": 8000,
"temperature": 0.7,
"enabled": true
},
{
"provider": "openai",
"model_name": "gpt-3.5-turbo",
"api_key": "",
"base_url": null,
"max_tokens": 4000,
"temperature": 0.7,
"enabled": false
},
{
"provider": "openai",
"model_name": "gpt-4",
"api_key": "",
"base_url": null,
"max_tokens": 8000,
"temperature": 0.7,
"enabled": false
},
{
"provider": "google",
"model_name": "gemini-pro",
"api_key": "",
"base_url": null,
"max_tokens": 4000,
"temperature": 0.7,
"enabled": false
}
]

58
config/pricing.json Normal file
View File

@ -0,0 +1,58 @@
[
{
"provider": "dashscope",
"model_name": "qwen-turbo",
"input_price_per_1k": 0.002,
"output_price_per_1k": 0.006,
"currency": "CNY"
},
{
"provider": "dashscope",
"model_name": "qwen-plus-latest",
"input_price_per_1k": 0.004,
"output_price_per_1k": 0.012,
"currency": "CNY"
},
{
"provider": "dashscope",
"model_name": "qwen-max",
"input_price_per_1k": 0.02,
"output_price_per_1k": 0.06,
"currency": "CNY"
},
{
"provider": "openai",
"model_name": "gpt-3.5-turbo",
"input_price_per_1k": 0.0015,
"output_price_per_1k": 0.002,
"currency": "USD"
},
{
"provider": "openai",
"model_name": "gpt-4",
"input_price_per_1k": 0.03,
"output_price_per_1k": 0.06,
"currency": "USD"
},
{
"provider": "openai",
"model_name": "gpt-4-turbo",
"input_price_per_1k": 0.01,
"output_price_per_1k": 0.03,
"currency": "USD"
},
{
"provider": "google",
"model_name": "gemini-pro",
"input_price_per_1k": 0.00025,
"output_price_per_1k": 0.0005,
"currency": "USD"
},
{
"provider": "google",
"model_name": "gemini-pro-vision",
"input_price_per_1k": 0.00025,
"output_price_per_1k": 0.0005,
"currency": "USD"
}
]

13
config/settings.json Normal file
View File

@ -0,0 +1,13 @@
{
"default_provider": "dashscope",
"default_model": "qwen-turbo",
"enable_cost_tracking": true,
"cost_alert_threshold": 100.0,
"currency_preference": "CNY",
"auto_save_usage": true,
"max_usage_records": 10000,
"data_dir": "C:\\Users\\PC\\Documents\\TradingAgents\\data",
"cache_dir": "C:\\Users\\PC\\Documents\\TradingAgents\\data\\cache",
"results_dir": "C:\\Users\\PC\\Documents\\TradingAgents\\results",
"auto_create_dirs": true
}

View File

@ -32,3 +32,29 @@ dependencies = [
"typing-extensions>=4.14.0",
"yfinance>=0.2.63",
]
[project.optional-dependencies]
chinese = [
"pytdx",
"dashscope",
]
database = [
"pymongo>=4.0.0",
]
visualization = [
"streamlit",
"plotly",
]
development = [
"pytest",
"black",
"flake8",
]
all = [
"pytdx",
"dashscope",
"pymongo>=4.0.0",
"streamlit",
"plotly",
]

20
requirements-optional.txt Normal file
View File

@ -0,0 +1,20 @@
# Optional Dependencies for TradingAgents
# Install specific groups as needed:
# pip install -r requirements-optional.txt
# Chinese dependencies
pytdx # TongDaXin API for Chinese stock real-time data
dashscope # Alibaba Cloud LLM support
# Database dependencies
pymongo # MongoDB database support for token usage storage
# Visualization dependencies
streamlit # Web app framework
plotly # Interactive plotting
# Development dependencies
pytest # Testing framework
black # Code formatter
flake8 # Code linter

View File

@ -24,3 +24,10 @@ rich
questionary
langchain_anthropic
langchain-google-genai
# New dependencies from Chinese version
dashscope
streamlit
plotly
pytdx # TongDaXin API for Chinese stock real-time data
pymongo # MongoDB database support for token usage storage

321
test_merged_features.py Normal file
View File

@ -0,0 +1,321 @@
#!/usr/bin/env python3
"""
合并后功能测试验证脚本
测试所有新增功能和原有功能的兼容性
"""
import sys
import os
import traceback
import tempfile
import shutil
from pathlib import Path
# 添加项目路径
sys.path.insert(0, os.path.abspath('.'))
class MergedFeaturesTest:
"""合并功能测试类"""
def __init__(self):
self.test_results = {
"passed": [],
"failed": [],
"warnings": []
}
self.temp_dir = None
def setup(self):
"""测试环境设置"""
print("🔧 设置测试环境...")
self.temp_dir = tempfile.mkdtemp()
print(f" 临时目录: {self.temp_dir}")
def cleanup(self):
"""清理测试环境"""
if self.temp_dir and os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)
print(f"🧹 清理临时目录: {self.temp_dir}")
def test_basic_imports(self):
"""测试基础模块导入"""
print("\n📦 测试基础模块导入...")
basic_modules = [
"tradingagents.default_config",
"tradingagents.dataflows.interface",
"tradingagents.dataflows.config",
]
for module in basic_modules:
try:
__import__(module)
self.test_results["passed"].append(f"基础导入: {module}")
print(f"{module}")
except Exception as e:
self.test_results["failed"].append(f"基础导入: {module} - {e}")
print(f"{module}: {e}")
def test_cache_system(self):
"""测试缓存系统"""
print("\n💾 测试缓存系统...")
try:
# 测试原有缓存管理器
from tradingagents.dataflows.cache_manager import StockDataCache, get_cache
# 创建缓存实例
cache = StockDataCache(cache_dir=self.temp_dir)
# 测试基本功能
test_data = "Test stock data for AAPL"
cache_key = cache.save_stock_data("AAPL", test_data, "2024-01-01", "2024-01-31", "test")
if cache_key:
loaded_data = cache.load_stock_data(cache_key)
if loaded_data == test_data:
self.test_results["passed"].append("缓存系统: 基本功能正常")
print(" ✅ 基本缓存功能正常")
else:
self.test_results["failed"].append("缓存系统: 数据不匹配")
print(" ❌ 缓存数据不匹配")
else:
self.test_results["failed"].append("缓存系统: 保存失败")
print(" ❌ 缓存保存失败")
# 测试市场类型检测
us_market = cache._determine_market_type("AAPL")
china_market = cache._determine_market_type("000001")
if us_market == "us" and china_market == "china":
self.test_results["passed"].append("缓存系统: 市场类型检测正常")
print(" ✅ 市场类型检测正常")
else:
self.test_results["failed"].append("缓存系统: 市场类型检测异常")
print(" ❌ 市场类型检测异常")
except Exception as e:
self.test_results["failed"].append(f"缓存系统: {e}")
print(f" ❌ 缓存系统测试失败: {e}")
def test_new_features_import(self):
"""测试新功能模块导入"""
print("\n🆕 测试新功能模块导入...")
new_modules = [
# 中国市场数据
("tradingagents.dataflows.chinese_finance_utils", "中国财经数据工具"),
("tradingagents.dataflows.tdx_utils", "通达信API工具"),
("tradingagents.dataflows.optimized_china_data", "优化A股数据提供器"),
# 高级缓存
("tradingagents.dataflows.adaptive_cache", "自适应缓存"),
("tradingagents.dataflows.integrated_cache", "集成缓存"),
("tradingagents.dataflows.db_cache_manager", "数据库缓存管理"),
# 配置管理
("tradingagents.config.database_config", "数据库配置"),
("tradingagents.config.database_manager", "数据库管理器"),
("tradingagents.config.mongodb_storage", "MongoDB存储"),
# LLM适配器
("tradingagents.llm_adapters.dashscope_adapter", "DashScope适配器"),
# API服务
("tradingagents.api.stock_api", "股票API"),
("tradingagents.dataflows.stock_data_service", "股票数据服务"),
("tradingagents.dataflows.realtime_news_utils", "实时新闻工具"),
]
for module_name, description in new_modules:
try:
__import__(module_name)
self.test_results["passed"].append(f"新功能导入: {description}")
print(f"{description}")
except ImportError as e:
if "No module named" in str(e):
self.test_results["warnings"].append(f"新功能导入: {description} - 可能缺少依赖")
print(f" ⚠️ {description}: 可能缺少依赖 ({e})")
else:
self.test_results["failed"].append(f"新功能导入: {description} - {e}")
print(f"{description}: {e}")
except Exception as e:
self.test_results["failed"].append(f"新功能导入: {description} - {e}")
print(f"{description}: {e}")
def test_optimized_data_providers(self):
"""测试优化的数据提供器"""
print("\n📊 测试优化数据提供器...")
try:
# 测试美股数据提供器
from tradingagents.dataflows.optimized_us_data import OptimizedUSDataProvider
provider = OptimizedUSDataProvider()
self.test_results["passed"].append("数据提供器: 美股提供器初始化成功")
print(" ✅ 美股数据提供器初始化成功")
# 测试基本方法存在
required_methods = ['get_stock_data', '_wait_for_rate_limit', '_format_stock_data']
for method in required_methods:
if hasattr(provider, method):
self.test_results["passed"].append(f"数据提供器: {method} 方法存在")
print(f"{method} 方法存在")
else:
self.test_results["failed"].append(f"数据提供器: {method} 方法缺失")
print(f"{method} 方法缺失")
except Exception as e:
self.test_results["failed"].append(f"数据提供器: {e}")
print(f" ❌ 数据提供器测试失败: {e}")
def test_config_system(self):
"""测试配置系统"""
print("\n⚙️ 测试配置系统...")
try:
# 测试默认配置
from tradingagents.default_config import DEFAULT_CONFIG
# 检查基本配置项
required_configs = [
"project_dir", "results_dir", "data_dir",
"llm_provider", "deep_think_llm", "quick_think_llm"
]
for config_key in required_configs:
if config_key in DEFAULT_CONFIG:
self.test_results["passed"].append(f"配置系统: {config_key} 存在")
print(f"{config_key} 配置存在")
else:
self.test_results["failed"].append(f"配置系统: {config_key} 缺失")
print(f"{config_key} 配置缺失")
# 测试动态配置
from tradingagents.dataflows.config import get_config, set_config
current_config = get_config()
if current_config:
self.test_results["passed"].append("配置系统: 动态配置获取正常")
print(" ✅ 动态配置获取正常")
else:
self.test_results["failed"].append("配置系统: 动态配置获取失败")
print(" ❌ 动态配置获取失败")
except Exception as e:
self.test_results["failed"].append(f"配置系统: {e}")
print(f" ❌ 配置系统测试失败: {e}")
def test_main_functionality(self):
"""测试主要功能"""
print("\n🚀 测试主要功能...")
try:
# 测试主程序导入
import main
self.test_results["passed"].append("主功能: main.py 导入成功")
print(" ✅ main.py 导入成功")
# 测试交易图形导入
from tradingagents.graph.trading_graph import TradingAgentsGraph
self.test_results["passed"].append("主功能: TradingAgentsGraph 导入成功")
print(" ✅ TradingAgentsGraph 导入成功")
except Exception as e:
self.test_results["failed"].append(f"主功能: {e}")
print(f" ❌ 主功能测试失败: {e}")
def test_documentation(self):
"""测试文档完整性"""
print("\n📚 测试文档完整性...")
doc_files = [
"docs/README.md",
"docs/en-US/configuration_guide.md",
"docs/en-US/quick_reference.md",
"docs/en-US/prompt_templates.md",
"MERGE_SUMMARY.md"
]
for doc_file in doc_files:
if os.path.exists(doc_file):
self.test_results["passed"].append(f"文档: {doc_file} 存在")
print(f"{doc_file}")
else:
self.test_results["failed"].append(f"文档: {doc_file} 缺失")
print(f"{doc_file}")
def run_all_tests(self):
"""运行所有测试"""
print("🧪 开始合并后功能测试验证")
print("=" * 50)
self.setup()
try:
self.test_basic_imports()
self.test_cache_system()
self.test_new_features_import()
self.test_optimized_data_providers()
self.test_config_system()
self.test_main_functionality()
self.test_documentation()
finally:
self.cleanup()
self.print_summary()
def print_summary(self):
"""打印测试摘要"""
print("\n" + "=" * 50)
print("📋 测试结果摘要")
print("=" * 50)
total_passed = len(self.test_results["passed"])
total_failed = len(self.test_results["failed"])
total_warnings = len(self.test_results["warnings"])
total_tests = total_passed + total_failed + total_warnings
print(f"\n📊 统计:")
print(f" 总测试项: {total_tests}")
print(f" ✅ 通过: {total_passed}")
print(f" ❌ 失败: {total_failed}")
print(f" ⚠️ 警告: {total_warnings}")
if total_failed == 0:
print(f"\n🎉 所有核心功能测试通过!")
if total_warnings > 0:
print(f"⚠️ 有 {total_warnings} 个警告,主要是可选依赖缺失")
else:
print(f"\n❌ 有 {total_failed} 个测试失败,需要修复")
# 详细结果
if self.test_results["failed"]:
print(f"\n❌ 失败的测试:")
for failure in self.test_results["failed"]:
print(f" - {failure}")
if self.test_results["warnings"]:
print(f"\n⚠️ 警告:")
for warning in self.test_results["warnings"]:
print(f" - {warning}")
# 建议
print(f"\n💡 建议:")
if total_failed == 0:
print(" 1. 核心功能正常,可以进行更深入的集成测试")
print(" 2. 考虑安装可选依赖以启用完整功能")
print(" 3. 运行实际的股票数据获取测试")
else:
print(" 1. 修复失败的测试项")
print(" 2. 检查依赖项安装")
print(" 3. 验证文件路径和导入")
def main():
"""主函数"""
tester = MergedFeaturesTest()
tester.run_all_tests()
if __name__ == "__main__":
main()

View File

@ -1,498 +0,0 @@
#!/usr/bin/env python3
"""
Stock Data Cache Manager
Supports local caching of stock data to reduce API calls and improve response speed
"""
import os
import json
import pickle
import pandas as pd
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional, Dict, Any, Union
import hashlib
class StockDataCache:
"""Stock Data Cache Manager - Supports optimized caching for US and Chinese stock data"""
def __init__(self, cache_dir: str = None):
"""
Initialize cache manager
Args:
cache_dir: Cache directory path, defaults to tradingagents/dataflows/data_cache
"""
if cache_dir is None:
# Get current file directory
current_dir = Path(__file__).parent
cache_dir = current_dir / "data_cache"
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(exist_ok=True)
# Create subdirectories - categorized by market
self.us_stock_dir = self.cache_dir / "us_stocks"
self.china_stock_dir = self.cache_dir / "china_stocks"
self.us_news_dir = self.cache_dir / "us_news"
self.china_news_dir = self.cache_dir / "china_news"
self.us_fundamentals_dir = self.cache_dir / "us_fundamentals"
self.china_fundamentals_dir = self.cache_dir / "china_fundamentals"
self.metadata_dir = self.cache_dir / "metadata"
# Create all directories
for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
self.china_news_dir, self.us_fundamentals_dir,
self.china_fundamentals_dir, self.metadata_dir]:
dir_path.mkdir(exist_ok=True)
# Cache configuration - different TTL settings for different markets
self.cache_config = {
'us_stock_data': {
'ttl_hours': 2, # US stock data cached for 2 hours (considering API limits)
'max_files': 1000,
'description': 'US stock historical data'
},
'china_stock_data': {
'ttl_hours': 1, # A-share data cached for 1 hour (high real-time requirement)
'max_files': 1000,
'description': 'A-share historical data'
},
'us_news': {
'ttl_hours': 6, # US stock news cached for 6 hours
'max_files': 500,
'description': 'US stock news data'
},
'china_news': {
'ttl_hours': 4, # A-share news cached for 4 hours
'max_files': 500,
'description': 'A-share news data'
},
'us_fundamentals': {
'ttl_hours': 24, # US stock fundamentals cached for 24 hours
'max_files': 200,
'description': 'US stock fundamentals data'
},
'china_fundamentals': {
'ttl_hours': 12, # A-share fundamentals cached for 12 hours
'max_files': 200,
'description': 'A-share fundamentals data'
}
}
print(f"📁 Cache manager initialized, cache directory: {self.cache_dir}")
print(f"🗄️ Database cache manager initialized")
print(f" US stock data: ✅ Configured")
print(f" A-share data: ✅ Configured")
def _determine_market_type(self, symbol: str) -> str:
"""Determine market type based on stock symbol"""
import re
# Check if it's Chinese A-share (6-digit number)
if re.match(r'^\d{6}$', str(symbol)):
return 'china'
else:
return 'us'
def _generate_cache_key(self, data_type: str, symbol: str, **kwargs) -> str:
"""Generate cache key"""
# Create a string containing all parameters
params_str = f"{data_type}_{symbol}"
for key, value in sorted(kwargs.items()):
params_str += f"_{key}_{value}"
# Use MD5 to generate short unique identifier
cache_key = hashlib.md5(params_str.encode()).hexdigest()[:12]
return f"{symbol}_{data_type}_{cache_key}"
def _get_cache_path(self, data_type: str, cache_key: str, file_format: str = "json", symbol: str = None) -> Path:
"""Get cache file path - supports market classification"""
if symbol:
market_type = self._determine_market_type(symbol)
else:
# Try to extract market type from cache key
market_type = 'us' if not cache_key.startswith(('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')) else 'china'
# Select directory based on data type and market type
if data_type == "stock_data":
base_dir = self.china_stock_dir if market_type == 'china' else self.us_stock_dir
elif data_type == "news":
base_dir = self.china_news_dir if market_type == 'china' else self.us_news_dir
elif data_type == "fundamentals":
base_dir = self.china_fundamentals_dir if market_type == 'china' else self.us_fundamentals_dir
else:
base_dir = self.cache_dir
return base_dir / f"{cache_key}.{file_format}"
def _get_metadata_path(self, cache_key: str) -> Path:
"""Get metadata file path"""
return self.metadata_dir / f"{cache_key}_meta.json"
def _save_metadata(self, cache_key: str, metadata: Dict[str, Any]):
"""Save cache metadata"""
metadata_path = self._get_metadata_path(cache_key)
try:
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, ensure_ascii=False, indent=2, default=str)
except Exception as e:
print(f"⚠️ Failed to save metadata: {e}")
def _load_metadata(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""Load cache metadata"""
metadata_path = self._get_metadata_path(cache_key)
if not metadata_path.exists():
return None
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"⚠️ Failed to load metadata: {e}")
return None
def save_stock_data(self, symbol: str, data: Union[str, pd.DataFrame],
start_date: str, end_date: str, data_source: str = "unknown") -> str:
"""
Save stock data to cache
Args:
symbol: Stock symbol
data: Stock data (string or DataFrame)
start_date: Start date
end_date: End date
data_source: Data source name
Returns:
Cache key
"""
try:
# Generate cache key
cache_key = self._generate_cache_key(
"stock_data", symbol,
start_date=start_date,
end_date=end_date,
data_source=data_source
)
# Determine file format and save data
if isinstance(data, pd.DataFrame):
# Save DataFrame as pickle for better performance
cache_path = self._get_cache_path("stock_data", cache_key, "pkl", symbol)
data.to_pickle(cache_path)
data_type = "dataframe"
else:
# Save string data as JSON
cache_path = self._get_cache_path("stock_data", cache_key, "json", symbol)
with open(cache_path, 'w', encoding='utf-8') as f:
json.dump({"data": data}, f, ensure_ascii=False, indent=2)
data_type = "string"
# Save metadata
market_type = self._determine_market_type(symbol)
metadata = {
"symbol": symbol,
"data_type": data_type,
"start_date": start_date,
"end_date": end_date,
"data_source": data_source,
"market_type": market_type,
"cache_time": datetime.now().isoformat(),
"file_path": str(cache_path),
"cache_key": cache_key
}
self._save_metadata(cache_key, metadata)
print(f"💾 Stock data cached: {symbol} ({market_type.upper()}) -> {cache_key}")
return cache_key
except Exception as e:
print(f"❌ Failed to save stock data cache: {e}")
return None
def load_stock_data(self, cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
"""
Load stock data from cache
Args:
cache_key: Cache key
Returns:
Stock data or None if not found
"""
try:
# Load metadata
metadata = self._load_metadata(cache_key)
if not metadata:
print(f"⚠️ Cache metadata not found: {cache_key}")
return None
# Get file path
cache_path = Path(metadata["file_path"])
if not cache_path.exists():
print(f"⚠️ Cache file not found: {cache_path}")
return None
# Load data based on type
if metadata["data_type"] == "dataframe":
data = pd.read_pickle(cache_path)
else:
with open(cache_path, 'r', encoding='utf-8') as f:
json_data = json.load(f)
data = json_data["data"]
print(f"📖 Stock data loaded from cache: {metadata['symbol']} -> {cache_key}")
return data
except Exception as e:
print(f"❌ Failed to load stock data from cache: {e}")
return None
def find_cached_stock_data(self, symbol: str, start_date: str, end_date: str,
data_source: str = "unknown") -> Optional[str]:
"""
Find cached stock data
Args:
symbol: Stock symbol
start_date: Start date
end_date: End date
data_source: Data source name
Returns:
Cache key if found, None otherwise
"""
# Generate expected cache key
cache_key = self._generate_cache_key(
"stock_data", symbol,
start_date=start_date,
end_date=end_date,
data_source=data_source
)
# Check if metadata exists
metadata = self._load_metadata(cache_key)
if metadata:
cache_path = Path(metadata["file_path"])
if cache_path.exists():
return cache_key
return None
def is_cache_valid(self, cache_key: str, symbol: str = None, data_type: str = "stock_data") -> bool:
"""
Check if cache is still valid
Args:
cache_key: Cache key
symbol: Stock symbol (for market type determination)
data_type: Data type
Returns:
True if cache is valid, False otherwise
"""
try:
# Load metadata
metadata = self._load_metadata(cache_key)
if not metadata:
return False
# Check if file exists
cache_path = Path(metadata["file_path"])
if not cache_path.exists():
return False
# Determine market type and get TTL
if symbol:
market_type = self._determine_market_type(symbol)
else:
market_type = metadata.get("market_type", "us")
cache_type_key = f"{market_type}_{data_type}"
if cache_type_key not in self.cache_config:
cache_type_key = "us_stock_data" # Default fallback
ttl_hours = self.cache_config[cache_type_key]["ttl_hours"]
# Check if cache has expired
cache_time = datetime.fromisoformat(metadata["cache_time"])
expiry_time = cache_time + timedelta(hours=ttl_hours)
is_valid = datetime.now() < expiry_time
if not is_valid:
print(f"⏰ Cache expired: {cache_key} (cached at {cache_time}, TTL: {ttl_hours}h)")
return is_valid
except Exception as e:
print(f"❌ Failed to check cache validity: {e}")
return False
def get_cache_stats(self) -> Dict[str, Any]:
"""
Get cache statistics
Returns:
Dictionary containing cache statistics
"""
try:
stats = {
"cache_dir": str(self.cache_dir),
"total_files": 0,
"total_size_mb": 0,
"stock_data_count": 0,
"news_count": 0,
"fundamentals_count": 0,
"us_data_count": 0,
"china_data_count": 0
}
# Count files in each directory
for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
self.china_news_dir, self.us_fundamentals_dir,
self.china_fundamentals_dir, self.metadata_dir]:
if dir_path.exists():
files = list(dir_path.glob("*"))
stats["total_files"] += len(files)
# Calculate total size
for file_path in files:
if file_path.is_file():
stats["total_size_mb"] += file_path.stat().st_size / (1024 * 1024)
# Count by data type
if self.us_stock_dir.exists():
stats["stock_data_count"] += len(list(self.us_stock_dir.glob("*")))
stats["us_data_count"] += len(list(self.us_stock_dir.glob("*")))
if self.china_stock_dir.exists():
stats["stock_data_count"] += len(list(self.china_stock_dir.glob("*")))
stats["china_data_count"] += len(list(self.china_stock_dir.glob("*")))
if self.us_news_dir.exists():
stats["news_count"] += len(list(self.us_news_dir.glob("*")))
stats["us_data_count"] += len(list(self.us_news_dir.glob("*")))
if self.china_news_dir.exists():
stats["news_count"] += len(list(self.china_news_dir.glob("*")))
stats["china_data_count"] += len(list(self.china_news_dir.glob("*")))
if self.us_fundamentals_dir.exists():
stats["fundamentals_count"] += len(list(self.us_fundamentals_dir.glob("*")))
stats["us_data_count"] += len(list(self.us_fundamentals_dir.glob("*")))
if self.china_fundamentals_dir.exists():
stats["fundamentals_count"] += len(list(self.china_fundamentals_dir.glob("*")))
stats["china_data_count"] += len(list(self.china_fundamentals_dir.glob("*")))
# Round size to 2 decimal places
stats["total_size_mb"] = round(stats["total_size_mb"], 2)
return stats
except Exception as e:
print(f"❌ Failed to get cache statistics: {e}")
return {"error": str(e)}
def cleanup_expired_cache(self):
"""Clean up expired cache files"""
try:
cleaned_count = 0
# Check all metadata files
if self.metadata_dir.exists():
for metadata_file in self.metadata_dir.glob("*_meta.json"):
try:
cache_key = metadata_file.stem.replace("_meta", "")
# Load metadata
with open(metadata_file, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Check if cache is expired
if not self.is_cache_valid(cache_key, metadata.get("symbol"), "stock_data"):
# Remove cache file
cache_path = Path(metadata["file_path"])
if cache_path.exists():
cache_path.unlink()
# Remove metadata file
metadata_file.unlink()
cleaned_count += 1
print(f"🗑️ Cleaned expired cache: {cache_key}")
except Exception as e:
print(f"⚠️ Failed to clean cache file {metadata_file}: {e}")
print(f"✅ Cache cleanup completed, removed {cleaned_count} expired files")
except Exception as e:
print(f"❌ Failed to cleanup cache: {e}")
# Global cache instance
_global_cache = None
def get_cache(cache_dir: str = None) -> StockDataCache:
"""
Get global cache instance
Args:
cache_dir: Cache directory path
Returns:
StockDataCache instance
"""
global _global_cache
if _global_cache is None:
_global_cache = StockDataCache(cache_dir)
return _global_cache
# Convenience functions
def save_stock_data(symbol: str, data: Union[str, pd.DataFrame],
start_date: str, end_date: str, data_source: str = "unknown") -> str:
"""Save stock data to cache (convenience function)"""
cache = get_cache()
return cache.save_stock_data(symbol, data, start_date, end_date, data_source)
def load_stock_data(cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
"""Load stock data from cache (convenience function)"""
cache = get_cache()
return cache.load_stock_data(cache_key)
def find_cached_stock_data(symbol: str, start_date: str, end_date: str,
data_source: str = "unknown") -> Optional[str]:
"""Find cached stock data (convenience function)"""
cache = get_cache()
return cache.find_cached_stock_data(symbol, start_date, end_date, data_source)
if __name__ == "__main__":
# Test the cache manager
print("🧪 Testing Stock Data Cache Manager...")
# Initialize cache
cache = StockDataCache()
# Test data
test_data = "Sample stock data for AAPL"
cache_key = cache.save_stock_data("AAPL", test_data, "2024-01-01", "2024-01-31", "test")
# Load data
loaded_data = cache.load_stock_data(cache_key)
print(f"Loaded data: {loaded_data}")
# Check cache validity
is_valid = cache.is_cache_valid(cache_key, "AAPL")
print(f"Cache valid: {is_valid}")
# Get statistics
stats = cache.get_cache_stats()
print(f"Cache stats: {stats}")
print("✅ Cache manager test completed!")

View File

@ -1,863 +0,0 @@
# 文件差异报告
# 当前文件: tradingagents\dataflows\cache_manager.py
# 中文版文件: TradingAgentsCN\tradingagents\dataflows\cache_manager.py
# 生成时间: 周日 2025/07/06
--- current/cache_manager.py+++ chinese_version/cache_manager.py@@ -1,7 +1,7 @@ #!/usr/bin/env python3
"""
-Stock Data Cache Manager
-Supports local caching of stock data to reduce API calls and improve response speed
+股票数据缓存管理器
+支持本地缓存股票数据减少API调用提高响应速度
"""
import os
@@ -15,24 +15,24 @@
class StockDataCache:
- """Stock Data Cache Manager - Supports optimized caching for US and Chinese stock data"""
+ """股票数据缓存管理器 - 支持美股和A股数据缓存优化"""
def __init__(self, cache_dir: str = None):
"""
- Initialize cache manager
+ 初始化缓存管理器
Args:
- cache_dir: Cache directory path, defaults to tradingagents/dataflows/data_cache
+ cache_dir: 缓存目录路径,默认为 tradingagents/dataflows/data_cache
"""
if cache_dir is None:
- # Get current file directory
+ # 获取当前文件所在目录
current_dir = Path(__file__).parent
cache_dir = current_dir / "data_cache"
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(exist_ok=True)
- # Create subdirectories - categorized by market
+ # 创建子目录 - 按市场分类
self.us_stock_dir = self.cache_dir / "us_stocks"
self.china_stock_dir = self.cache_dir / "china_stocks"
self.us_news_dir = self.cache_dir / "us_news"
@@ -41,81 +41,81 @@ self.china_fundamentals_dir = self.cache_dir / "china_fundamentals"
self.metadata_dir = self.cache_dir / "metadata"
- # Create all directories
+ # 创建所有目录
for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
self.china_news_dir, self.us_fundamentals_dir,
self.china_fundamentals_dir, self.metadata_dir]:
dir_path.mkdir(exist_ok=True)
- # Cache configuration - different TTL settings for different markets
+ # 缓存配置 - 针对不同市场设置不同的TTL
self.cache_config = {
'us_stock_data': {
- 'ttl_hours': 2, # US stock data cached for 2 hours (considering API limits)
+ 'ttl_hours': 2, # 美股数据缓存2小时考虑到API限制
'max_files': 1000,
- 'description': 'US stock historical data'
+ 'description': '美股历史数据'
},
'china_stock_data': {
- 'ttl_hours': 1, # A-share data cached for 1 hour (high real-time requirement)
+ 'ttl_hours': 1, # A股数据缓存1小时实时性要求高
'max_files': 1000,
- 'description': 'A-share historical data'
+ 'description': 'A股历史数据'
},
'us_news': {
- 'ttl_hours': 6, # US stock news cached for 6 hours
+ 'ttl_hours': 6, # 美股新闻缓存6小时
'max_files': 500,
- 'description': 'US stock news data'
+ 'description': '美股新闻数据'
},
'china_news': {
- 'ttl_hours': 4, # A-share news cached for 4 hours
+ 'ttl_hours': 4, # A股新闻缓存4小时
'max_files': 500,
- 'description': 'A-share news data'
+ 'description': 'A股新闻数据'
},
'us_fundamentals': {
- 'ttl_hours': 24, # US stock fundamentals cached for 24 hours
+ 'ttl_hours': 24, # 美股基本面数据缓存24小时
'max_files': 200,
- 'description': 'US stock fundamentals data'
+ 'description': '美股基本面数据'
},
'china_fundamentals': {
- 'ttl_hours': 12, # A-share fundamentals cached for 12 hours
+ 'ttl_hours': 12, # A股基本面数据缓存12小时
'max_files': 200,
- 'description': 'A-share fundamentals data'
+ 'description': 'A股基本面数据'
}
}
- print(f"📁 Cache manager initialized, cache directory: {self.cache_dir}")
- print(f"🗄️ Database cache manager initialized")
- print(f" US stock data: ✅ Configured")
- print(f" A-share data: ✅ Configured")
+ print(f"📁 缓存管理器初始化完成,缓存目录: {self.cache_dir}")
+ print(f"🗄️ 数据库缓存管理器初始化完成")
+ print(f" 美股数据: ✅ 已配置")
+ print(f" A股数据: ✅ 已配置")
def _determine_market_type(self, symbol: str) -> str:
- """Determine market type based on stock symbol"""
+ """根据股票代码确定市场类型"""
import re
- # Check if it's Chinese A-share (6-digit number)
+ # 判断是否为中国A股6位数字
if re.match(r'^\d{6}$', str(symbol)):
return 'china'
else:
return 'us'
def _generate_cache_key(self, data_type: str, symbol: str, **kwargs) -> str:
- """Generate cache key"""
- # Create a string containing all parameters
+ """生成缓存键"""
+ # 创建一个包含所有参数的字符串
params_str = f"{data_type}_{symbol}"
for key, value in sorted(kwargs.items()):
params_str += f"_{key}_{value}"
- # Use MD5 to generate short unique identifier
+ # 使用MD5生成短的唯一标识
cache_key = hashlib.md5(params_str.encode()).hexdigest()[:12]
return f"{symbol}_{data_type}_{cache_key}"
def _get_cache_path(self, data_type: str, cache_key: str, file_format: str = "json", symbol: str = None) -> Path:
- """Get cache file path - supports market classification"""
+ """获取缓存文件路径 - 支持市场分类"""
if symbol:
market_type = self._determine_market_type(symbol)
else:
- # Try to extract market type from cache key
+ # 从缓存键中尝试提取市场类型
market_type = 'us' if not cache_key.startswith(('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')) else 'china'
- # Select directory based on data type and market type
+ # 根据数据类型和市场类型选择目录
if data_type == "stock_data":
base_dir = self.china_stock_dir if market_type == 'china' else self.us_stock_dir
elif data_type == "news":
@@ -128,20 +128,19 @@ return base_dir / f"{cache_key}.{file_format}"
def _get_metadata_path(self, cache_key: str) -> Path:
- """Get metadata file path"""
+ """获取元数据文件路径"""
return self.metadata_dir / f"{cache_key}_meta.json"
def _save_metadata(self, cache_key: str, metadata: Dict[str, Any]):
- """Save cache metadata"""
+ """保存元数据"""
metadata_path = self._get_metadata_path(cache_key)
- try:
- with open(metadata_path, 'w', encoding='utf-8') as f:
- json.dump(metadata, f, ensure_ascii=False, indent=2, default=str)
- except Exception as e:
- print(f"⚠️ Failed to save metadata: {e}")
+ metadata['cached_at'] = datetime.now().isoformat()
+
+ with open(metadata_path, 'w', encoding='utf-8') as f:
+ json.dump(metadata, f, ensure_ascii=False, indent=2)
def _load_metadata(self, cache_key: str) -> Optional[Dict[str, Any]]:
- """Load cache metadata"""
+ """加载元数据"""
metadata_path = self._get_metadata_path(cache_key)
if not metadata_path.exists():
return None
@@ -150,349 +149,355 @@ with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
- print(f"⚠️ Failed to load metadata: {e}")
- return None
-
- def save_stock_data(self, symbol: str, data: Union[str, pd.DataFrame],
- start_date: str, end_date: str, data_source: str = "unknown") -> str:
- """
- Save stock data to cache
+ print(f"⚠️ 加载元数据失败: {e}")
+ return None
+
+ def is_cache_valid(self, cache_key: str, max_age_hours: int = None, symbol: str = None, data_type: str = None) -> bool:
+ """检查缓存是否有效 - 支持智能TTL配置"""
+ metadata = self._load_metadata(cache_key)
+ if not metadata:
+ return False
+
+ # 如果没有指定TTL根据数据类型和市场自动确定
+ if max_age_hours is None:
+ if symbol and data_type:
+ market_type = self._determine_market_type(symbol)
+ cache_type = f"{market_type}_{data_type}"
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
+ else:
+ # 从元数据中获取信息
+ symbol = metadata.get('symbol', '')
+ data_type = metadata.get('data_type', 'stock_data')
+ market_type = self._determine_market_type(symbol)
+ cache_type = f"{market_type}_{data_type}"
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
+
+ cached_at = datetime.fromisoformat(metadata['cached_at'])
+ age = datetime.now() - cached_at
+
+ is_valid = age.total_seconds() < max_age_hours * 3600
+
+ if is_valid:
+ market_type = self._determine_market_type(metadata.get('symbol', ''))
+ cache_type = f"{market_type}_{metadata.get('data_type', 'stock_data')}"
+ desc = self.cache_config.get(cache_type, {}).get('description', '数据')
+ print(f"✅ 缓存有效: {desc} - {metadata.get('symbol')} (剩余 {max_age_hours - age.total_seconds()/3600:.1f}h)")
+
+ return is_valid
+
+ def save_stock_data(self, symbol: str, data: Union[pd.DataFrame, str],
+ start_date: str = None, end_date: str = None,
+ data_source: str = "unknown") -> str:
+ """
+ 保存股票数据到缓存 - 支持美股和A股分类存储
Args:
- symbol: Stock symbol
- data: Stock data (string or DataFrame)
- start_date: Start date
- end_date: End date
- data_source: Data source name
+ symbol: 股票代码
+ data: 股票数据DataFrame或字符串
+ start_date: 开始日期
+ end_date: 结束日期
+ data_source: 数据源(如 "tdx", "yfinance", "finnhub"
Returns:
- Cache key
- """
+ cache_key: 缓存键
+ """
+ market_type = self._determine_market_type(symbol)
+ cache_key = self._generate_cache_key("stock_data", symbol,
+ start_date=start_date,
+ end_date=end_date,
+ source=data_source,
+ market=market_type)
+
+ # 保存数据
+ if isinstance(data, pd.DataFrame):
+ cache_path = self._get_cache_path("stock_data", cache_key, "csv", symbol)
+ data.to_csv(cache_path, index=True)
+ else:
+ cache_path = self._get_cache_path("stock_data", cache_key, "txt", symbol)
+ with open(cache_path, 'w', encoding='utf-8') as f:
+ f.write(str(data))
+
+ # 保存元数据
+ metadata = {
+ 'symbol': symbol,
+ 'data_type': 'stock_data',
+ 'market_type': market_type,
+ 'start_date': start_date,
+ 'end_date': end_date,
+ 'data_source': data_source,
+ 'file_path': str(cache_path),
+ 'file_format': 'csv' if isinstance(data, pd.DataFrame) else 'txt'
+ }
+ self._save_metadata(cache_key, metadata)
+
+ # 获取描述信息
+ cache_type = f"{market_type}_stock_data"
+ desc = self.cache_config.get(cache_type, {}).get('description', '股票数据')
+ print(f"💾 {desc}已缓存: {symbol} ({data_source}) -> {cache_key}")
+ return cache_key
+
+ def load_stock_data(self, cache_key: str) -> Optional[Union[pd.DataFrame, str]]:
+ """从缓存加载股票数据"""
+ metadata = self._load_metadata(cache_key)
+ if not metadata:
+ return None
+
+ cache_path = Path(metadata['file_path'])
+ if not cache_path.exists():
+ return None
+
try:
- # Generate cache key
- cache_key = self._generate_cache_key(
- "stock_data", symbol,
- start_date=start_date,
- end_date=end_date,
- data_source=data_source
- )
-
- # Determine file format and save data
- if isinstance(data, pd.DataFrame):
- # Save DataFrame as pickle for better performance
- cache_path = self._get_cache_path("stock_data", cache_key, "pkl", symbol)
- data.to_pickle(cache_path)
- data_type = "dataframe"
- else:
- # Save string data as JSON
- cache_path = self._get_cache_path("stock_data", cache_key, "json", symbol)
- with open(cache_path, 'w', encoding='utf-8') as f:
- json.dump({"data": data}, f, ensure_ascii=False, indent=2)
- data_type = "string"
-
- # Save metadata
- market_type = self._determine_market_type(symbol)
- metadata = {
- "symbol": symbol,
- "data_type": data_type,
- "start_date": start_date,
- "end_date": end_date,
- "data_source": data_source,
- "market_type": market_type,
- "cache_time": datetime.now().isoformat(),
- "file_path": str(cache_path),
- "cache_key": cache_key
- }
- self._save_metadata(cache_key, metadata)
-
- print(f"💾 Stock data cached: {symbol} ({market_type.upper()}) -> {cache_key}")
- return cache_key
-
- except Exception as e:
- print(f"❌ Failed to save stock data cache: {e}")
- return None
-
- def load_stock_data(self, cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
- """
- Load stock data from cache
-
- Args:
- cache_key: Cache key
-
- Returns:
- Stock data or None if not found
- """
- try:
- # Load metadata
- metadata = self._load_metadata(cache_key)
- if not metadata:
- print(f"⚠️ Cache metadata not found: {cache_key}")
- return None
-
- # Get file path
- cache_path = Path(metadata["file_path"])
- if not cache_path.exists():
- print(f"⚠️ Cache file not found: {cache_path}")
- return None
-
- # Load data based on type
- if metadata["data_type"] == "dataframe":
- data = pd.read_pickle(cache_path)
+ if metadata['file_format'] == 'csv':
+ return pd.read_csv(cache_path, index_col=0)
else:
with open(cache_path, 'r', encoding='utf-8') as f:
- json_data = json.load(f)
- data = json_data["data"]
-
- print(f"📖 Stock data loaded from cache: {metadata['symbol']} -> {cache_key}")
- return data
-
+ return f.read()
except Exception as e:
- print(f"❌ Failed to load stock data from cache: {e}")
- return None
-
- def find_cached_stock_data(self, symbol: str, start_date: str, end_date: str,
- data_source: str = "unknown") -> Optional[str]:
- """
- Find cached stock data
+ print(f"⚠️ 加载缓存数据失败: {e}")
+ return None
+
+ def find_cached_stock_data(self, symbol: str, start_date: str = None,
+ end_date: str = None, data_source: str = None,
+ max_age_hours: int = None) -> Optional[str]:
+ """
+ 查找匹配的缓存数据 - 支持智能市场分类查找
Args:
- symbol: Stock symbol
- start_date: Start date
- end_date: End date
- data_source: Data source name
+ symbol: 股票代码
+ start_date: 开始日期
+ end_date: 结束日期
+ data_source: 数据源
+ max_age_hours: 最大缓存时间小时None时使用智能配置
Returns:
- Cache key if found, None otherwise
- """
- # Generate expected cache key
- cache_key = self._generate_cache_key(
- "stock_data", symbol,
- start_date=start_date,
- end_date=end_date,
- data_source=data_source
- )
-
- # Check if metadata exists
+ cache_key: 如果找到有效缓存则返回缓存键否则返回None
+ """
+ market_type = self._determine_market_type(symbol)
+
+ # 如果没有指定TTL使用智能配置
+ if max_age_hours is None:
+ cache_type = f"{market_type}_stock_data"
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
+
+ # 生成查找键
+ search_key = self._generate_cache_key("stock_data", symbol,
+ start_date=start_date,
+ end_date=end_date,
+ source=data_source,
+ market=market_type)
+
+ # 检查精确匹配
+ if self.is_cache_valid(search_key, max_age_hours, symbol, 'stock_data'):
+ desc = self.cache_config.get(f"{market_type}_stock_data", {}).get('description', '数据')
+ print(f"🎯 找到精确匹配的{desc}: {symbol} -> {search_key}")
+ return search_key
+
+ # 如果没有精确匹配,查找部分匹配(相同股票代码的其他缓存)
+ for metadata_file in self.metadata_dir.glob(f"*_meta.json"):
+ try:
+ with open(metadata_file, 'r', encoding='utf-8') as f:
+ metadata = json.load(f)
+
+ if (metadata.get('symbol') == symbol and
+ metadata.get('data_type') == 'stock_data' and
+ metadata.get('market_type') == market_type and
+ (data_source is None or metadata.get('data_source') == data_source)):
+
+ cache_key = metadata_file.stem.replace('_meta', '')
+ if self.is_cache_valid(cache_key, max_age_hours, symbol, 'stock_data'):
+ desc = self.cache_config.get(f"{market_type}_stock_data", {}).get('description', '数据')
+ print(f"📋 找到部分匹配的{desc}: {symbol} -> {cache_key}")
+ return cache_key
+ except Exception:
+ continue
+
+ desc = self.cache_config.get(f"{market_type}_stock_data", {}).get('description', '数据')
+ print(f"❌ 未找到有效的{desc}缓存: {symbol}")
+ return None
+
+ def save_news_data(self, symbol: str, news_data: str,
+ start_date: str = None, end_date: str = None,
+ data_source: str = "unknown") -> str:
+ """保存新闻数据到缓存"""
+ cache_key = self._generate_cache_key("news", symbol,
+ start_date=start_date,
+ end_date=end_date,
+ source=data_source)
+
+ cache_path = self._get_cache_path("news", cache_key, "txt")
+ with open(cache_path, 'w', encoding='utf-8') as f:
+ f.write(news_data)
+
+ metadata = {
+ 'symbol': symbol,
+ 'data_type': 'news',
+ 'start_date': start_date,
+ 'end_date': end_date,
+ 'data_source': data_source,
+ 'file_path': str(cache_path),
+ 'file_format': 'txt'
+ }
+ self._save_metadata(cache_key, metadata)
+
+ print(f"📰 新闻数据已缓存: {symbol} ({data_source}) -> {cache_key}")
+ return cache_key
+
+ def save_fundamentals_data(self, symbol: str, fundamentals_data: str,
+ data_source: str = "unknown") -> str:
+ """保存基本面数据到缓存"""
+ market_type = self._determine_market_type(symbol)
+ cache_key = self._generate_cache_key("fundamentals", symbol,
+ source=data_source,
+ market=market_type,
+ date=datetime.now().strftime("%Y-%m-%d"))
+
+ cache_path = self._get_cache_path("fundamentals", cache_key, "txt", symbol)
+ with open(cache_path, 'w', encoding='utf-8') as f:
+ f.write(fundamentals_data)
+
+ metadata = {
+ 'symbol': symbol,
+ 'data_type': 'fundamentals',
+ 'data_source': data_source,
+ 'market_type': market_type,
+ 'file_path': str(cache_path),
+ 'file_format': 'txt'
+ }
+ self._save_metadata(cache_key, metadata)
+
+ desc = self.cache_config.get(f"{market_type}_fundamentals", {}).get('description', '基本面数据')
+ print(f"💼 {desc}已缓存: {symbol} ({data_source}) -> {cache_key}")
+ return cache_key
+
+ def load_fundamentals_data(self, cache_key: str) -> Optional[str]:
+ """从缓存加载基本面数据"""
metadata = self._load_metadata(cache_key)
- if metadata:
- cache_path = Path(metadata["file_path"])
- if cache_path.exists():
- return cache_key
-
+ if not metadata:
+ return None
+
+ cache_path = Path(metadata['file_path'])
+ if not cache_path.exists():
+ return None
+
+ try:
+ with open(cache_path, 'r', encoding='utf-8') as f:
+ return f.read()
+ except Exception as e:
+ print(f"⚠️ 加载基本面缓存数据失败: {e}")
+ return None
+
+ def find_cached_fundamentals_data(self, symbol: str, data_source: str = None,
+ max_age_hours: int = None) -> Optional[str]:
+ """
+ 查找匹配的基本面缓存数据
+
+ Args:
+ symbol: 股票代码
+ data_source: 数据源(如 "openai", "finnhub"
+ max_age_hours: 最大缓存时间小时None时使用智能配置
+
+ Returns:
+ cache_key: 如果找到有效缓存则返回缓存键否则返回None
+ """
+ market_type = self._determine_market_type(symbol)
+
+ # 如果没有指定TTL使用智能配置
+ if max_age_hours is None:
+ cache_type = f"{market_type}_fundamentals"
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
+
+ # 查找匹配的缓存
+ for metadata_file in self.metadata_dir.glob(f"*_meta.json"):
+ try:
+ with open(metadata_file, 'r', encoding='utf-8') as f:
+ metadata = json.load(f)
+
+ if (metadata.get('symbol') == symbol and
+ metadata.get('data_type') == 'fundamentals' and
+ metadata.get('market_type') == market_type and
+ (data_source is None or metadata.get('data_source') == data_source)):
+
+ cache_key = metadata_file.stem.replace('_meta', '')
+ if self.is_cache_valid(cache_key, max_age_hours, symbol, 'fundamentals'):
+ desc = self.cache_config.get(f"{market_type}_fundamentals", {}).get('description', '基本面数据')
+ print(f"🎯 找到匹配的{desc}缓存: {symbol} ({data_source}) -> {cache_key}")
+ return cache_key
+ except Exception:
+ continue
+
+ desc = self.cache_config.get(f"{market_type}_fundamentals", {}).get('description', '基本面数据')
+ print(f"❌ 未找到有效的{desc}缓存: {symbol} ({data_source})")
return None
-
- def is_cache_valid(self, cache_key: str, symbol: str = None, data_type: str = "stock_data") -> bool:
- """
- Check if cache is still valid
-
- Args:
- cache_key: Cache key
- symbol: Stock symbol (for market type determination)
- data_type: Data type
-
- Returns:
- True if cache is valid, False otherwise
- """
- try:
- # Load metadata
- metadata = self._load_metadata(cache_key)
- if not metadata:
- return False
-
- # Check if file exists
- cache_path = Path(metadata["file_path"])
- if not cache_path.exists():
- return False
-
- # Determine market type and get TTL
- if symbol:
- market_type = self._determine_market_type(symbol)
- else:
- market_type = metadata.get("market_type", "us")
-
- cache_type_key = f"{market_type}_{data_type}"
- if cache_type_key not in self.cache_config:
- cache_type_key = "us_stock_data" # Default fallback
-
- ttl_hours = self.cache_config[cache_type_key]["ttl_hours"]
-
- # Check if cache has expired
- cache_time = datetime.fromisoformat(metadata["cache_time"])
- expiry_time = cache_time + timedelta(hours=ttl_hours)
-
- is_valid = datetime.now() < expiry_time
- if not is_valid:
- print(f"⏰ Cache expired: {cache_key} (cached at {cache_time}, TTL: {ttl_hours}h)")
-
- return is_valid
-
- except Exception as e:
- print(f"❌ Failed to check cache validity: {e}")
- return False
-
+
+ def clear_old_cache(self, max_age_days: int = 7):
+ """清理过期缓存"""
+ cutoff_time = datetime.now() - timedelta(days=max_age_days)
+ cleared_count = 0
+
+ for metadata_file in self.metadata_dir.glob("*_meta.json"):
+ try:
+ with open(metadata_file, 'r', encoding='utf-8') as f:
+ metadata = json.load(f)
+
+ cached_at = datetime.fromisoformat(metadata['cached_at'])
+ if cached_at < cutoff_time:
+ # 删除数据文件
+ data_file = Path(metadata['file_path'])
+ if data_file.exists():
+ data_file.unlink()
+
+ # 删除元数据文件
+ metadata_file.unlink()
+ cleared_count += 1
+
+ except Exception as e:
+ print(f"⚠️ 清理缓存时出错: {e}")
+
+ print(f"🧹 已清理 {cleared_count} 个过期缓存文件")
+
def get_cache_stats(self) -> Dict[str, Any]:
- """
- Get cache statistics
-
- Returns:
- Dictionary containing cache statistics
- """
- try:
- stats = {
- "cache_dir": str(self.cache_dir),
- "total_files": 0,
- "total_size_mb": 0,
- "stock_data_count": 0,
- "news_count": 0,
- "fundamentals_count": 0,
- "us_data_count": 0,
- "china_data_count": 0
- }
-
- # Count files in each directory
- for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
- self.china_news_dir, self.us_fundamentals_dir,
- self.china_fundamentals_dir, self.metadata_dir]:
- if dir_path.exists():
- files = list(dir_path.glob("*"))
- stats["total_files"] += len(files)
-
- # Calculate total size
- for file_path in files:
- if file_path.is_file():
- stats["total_size_mb"] += file_path.stat().st_size / (1024 * 1024)
-
- # Count by data type
- if self.us_stock_dir.exists():
- stats["stock_data_count"] += len(list(self.us_stock_dir.glob("*")))
- stats["us_data_count"] += len(list(self.us_stock_dir.glob("*")))
-
- if self.china_stock_dir.exists():
- stats["stock_data_count"] += len(list(self.china_stock_dir.glob("*")))
- stats["china_data_count"] += len(list(self.china_stock_dir.glob("*")))
-
- if self.us_news_dir.exists():
- stats["news_count"] += len(list(self.us_news_dir.glob("*")))
- stats["us_data_count"] += len(list(self.us_news_dir.glob("*")))
-
- if self.china_news_dir.exists():
- stats["news_count"] += len(list(self.china_news_dir.glob("*")))
- stats["china_data_count"] += len(list(self.china_news_dir.glob("*")))
-
- if self.us_fundamentals_dir.exists():
- stats["fundamentals_count"] += len(list(self.us_fundamentals_dir.glob("*")))
- stats["us_data_count"] += len(list(self.us_fundamentals_dir.glob("*")))
-
- if self.china_fundamentals_dir.exists():
- stats["fundamentals_count"] += len(list(self.china_fundamentals_dir.glob("*")))
- stats["china_data_count"] += len(list(self.china_fundamentals_dir.glob("*")))
-
- # Round size to 2 decimal places
- stats["total_size_mb"] = round(stats["total_size_mb"], 2)
-
- return stats
-
- except Exception as e:
- print(f"❌ Failed to get cache statistics: {e}")
- return {"error": str(e)}
-
- def cleanup_expired_cache(self):
- """Clean up expired cache files"""
- try:
- cleaned_count = 0
-
- # Check all metadata files
- if self.metadata_dir.exists():
- for metadata_file in self.metadata_dir.glob("*_meta.json"):
- try:
- cache_key = metadata_file.stem.replace("_meta", "")
-
- # Load metadata
- with open(metadata_file, 'r', encoding='utf-8') as f:
- metadata = json.load(f)
-
- # Check if cache is expired
- if not self.is_cache_valid(cache_key, metadata.get("symbol"), "stock_data"):
- # Remove cache file
- cache_path = Path(metadata["file_path"])
- if cache_path.exists():
- cache_path.unlink()
-
- # Remove metadata file
- metadata_file.unlink()
- cleaned_count += 1
- print(f"🗑️ Cleaned expired cache: {cache_key}")
-
- except Exception as e:
- print(f"⚠️ Failed to clean cache file {metadata_file}: {e}")
-
- print(f"✅ Cache cleanup completed, removed {cleaned_count} expired files")
-
- except Exception as e:
- print(f"❌ Failed to cleanup cache: {e}")
-
-
-# Global cache instance
-_global_cache = None
-
-def get_cache(cache_dir: str = None) -> StockDataCache:
- """
- Get global cache instance
-
- Args:
- cache_dir: Cache directory path
-
- Returns:
- StockDataCache instance
- """
- global _global_cache
- if _global_cache is None:
- _global_cache = StockDataCache(cache_dir)
- return _global_cache
-
-
-# Convenience functions
-def save_stock_data(symbol: str, data: Union[str, pd.DataFrame],
- start_date: str, end_date: str, data_source: str = "unknown") -> str:
- """Save stock data to cache (convenience function)"""
- cache = get_cache()
- return cache.save_stock_data(symbol, data, start_date, end_date, data_source)
-
-
-def load_stock_data(cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
- """Load stock data from cache (convenience function)"""
- cache = get_cache()
- return cache.load_stock_data(cache_key)
-
-
-def find_cached_stock_data(symbol: str, start_date: str, end_date: str,
- data_source: str = "unknown") -> Optional[str]:
- """Find cached stock data (convenience function)"""
- cache = get_cache()
- return cache.find_cached_stock_data(symbol, start_date, end_date, data_source)
-
-
-if __name__ == "__main__":
- # Test the cache manager
- print("🧪 Testing Stock Data Cache Manager...")
-
- # Initialize cache
- cache = StockDataCache()
-
- # Test data
- test_data = "Sample stock data for AAPL"
- cache_key = cache.save_stock_data("AAPL", test_data, "2024-01-01", "2024-01-31", "test")
-
- # Load data
- loaded_data = cache.load_stock_data(cache_key)
- print(f"Loaded data: {loaded_data}")
-
- # Check cache validity
- is_valid = cache.is_cache_valid(cache_key, "AAPL")
- print(f"Cache valid: {is_valid}")
-
- # Get statistics
- stats = cache.get_cache_stats()
- print(f"Cache stats: {stats}")
-
- print("✅ Cache manager test completed!")
+ """获取缓存统计信息"""
+ stats = {
+ 'total_files': 0,
+ 'stock_data_count': 0,
+ 'news_count': 0,
+ 'fundamentals_count': 0,
+ 'total_size_mb': 0
+ }
+
+ for metadata_file in self.metadata_dir.glob("*_meta.json"):
+ try:
+ with open(metadata_file, 'r', encoding='utf-8') as f:
+ metadata = json.load(f)
+
+ data_type = metadata.get('data_type', 'unknown')
+ if data_type == 'stock_data':
+ stats['stock_data_count'] += 1
+ elif data_type == 'news':
+ stats['news_count'] += 1
+ elif data_type == 'fundamentals':
+ stats['fundamentals_count'] += 1
+
+ # 计算文件大小
+ data_file = Path(metadata['file_path'])
+ if data_file.exists():
+ stats['total_size_mb'] += data_file.stat().st_size / (1024 * 1024)
+
+ stats['total_files'] += 1
+
+ except Exception:
+ continue
+
+ stats['total_size_mb'] = round(stats['total_size_mb'], 2)
+ return stats
+
+
+# 全局缓存实例
+_cache_instance = None
+
+def get_cache() -> StockDataCache:
+ """获取全局缓存实例"""
+ global _cache_instance
+ if _cache_instance is None:
+ _cache_instance = StockDataCache()
+ return _cache_instance

View File

@ -3,6 +3,13 @@ from .reddit_utils import fetch_top_from_category
from .yfin_utils import *
from .stockstats_utils import *
from .googlenews_utils import *
# Import Chinese finance utilities if available
try:
from .chinese_finance_utils import get_chinese_social_sentiment
except ImportError:
def get_chinese_social_sentiment(*args, **kwargs):
return "Chinese finance utilities not available"
from .finnhub_utils import get_data_in_range
from dateutil.relativedelta import relativedelta
from concurrent.futures import ThreadPoolExecutor
@ -43,7 +50,14 @@ def get_finnhub_news(
result = get_data_in_range(ticker, before, curr_date, "news_data", DATA_DIR)
if len(result) == 0:
return ""
error_msg = f"⚠️ Unable to retrieve news data for {ticker} ({before} to {curr_date})\n"
error_msg += f"Possible reasons:\n"
error_msg += f"1. Data files do not exist or path configuration is incorrect\n"
error_msg += f"2. No news data available for the specified date range\n"
error_msg += f"3. Need to download or update Finnhub news data first\n"
error_msg += f"Suggestion: Check data directory configuration or re-fetch news data"
print(f"📰 [DEBUG] {error_msg}")
return error_msg
combined_result = ""
for day, data in result.items():
@ -805,3 +819,133 @@ def get_fundamentals_openai(ticker, curr_date):
)
return response.output[1].content[0].text
def get_fundamentals_finnhub(ticker, curr_date):
"""
Use Finnhub API to get stock fundamental data as an alternative to OpenAI
Args:
ticker (str): Stock symbol
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: Formatted fundamental data report
"""
try:
import finnhub
import os
# Try to import cache manager
try:
from .cache_manager import get_cache
cache = get_cache()
# Check cache first
cached_key = cache.find_cached_stock_data(ticker, curr_date, curr_date, "finnhub_fundamentals")
if cached_key and cache.is_cache_valid(cached_key, ticker):
cached_data = cache.load_stock_data(cached_key)
if cached_data:
print(f"💾 [DEBUG] Loading Finnhub fundamental data from cache: {ticker}")
return cached_data
except ImportError:
cache = None
print("⚠️ Cache manager not available, proceeding without cache")
# Get Finnhub API key
api_key = os.getenv('FINNHUB_API_KEY')
if not api_key:
return "Error: FINNHUB_API_KEY environment variable not configured"
# Initialize Finnhub client
finnhub_client = finnhub.Client(api_key=api_key)
print(f"📊 [DEBUG] Using Finnhub API to get fundamental data for {ticker}...")
# Get basic financial data
try:
basic_financials = finnhub_client.company_basic_financials(ticker, 'all')
except Exception as e:
print(f"❌ [DEBUG] Failed to get Finnhub basic financials: {str(e)}")
basic_financials = None
# Get company profile
try:
company_profile = finnhub_client.company_profile2(symbol=ticker)
except Exception as e:
print(f"❌ [DEBUG] Failed to get Finnhub company profile: {str(e)}")
company_profile = None
# Get earnings data
try:
earnings = finnhub_client.company_earnings(ticker, limit=4)
except Exception as e:
print(f"❌ [DEBUG] Failed to get Finnhub earnings data: {str(e)}")
earnings = None
# Format report
report = f"# {ticker} Fundamental Analysis Report (Finnhub Data Source)\n\n"
report += f"**Data Retrieved**: {curr_date}\n"
report += f"**Data Source**: Finnhub API\n\n"
# Company profile section
if company_profile:
report += "## Company Profile\n"
report += f"- **Company Name**: {company_profile.get('name', 'N/A')}\n"
report += f"- **Industry**: {company_profile.get('finnhubIndustry', 'N/A')}\n"
report += f"- **Country**: {company_profile.get('country', 'N/A')}\n"
report += f"- **Currency**: {company_profile.get('currency', 'N/A')}\n"
report += f"- **Market Cap**: {company_profile.get('marketCapitalization', 'N/A')} million USD\n"
report += f"- **Shares Outstanding**: {company_profile.get('shareOutstanding', 'N/A')} million shares\n\n"
# Basic financial metrics
if basic_financials and 'metric' in basic_financials:
metrics = basic_financials['metric']
report += "## Key Financial Metrics\n"
# Valuation metrics
report += "### Valuation Metrics\n"
report += f"- **P/E Ratio (TTM)**: {metrics.get('peBasicExclExtraTTM', 'N/A')}\n"
report += f"- **P/B Ratio**: {metrics.get('pbAnnual', 'N/A')}\n"
report += f"- **P/S Ratio (TTM)**: {metrics.get('psAnnual', 'N/A')}\n"
report += f"- **EV/EBITDA (TTM)**: {metrics.get('evEbitdaTTM', 'N/A')}\n\n"
# Profitability metrics
report += "### Profitability Metrics\n"
report += f"- **ROE (TTM)**: {metrics.get('roeTTM', 'N/A')}%\n"
report += f"- **ROA (TTM)**: {metrics.get('roaTTM', 'N/A')}%\n"
report += f"- **Gross Margin (TTM)**: {metrics.get('grossMarginTTM', 'N/A')}%\n"
report += f"- **Net Margin (TTM)**: {metrics.get('netProfitMarginTTM', 'N/A')}%\n\n"
# Growth metrics
report += "### Growth Metrics\n"
report += f"- **Revenue Growth (5Y)**: {metrics.get('revenueGrowthTTMYoy', 'N/A')}%\n"
report += f"- **EPS Growth (5Y)**: {metrics.get('epsGrowthTTMYoy', 'N/A')}%\n\n"
# Earnings data
if earnings and len(earnings) > 0:
report += "## Recent Earnings\n"
for i, earning in enumerate(earnings[:4]): # Show last 4 quarters
report += f"### Q{i+1} (Period: {earning.get('period', 'N/A')})\n"
report += f"- **Actual EPS**: ${earning.get('actual', 'N/A')}\n"
report += f"- **Estimated EPS**: ${earning.get('estimate', 'N/A')}\n"
if earning.get('actual') and earning.get('estimate'):
surprise = earning['actual'] - earning['estimate']
report += f"- **Surprise**: ${surprise:.2f}\n"
report += "\n"
# Cache the result if cache is available
if cache:
try:
cache.save_stock_data(ticker, report, curr_date, curr_date, "finnhub_fundamentals")
print(f"💾 [DEBUG] Cached Finnhub fundamental data for {ticker}")
except Exception as e:
print(f"⚠️ [DEBUG] Failed to cache data: {e}")
print(f"✅ [DEBUG] Successfully retrieved Finnhub fundamental data for {ticker}")
return report
except ImportError:
return "Error: finnhub-python package not installed. Please install with: pip install finnhub-python"
except Exception as e:
error_msg = f"Error retrieving Finnhub fundamental data for {ticker}: {str(e)}"
print(f"❌ [DEBUG] {error_msg}")
return error_msg

View File

@ -1,807 +0,0 @@
from typing import Annotated, Dict
from .reddit_utils import fetch_top_from_category
from .yfin_utils import *
from .stockstats_utils import *
from .googlenews_utils import *
from .finnhub_utils import get_data_in_range
from dateutil.relativedelta import relativedelta
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import json
import os
import pandas as pd
from tqdm import tqdm
import yfinance as yf
from openai import OpenAI
from .config import get_config, set_config, DATA_DIR
def get_finnhub_news(
ticker: Annotated[
str,
"Search query of a company's, e.g. 'AAPL, TSM, etc.",
],
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
):
"""
Retrieve news about a company within a time frame
Args
ticker (str): ticker for the company you are interested in
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns
str: dataframe containing the news of the company in the time frame
"""
start_date = datetime.strptime(curr_date, "%Y-%m-%d")
before = start_date - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
result = get_data_in_range(ticker, before, curr_date, "news_data", DATA_DIR)
if len(result) == 0:
return ""
combined_result = ""
for day, data in result.items():
if len(data) == 0:
continue
for entry in data:
current_news = (
"### " + entry["headline"] + f" ({day})" + "\n" + entry["summary"]
)
combined_result += current_news + "\n\n"
return f"## {ticker} News, from {before} to {curr_date}:\n" + str(combined_result)
def get_finnhub_company_insider_sentiment(
ticker: Annotated[str, "ticker symbol for the company"],
curr_date: Annotated[
str,
"current date of you are trading at, yyyy-mm-dd",
],
look_back_days: Annotated[int, "number of days to look back"],
):
"""
Retrieve insider sentiment about a company (retrieved from public SEC information) for the past 15 days
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading on, yyyy-mm-dd
Returns:
str: a report of the sentiment in the past 15 days starting at curr_date
"""
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
before = date_obj - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
data = get_data_in_range(ticker, before, curr_date, "insider_senti", DATA_DIR)
if len(data) == 0:
return ""
result_str = ""
seen_dicts = []
for date, senti_list in data.items():
for entry in senti_list:
if entry not in seen_dicts:
result_str += f"### {entry['year']}-{entry['month']}:\nChange: {entry['change']}\nMonthly Share Purchase Ratio: {entry['mspr']}\n\n"
seen_dicts.append(entry)
return (
f"## {ticker} Insider Sentiment Data for {before} to {curr_date}:\n"
+ result_str
+ "The change field refers to the net buying/selling from all insiders' transactions. The mspr field refers to monthly share purchase ratio."
)
def get_finnhub_company_insider_transactions(
ticker: Annotated[str, "ticker symbol"],
curr_date: Annotated[
str,
"current date you are trading at, yyyy-mm-dd",
],
look_back_days: Annotated[int, "how many days to look back"],
):
"""
Retrieve insider transcaction information about a company (retrieved from public SEC information) for the past 15 days
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's insider transaction/trading informtaion in the past 15 days
"""
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
before = date_obj - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
data = get_data_in_range(ticker, before, curr_date, "insider_trans", DATA_DIR)
if len(data) == 0:
return ""
result_str = ""
seen_dicts = []
for date, senti_list in data.items():
for entry in senti_list:
if entry not in seen_dicts:
result_str += f"### Filing Date: {entry['filingDate']}, {entry['name']}:\nChange:{entry['change']}\nShares: {entry['share']}\nTransaction Price: {entry['transactionPrice']}\nTransaction Code: {entry['transactionCode']}\n\n"
seen_dicts.append(entry)
return (
f"## {ticker} insider transactions from {before} to {curr_date}:\n"
+ result_str
+ "The change field reflects the variation in share count—here a negative number indicates a reduction in holdings—while share specifies the total number of shares involved. The transactionPrice denotes the per-share price at which the trade was executed, and transactionDate marks when the transaction occurred. The name field identifies the insider making the trade, and transactionCode (e.g., S for sale) clarifies the nature of the transaction. FilingDate records when the transaction was officially reported, and the unique id links to the specific SEC filing, as indicated by the source. Additionally, the symbol ties the transaction to a particular company, isDerivative flags whether the trade involves derivative securities, and currency notes the currency context of the transaction."
)
def get_simfin_balance_sheet(
ticker: Annotated[str, "ticker symbol"],
freq: Annotated[
str,
"reporting frequency of the company's financial history: annual / quarterly",
],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
data_path = os.path.join(
DATA_DIR,
"fundamental_data",
"simfin_data_all",
"balance_sheet",
"companies",
"us",
f"us-balance-{freq}.csv",
)
df = pd.read_csv(data_path, sep=";")
# Convert date strings to datetime objects and remove any time components
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
# Convert the current date to datetime and normalize
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No balance sheet available before the given current date.")
return ""
# Get the most recent balance sheet by selecting the row with the latest Publish Date
latest_balance_sheet = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
# drop the SimFinID column
latest_balance_sheet = latest_balance_sheet.drop("SimFinId")
return (
f"## {freq} balance sheet for {ticker} released on {str(latest_balance_sheet['Publish Date'])[0:10]}: \n"
+ str(latest_balance_sheet)
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a breakdown of assets, liabilities, and equity. Assets are grouped as current (liquid items like cash and receivables) and noncurrent (long-term investments and property). Liabilities are split between short-term obligations and long-term debts, while equity reflects shareholder funds such as paid-in capital and retained earnings. Together, these components ensure that total assets equal the sum of liabilities and equity."
)
def get_simfin_cashflow(
ticker: Annotated[str, "ticker symbol"],
freq: Annotated[
str,
"reporting frequency of the company's financial history: annual / quarterly",
],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
data_path = os.path.join(
DATA_DIR,
"fundamental_data",
"simfin_data_all",
"cash_flow",
"companies",
"us",
f"us-cashflow-{freq}.csv",
)
df = pd.read_csv(data_path, sep=";")
# Convert date strings to datetime objects and remove any time components
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
# Convert the current date to datetime and normalize
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No cash flow statement available before the given current date.")
return ""
# Get the most recent cash flow statement by selecting the row with the latest Publish Date
latest_cash_flow = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
# drop the SimFinID column
latest_cash_flow = latest_cash_flow.drop("SimFinId")
return (
f"## {freq} cash flow statement for {ticker} released on {str(latest_cash_flow['Publish Date'])[0:10]}: \n"
+ str(latest_cash_flow)
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a breakdown of cash movements. Operating activities show cash generated from core business operations, including net income adjustments for non-cash items and working capital changes. Investing activities cover asset acquisitions/disposals and investments. Financing activities include debt transactions, equity issuances/repurchases, and dividend payments. The net change in cash represents the overall increase or decrease in the company's cash position during the reporting period."
)
def get_simfin_income_statements(
ticker: Annotated[str, "ticker symbol"],
freq: Annotated[
str,
"reporting frequency of the company's financial history: annual / quarterly",
],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
data_path = os.path.join(
DATA_DIR,
"fundamental_data",
"simfin_data_all",
"income_statements",
"companies",
"us",
f"us-income-{freq}.csv",
)
df = pd.read_csv(data_path, sep=";")
# Convert date strings to datetime objects and remove any time components
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
# Convert the current date to datetime and normalize
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No income statement available before the given current date.")
return ""
# Get the most recent income statement by selecting the row with the latest Publish Date
latest_income = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
# drop the SimFinID column
latest_income = latest_income.drop("SimFinId")
return (
f"## {freq} income statement for {ticker} released on {str(latest_income['Publish Date'])[0:10]}: \n"
+ str(latest_income)
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a comprehensive breakdown of the company's financial performance. Starting with Revenue, it shows Cost of Revenue and resulting Gross Profit. Operating Expenses are detailed, including SG&A, R&D, and Depreciation. The statement then shows Operating Income, followed by non-operating items and Interest Expense, leading to Pretax Income. After accounting for Income Tax and any Extraordinary items, it concludes with Net Income, representing the company's bottom-line profit or loss for the period."
)
def get_google_news(
query: Annotated[str, "Query to search with"],
curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
) -> str:
query = query.replace(" ", "+")
start_date = datetime.strptime(curr_date, "%Y-%m-%d")
before = start_date - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
news_results = getNewsData(query, before, curr_date)
news_str = ""
for news in news_results:
news_str += (
f"### {news['title']} (source: {news['source']}) \n\n{news['snippet']}\n\n"
)
if len(news_results) == 0:
return ""
return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"
def get_reddit_global_news(
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
max_limit_per_day: Annotated[int, "Maximum number of news per day"],
) -> str:
"""
Retrieve the latest top reddit news
Args:
start_date: Start date in yyyy-mm-dd format
end_date: End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
"""
start_date = datetime.strptime(start_date, "%Y-%m-%d")
before = start_date - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
posts = []
# iterate from start_date to end_date
curr_date = datetime.strptime(before, "%Y-%m-%d")
total_iterations = (start_date - curr_date).days + 1
pbar = tqdm(desc=f"Getting Global News on {start_date}", total=total_iterations)
while curr_date <= start_date:
curr_date_str = curr_date.strftime("%Y-%m-%d")
fetch_result = fetch_top_from_category(
"global_news",
curr_date_str,
max_limit_per_day,
data_path=os.path.join(DATA_DIR, "reddit_data"),
)
posts.extend(fetch_result)
curr_date += relativedelta(days=1)
pbar.update(1)
pbar.close()
if len(posts) == 0:
return ""
news_str = ""
for post in posts:
if post["content"] == "":
news_str += f"### {post['title']}\n\n"
else:
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
return f"## Global News Reddit, from {before} to {curr_date}:\n{news_str}"
def get_reddit_company_news(
ticker: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
max_limit_per_day: Annotated[int, "Maximum number of news per day"],
) -> str:
"""
Retrieve the latest top reddit news
Args:
ticker: ticker symbol of the company
start_date: Start date in yyyy-mm-dd format
end_date: End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
"""
start_date = datetime.strptime(start_date, "%Y-%m-%d")
before = start_date - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
posts = []
# iterate from start_date to end_date
curr_date = datetime.strptime(before, "%Y-%m-%d")
total_iterations = (start_date - curr_date).days + 1
pbar = tqdm(
desc=f"Getting Company News for {ticker} on {start_date}",
total=total_iterations,
)
while curr_date <= start_date:
curr_date_str = curr_date.strftime("%Y-%m-%d")
fetch_result = fetch_top_from_category(
"company_news",
curr_date_str,
max_limit_per_day,
ticker,
data_path=os.path.join(DATA_DIR, "reddit_data"),
)
posts.extend(fetch_result)
curr_date += relativedelta(days=1)
pbar.update(1)
pbar.close()
if len(posts) == 0:
return ""
news_str = ""
for post in posts:
if post["content"] == "":
news_str += f"### {post['title']}\n\n"
else:
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
return f"##{ticker} News Reddit, from {before} to {curr_date}:\n\n{news_str}"
def get_stock_stats_indicators_window(
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd"
],
look_back_days: Annotated[int, "how many days to look back"],
online: Annotated[bool, "to fetch data online or offline"],
) -> str:
best_ind_params = {
# Moving Averages
"close_50_sma": (
"50 SMA: A medium-term trend indicator. "
"Usage: Identify trend direction and serve as dynamic support/resistance. "
"Tips: It lags price; combine with faster indicators for timely signals."
),
"close_200_sma": (
"200 SMA: A long-term trend benchmark. "
"Usage: Confirm overall market trend and identify golden/death cross setups. "
"Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries."
),
"close_10_ema": (
"10 EMA: A responsive short-term average. "
"Usage: Capture quick shifts in momentum and potential entry points. "
"Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals."
),
# MACD Related
"macd": (
"MACD: Computes momentum via differences of EMAs. "
"Usage: Look for crossovers and divergence as signals of trend changes. "
"Tips: Confirm with other indicators in low-volatility or sideways markets."
),
"macds": (
"MACD Signal: An EMA smoothing of the MACD line. "
"Usage: Use crossovers with the MACD line to trigger trades. "
"Tips: Should be part of a broader strategy to avoid false positives."
),
"macdh": (
"MACD Histogram: Shows the gap between the MACD line and its signal. "
"Usage: Visualize momentum strength and spot divergence early. "
"Tips: Can be volatile; complement with additional filters in fast-moving markets."
),
# Momentum Indicators
"rsi": (
"RSI: Measures momentum to flag overbought/oversold conditions. "
"Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. "
"Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis."
),
# Volatility Indicators
"boll": (
"Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. "
"Usage: Acts as a dynamic benchmark for price movement. "
"Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals."
),
"boll_ub": (
"Bollinger Upper Band: Typically 2 standard deviations above the middle line. "
"Usage: Signals potential overbought conditions and breakout zones. "
"Tips: Confirm signals with other tools; prices may ride the band in strong trends."
),
"boll_lb": (
"Bollinger Lower Band: Typically 2 standard deviations below the middle line. "
"Usage: Indicates potential oversold conditions. "
"Tips: Use additional analysis to avoid false reversal signals."
),
"atr": (
"ATR: Averages true range to measure volatility. "
"Usage: Set stop-loss levels and adjust position sizes based on current market volatility. "
"Tips: It's a reactive measure, so use it as part of a broader risk management strategy."
),
# Volume-Based Indicators
"vwma": (
"VWMA: A moving average weighted by volume. "
"Usage: Confirm trends by integrating price action with volume data. "
"Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses."
),
"mfi": (
"MFI: The Money Flow Index is a momentum indicator that uses both price and volume to measure buying and selling pressure. "
"Usage: Identify overbought (>80) or oversold (<20) conditions and confirm the strength of trends or reversals. "
"Tips: Use alongside RSI or MACD to confirm signals; divergence between price and MFI can indicate potential reversals."
),
}
if indicator not in best_ind_params:
raise ValueError(
f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}"
)
end_date = curr_date
curr_date = datetime.strptime(curr_date, "%Y-%m-%d")
before = curr_date - relativedelta(days=look_back_days)
if not online:
# read from YFin data
data = pd.read_csv(
os.path.join(
DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
data["Date"] = pd.to_datetime(data["Date"], utc=True)
dates_in_df = data["Date"].astype(str).str[:10]
ind_string = ""
while curr_date >= before:
# only do the trading dates
if curr_date.strftime("%Y-%m-%d") in dates_in_df.values:
indicator_value = get_stockstats_indicator(
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online
)
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
curr_date = curr_date - relativedelta(days=1)
else:
# online gathering
ind_string = ""
while curr_date >= before:
indicator_value = get_stockstats_indicator(
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online
)
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
curr_date = curr_date - relativedelta(days=1)
result_str = (
f"## {indicator} values from {before.strftime('%Y-%m-%d')} to {end_date}:\n\n"
+ ind_string
+ "\n\n"
+ best_ind_params.get(indicator, "No description available.")
)
return result_str
def get_stockstats_indicator(
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd"
],
online: Annotated[bool, "to fetch data online or offline"],
) -> str:
curr_date = datetime.strptime(curr_date, "%Y-%m-%d")
curr_date = curr_date.strftime("%Y-%m-%d")
try:
indicator_value = StockstatsUtils.get_stock_stats(
symbol,
indicator,
curr_date,
os.path.join(DATA_DIR, "market_data", "price_data"),
online=online,
)
except Exception as e:
print(
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}"
)
return ""
return str(indicator_value)
def get_YFin_data_window(
symbol: Annotated[str, "ticker symbol of the company"],
curr_date: Annotated[str, "Start date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
) -> str:
# calculate past days
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
before = date_obj - relativedelta(days=look_back_days)
start_date = before.strftime("%Y-%m-%d")
# read in data
data = pd.read_csv(
os.path.join(
DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
# Extract just the date part for comparison
data["DateOnly"] = data["Date"].str[:10]
# Filter data between the start and end dates (inclusive)
filtered_data = data[
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= curr_date)
]
# Drop the temporary column we created
filtered_data = filtered_data.drop("DateOnly", axis=1)
# Set pandas display options to show the full DataFrame
with pd.option_context(
"display.max_rows", None, "display.max_columns", None, "display.width", None
):
df_string = filtered_data.to_string()
return (
f"## Raw Market Data for {symbol} from {start_date} to {curr_date}:\n\n"
+ df_string
)
def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
):
datetime.strptime(start_date, "%Y-%m-%d")
datetime.strptime(end_date, "%Y-%m-%d")
# Create ticker object
ticker = yf.Ticker(symbol.upper())
# Fetch historical data for the specified date range
data = ticker.history(start=start_date, end=end_date)
# Check if data is empty
if data.empty:
return (
f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
)
# Remove timezone info from index for cleaner output
if data.index.tz is not None:
data.index = data.index.tz_localize(None)
# Round numerical values to 2 decimal places for cleaner display
numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"]
for col in numeric_columns:
if col in data.columns:
data[col] = data[col].round(2)
# Convert DataFrame to CSV string
csv_string = data.to_csv()
# Add header information
header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n"
header += f"# Total records: {len(data)}\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
return header + csv_string
def get_YFin_data(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
# read in data
data = pd.read_csv(
os.path.join(
DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
if end_date > "2025-03-25":
raise Exception(
f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25"
)
# Extract just the date part for comparison
data["DateOnly"] = data["Date"].str[:10]
# Filter data between the start and end dates (inclusive)
filtered_data = data[
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= end_date)
]
# Drop the temporary column we created
filtered_data = filtered_data.drop("DateOnly", axis=1)
# remove the index from the dataframe
filtered_data = filtered_data.reset_index(drop=True)
return filtered_data
def get_stock_news_openai(ticker, curr_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
model=config["quick_think_llm"],
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text
def get_global_news_openai(curr_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
model=config["quick_think_llm"],
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text
def get_fundamentals_openai(ticker, curr_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
model=config["quick_think_llm"],
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text

View File

@ -1,275 +0,0 @@
# 文件差异报告
# 当前文件: tradingagents\dataflows\interface.py
# 中文版文件: TradingAgentsCN\tradingagents\dataflows\interface.py
# 生成时间: 周日 2025/07/06
--- current/interface.py+++ chinese_version/interface.py@@ -1,5 +1,6 @@ from typing import Annotated, Dict
from .reddit_utils import fetch_top_from_category
+from .chinese_finance_utils import get_chinese_social_sentiment
from .yfin_utils import *
from .stockstats_utils import *
from .googlenews_utils import *
@@ -43,7 +44,14 @@ result = get_data_in_range(ticker, before, curr_date, "news_data", DATA_DIR)
if len(result) == 0:
- return ""
+ error_msg = f"⚠️ 无法获取{ticker}的新闻数据 ({before} 到 {curr_date})\n"
+ error_msg += f"可能的原因:\n"
+ error_msg += f"1. 数据文件不存在或路径配置错误\n"
+ error_msg += f"2. 指定日期范围内没有新闻数据\n"
+ error_msg += f"3. 需要先下载或更新Finnhub新闻数据\n"
+ error_msg += f"建议:检查数据目录配置或重新获取新闻数据"
+ print(f"📰 [DEBUG] {error_msg}")
+ return error_msg
combined_result = ""
for day, data in result.items():
@@ -772,36 +780,217 @@ return response.output[1].content[0].text
+def get_fundamentals_finnhub(ticker, curr_date):
+ """
+ 使用Finnhub API获取股票基本面数据作为OpenAI的备选方案
+ Args:
+ ticker (str): 股票代码
+ curr_date (str): 当前日期格式为yyyy-mm-dd
+ Returns:
+ str: 格式化的基本面数据报告
+ """
+ try:
+ import finnhub
+ import os
+ from .cache_manager import get_cache
+
+ # 检查缓存
+ cache = get_cache()
+ cached_key = cache.find_cached_fundamentals_data(ticker, data_source="finnhub")
+ if cached_key:
+ cached_data = cache.load_fundamentals_data(cached_key)
+ if cached_data:
+ print(f"💾 [DEBUG] 从缓存加载Finnhub基本面数据: {ticker}")
+ return cached_data
+
+ # 获取Finnhub API密钥
+ api_key = os.getenv('FINNHUB_API_KEY')
+ if not api_key:
+ return "错误未配置FINNHUB_API_KEY环境变量"
+
+ # 初始化Finnhub客户端
+ finnhub_client = finnhub.Client(api_key=api_key)
+
+ print(f"📊 [DEBUG] 使用Finnhub API获取 {ticker} 的基本面数据...")
+
+ # 获取基本财务数据
+ try:
+ basic_financials = finnhub_client.company_basic_financials(ticker, 'all')
+ except Exception as e:
+ print(f"❌ [DEBUG] Finnhub基本财务数据获取失败: {str(e)}")
+ basic_financials = None
+
+ # 获取公司概况
+ try:
+ company_profile = finnhub_client.company_profile2(symbol=ticker)
+ except Exception as e:
+ print(f"❌ [DEBUG] Finnhub公司概况获取失败: {str(e)}")
+ company_profile = None
+
+ # 获取收益数据
+ try:
+ earnings = finnhub_client.company_earnings(ticker, limit=4)
+ except Exception as e:
+ print(f"❌ [DEBUG] Finnhub收益数据获取失败: {str(e)}")
+ earnings = None
+
+ # 格式化报告
+ report = f"# {ticker} 基本面分析报告Finnhub数据源\n\n"
+ report += f"**数据获取时间**: {curr_date}\n"
+ report += f"**数据来源**: Finnhub API\n\n"
+
+ # 公司概况部分
+ if company_profile:
+ report += "## 公司概况\n"
+ report += f"- **公司名称**: {company_profile.get('name', 'N/A')}\n"
+ report += f"- **行业**: {company_profile.get('finnhubIndustry', 'N/A')}\n"
+ report += f"- **国家**: {company_profile.get('country', 'N/A')}\n"
+ report += f"- **货币**: {company_profile.get('currency', 'N/A')}\n"
+ report += f"- **市值**: {company_profile.get('marketCapitalization', 'N/A')} 百万美元\n"
+ report += f"- **流通股数**: {company_profile.get('shareOutstanding', 'N/A')} 百万股\n\n"
+
+ # 基本财务指标
+ if basic_financials and 'metric' in basic_financials:
+ metrics = basic_financials['metric']
+ report += "## 关键财务指标\n"
+ report += "| 指标 | 数值 |\n"
+ report += "|------|------|\n"
+
+ # 估值指标
+ if 'peBasicExclExtraTTM' in metrics:
+ report += f"| 市盈率 (PE) | {metrics['peBasicExclExtraTTM']:.2f} |\n"
+ if 'psAnnual' in metrics:
+ report += f"| 市销率 (PS) | {metrics['psAnnual']:.2f} |\n"
+ if 'pbAnnual' in metrics:
+ report += f"| 市净率 (PB) | {metrics['pbAnnual']:.2f} |\n"
+
+ # 盈利能力指标
+ if 'roeTTM' in metrics:
+ report += f"| 净资产收益率 (ROE) | {metrics['roeTTM']:.2f}% |\n"
+ if 'roaTTM' in metrics:
+ report += f"| 总资产收益率 (ROA) | {metrics['roaTTM']:.2f}% |\n"
+ if 'netProfitMarginTTM' in metrics:
+ report += f"| 净利润率 | {metrics['netProfitMarginTTM']:.2f}% |\n"
+
+ # 财务健康指标
+ if 'currentRatioAnnual' in metrics:
+ report += f"| 流动比率 | {metrics['currentRatioAnnual']:.2f} |\n"
+ if 'totalDebt/totalEquityAnnual' in metrics:
+ report += f"| 负债权益比 | {metrics['totalDebt/totalEquityAnnual']:.2f} |\n"
+
+ report += "\n"
+
+ # 收益历史
+ if earnings:
+ report += "## 收益历史\n"
+ report += "| 季度 | 实际EPS | 预期EPS | 差异 |\n"
+ report += "|------|---------|---------|------|\n"
+ for earning in earnings[:4]: # 显示最近4个季度
+ actual = earning.get('actual', 'N/A')
+ estimate = earning.get('estimate', 'N/A')
+ period = earning.get('period', 'N/A')
+ surprise = earning.get('surprise', 'N/A')
+ report += f"| {period} | {actual} | {estimate} | {surprise} |\n"
+ report += "\n"
+
+ # 数据可用性说明
+ report += "## 数据说明\n"
+ report += "- 本报告使用Finnhub API提供的官方财务数据\n"
+ report += "- 数据来源于公司财报和SEC文件\n"
+ report += "- TTM表示过去12个月数据\n"
+ report += "- Annual表示年度数据\n\n"
+
+ if not basic_financials and not company_profile and not earnings:
+ report += "⚠️ **警告**: 无法获取该股票的基本面数据,可能原因:\n"
+ report += "- 股票代码不正确\n"
+ report += "- Finnhub API限制\n"
+ report += "- 该股票暂无基本面数据\n"
+
+ # 保存到缓存
+ if report and len(report) > 100: # 只有当报告有实际内容时才缓存
+ cache.save_fundamentals_data(ticker, report, data_source="finnhub")
+
+ print(f"📊 [DEBUG] Finnhub基本面数据获取完成报告长度: {len(report)}")
+ return report
+
+ except ImportError:
+ return "错误未安装finnhub-python库请运行: pip install finnhub-python"
+ except Exception as e:
+ print(f"❌ [DEBUG] Finnhub基本面数据获取失败: {str(e)}")
+ return f"Finnhub基本面数据获取失败: {str(e)}"
+
+
def get_fundamentals_openai(ticker, curr_date):
- config = get_config()
- client = OpenAI(base_url=config["backend_url"])
-
- response = client.responses.create(
- model=config["quick_think_llm"],
- input=[
- {
- "role": "system",
- "content": [
- {
- "type": "input_text",
- "text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
- }
- ],
- }
- ],
- text={"format": {"type": "text"}},
- reasoning={},
- tools=[
- {
- "type": "web_search_preview",
- "user_location": {"type": "approximate"},
- "search_context_size": "low",
- }
- ],
- temperature=1,
- max_output_tokens=4096,
- top_p=1,
- store=True,
- )
-
- return response.output[1].content[0].text
+ """
+ 获取股票基本面数据优先使用OpenAI失败时回退到Finnhub API
+ 支持缓存机制以提高性能
+ Args:
+ ticker (str): 股票代码
+ curr_date (str): 当前日期格式为yyyy-mm-dd
+ Returns:
+ str: 基本面数据报告
+ """
+ try:
+ from .cache_manager import get_cache
+
+ # 检查缓存 - 优先检查OpenAI缓存
+ cache = get_cache()
+ cached_key = cache.find_cached_fundamentals_data(ticker, data_source="openai")
+ if cached_key:
+ cached_data = cache.load_fundamentals_data(cached_key)
+ if cached_data:
+ print(f"💾 [DEBUG] 从缓存加载OpenAI基本面数据: {ticker}")
+ return cached_data
+
+ config = get_config()
+
+ # 检查是否配置了OpenAI相关设置
+ if not config.get("backend_url") or not config.get("quick_think_llm"):
+ print(f"📊 [DEBUG] OpenAI配置不完整直接使用Finnhub API")
+ return get_fundamentals_finnhub(ticker, curr_date)
+
+ print(f"📊 [DEBUG] 尝试使用OpenAI获取 {ticker} 的基本面数据...")
+
+ client = OpenAI(base_url=config["backend_url"])
+
+ response = client.responses.create(
+ model=config["quick_think_llm"],
+ input=[
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "input_text",
+ "text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
+ }
+ ],
+ }
+ ],
+ text={"format": {"type": "text"}},
+ reasoning={},
+ tools=[
+ {
+ "type": "web_search_preview",
+ "user_location": {"type": "approximate"},
+ "search_context_size": "low",
+ }
+ ],
+ temperature=1,
+ max_output_tokens=4096,
+ top_p=1,
+ store=True,
+ )
+
+ result = response.output[1].content[0].text
+
+ # 保存到缓存
+ if result and len(result) > 100: # 只有当结果有实际内容时才缓存
+ cache.save_fundamentals_data(ticker, result, data_source="openai")
+
+ print(f"📊 [DEBUG] OpenAI基本面数据获取成功长度: {len(result)}")
+ return result
+
+ except Exception as e:
+ print(f"❌ [DEBUG] OpenAI基本面数据获取失败: {str(e)}")
+ print(f"📊 [DEBUG] 回退到Finnhub API...")
+ return get_fundamentals_finnhub(ticker, curr_date)

View File

@ -1,404 +0,0 @@
#!/usr/bin/env python3
"""
Optimized US Stock Data Fetcher
Integrates caching strategy to reduce API calls and improve response speed
"""
import os
import time
import random
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
import yfinance as yf
import pandas as pd
from .cache_manager import get_cache
from .config import get_config
class OptimizedUSDataProvider:
"""Optimized US Stock Data Provider - Integrates caching and API rate limiting"""
def __init__(self):
self.cache = get_cache()
self.config = get_config()
self.last_api_call = 0
self.min_api_interval = 1.0 # Minimum API call interval (seconds)
print("📊 Optimized US stock data provider initialized")
def _wait_for_rate_limit(self):
"""Wait for API rate limit"""
current_time = time.time()
time_since_last_call = current_time - self.last_api_call
if time_since_last_call < self.min_api_interval:
wait_time = self.min_api_interval - time_since_last_call
print(f"⏳ API rate limit wait {wait_time:.1f}s...")
time.sleep(wait_time)
self.last_api_call = time.time()
def get_stock_data(self, symbol: str, start_date: str, end_date: str,
force_refresh: bool = False) -> str:
"""
Get US stock data - prioritize cache usage
Args:
symbol: Stock symbol
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
force_refresh: Whether to force refresh cache
Returns:
Formatted stock data string
"""
try:
# Check cache first (unless force refresh)
if not force_refresh:
cache_key = self.cache.find_cached_stock_data(
symbol, start_date, end_date, "optimized_yfinance"
)
if cache_key and self.cache.is_cache_valid(cache_key, symbol):
cached_data = self.cache.load_stock_data(cache_key)
if cached_data:
print(f"📖 Using cached data for {symbol}")
if isinstance(cached_data, pd.DataFrame):
return self._format_stock_data(cached_data, symbol)
else:
return cached_data
# Fetch new data from API
print(f"🌐 Fetching new data for {symbol} from {start_date} to {end_date}")
# Wait for rate limit
self._wait_for_rate_limit()
# Try Yahoo Finance first
try:
data = self._fetch_from_yfinance(symbol, start_date, end_date)
if data is not None and not data.empty:
# Cache the DataFrame
cache_key = self.cache.save_stock_data(
symbol, data, start_date, end_date, "optimized_yfinance"
)
# Format and return
formatted_data = self._format_stock_data(data, symbol)
print(f"✅ Successfully fetched and cached data for {symbol}")
return formatted_data
else:
print(f"⚠️ No data returned from Yahoo Finance for {symbol}")
except Exception as e:
print(f"❌ Yahoo Finance error for {symbol}: {e}")
# Fallback: Try FINNHUB (if API key available)
try:
finnhub_data = self._fetch_from_finnhub(symbol, start_date, end_date)
if finnhub_data:
# Cache the string data
cache_key = self.cache.save_stock_data(
symbol, finnhub_data, start_date, end_date, "optimized_finnhub"
)
print(f"✅ Successfully fetched data from FINNHUB for {symbol}")
return finnhub_data
except Exception as e:
print(f"❌ FINNHUB error for {symbol}: {e}")
# If all fails, return error message
error_msg = f"❌ Failed to fetch data for {symbol} from {start_date} to {end_date}"
print(error_msg)
return error_msg
except Exception as e:
error_msg = f"❌ Unexpected error fetching data for {symbol}: {e}"
print(error_msg)
return error_msg
def _fetch_from_yfinance(self, symbol: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""Fetch data from Yahoo Finance"""
try:
ticker = yf.Ticker(symbol)
data = ticker.history(start=start_date, end=end_date)
if data.empty:
print(f"⚠️ No data available for {symbol} in the specified date range")
return None
# Reset index to make Date a column
data = data.reset_index()
# Ensure we have the required columns
required_columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
missing_columns = [col for col in required_columns if col not in data.columns]
if missing_columns:
print(f"⚠️ Missing columns for {symbol}: {missing_columns}")
return None
return data[required_columns]
except Exception as e:
print(f"❌ Yahoo Finance fetch error for {symbol}: {e}")
return None
def _fetch_from_finnhub(self, symbol: str, start_date: str, end_date: str) -> Optional[str]:
"""Fetch data from FINNHUB API"""
try:
# Check if FINNHUB API key is available
finnhub_api_key = os.getenv('FINNHUB_API_KEY')
if not finnhub_api_key:
print("⚠️ FINNHUB API key not found, skipping FINNHUB data fetch")
return None
import finnhub
# Initialize FINNHUB client
finnhub_client = finnhub.Client(api_key=finnhub_api_key)
# Convert dates to timestamps
start_timestamp = int(datetime.strptime(start_date, '%Y-%m-%d').timestamp())
end_timestamp = int(datetime.strptime(end_date, '%Y-%m-%d').timestamp())
# Fetch candle data
candle_data = finnhub_client.stock_candles(symbol, 'D', start_timestamp, end_timestamp)
if candle_data['s'] != 'ok':
print(f"⚠️ FINNHUB returned status: {candle_data['s']} for {symbol}")
return None
# Format data
formatted_data = self._format_finnhub_data(candle_data, symbol)
return formatted_data
except ImportError:
print("⚠️ finnhub-python package not installed, skipping FINNHUB data fetch")
return None
except Exception as e:
print(f"❌ FINNHUB fetch error for {symbol}: {e}")
return None
def _format_stock_data(self, data: pd.DataFrame, symbol: str) -> str:
"""Format DataFrame stock data into string"""
try:
# Ensure Date column is properly formatted
if 'Date' in data.columns:
data['Date'] = pd.to_datetime(data['Date']).dt.strftime('%Y-%m-%d')
# Round numerical columns to 2 decimal places
numeric_columns = ['Open', 'High', 'Low', 'Close']
for col in numeric_columns:
if col in data.columns:
data[col] = data[col].round(2)
# Format volume as integer
if 'Volume' in data.columns:
data['Volume'] = data['Volume'].astype(int)
# Create formatted string
formatted_lines = [f"Stock Data for {symbol}:"]
formatted_lines.append("Date,Open,High,Low,Close,Volume")
for _, row in data.iterrows():
line = f"{row['Date']},{row['Open']},{row['High']},{row['Low']},{row['Close']},{row['Volume']}"
formatted_lines.append(line)
# Add summary statistics
if len(data) > 0:
formatted_lines.append(f"\nSummary for {symbol}:")
formatted_lines.append(f"Period: {data['Date'].iloc[0]} to {data['Date'].iloc[-1]}")
formatted_lines.append(f"Total trading days: {len(data)}")
formatted_lines.append(f"Average volume: {data['Volume'].mean():,.0f}")
formatted_lines.append(f"Price range: ${data['Low'].min():.2f} - ${data['High'].max():.2f}")
# Calculate basic statistics
start_price = data['Open'].iloc[0]
end_price = data['Close'].iloc[-1]
price_change = end_price - start_price
price_change_pct = (price_change / start_price) * 100
formatted_lines.append(f"Period return: {price_change_pct:+.2f}% (${price_change:+.2f})")
return "\n".join(formatted_lines)
except Exception as e:
print(f"❌ Error formatting stock data for {symbol}: {e}")
return f"Error formatting data for {symbol}: {str(e)}"
def _format_finnhub_data(self, candle_data: Dict, symbol: str) -> str:
"""Format FINNHUB candle data into string"""
try:
# Extract data arrays
timestamps = candle_data['t']
opens = candle_data['o']
highs = candle_data['h']
lows = candle_data['l']
closes = candle_data['c']
volumes = candle_data['v']
# Create formatted string
formatted_lines = [f"Stock Data for {symbol} (FINNHUB):"]
formatted_lines.append("Date,Open,High,Low,Close,Volume")
for i in range(len(timestamps)):
date = datetime.fromtimestamp(timestamps[i]).strftime('%Y-%m-%d')
line = f"{date},{opens[i]:.2f},{highs[i]:.2f},{lows[i]:.2f},{closes[i]:.2f},{int(volumes[i])}"
formatted_lines.append(line)
# Add summary
if len(timestamps) > 0:
start_date = datetime.fromtimestamp(timestamps[0]).strftime('%Y-%m-%d')
end_date = datetime.fromtimestamp(timestamps[-1]).strftime('%Y-%m-%d')
formatted_lines.append(f"\nSummary for {symbol}:")
formatted_lines.append(f"Period: {start_date} to {end_date}")
formatted_lines.append(f"Total trading days: {len(timestamps)}")
formatted_lines.append(f"Average volume: {sum(volumes)/len(volumes):,.0f}")
formatted_lines.append(f"Price range: ${min(lows):.2f} - ${max(highs):.2f}")
# Calculate return
price_change = closes[-1] - opens[0]
price_change_pct = (price_change / opens[0]) * 100
formatted_lines.append(f"Period return: {price_change_pct:+.2f}% (${price_change:+.2f})")
return "\n".join(formatted_lines)
except Exception as e:
print(f"❌ Error formatting FINNHUB data for {symbol}: {e}")
return f"Error formatting FINNHUB data for {symbol}: {str(e)}"
def get_stock_with_indicators(self, symbol: str, start_date: str, end_date: str,
indicators: list = None) -> str:
"""
Get stock data with technical indicators
Args:
symbol: Stock symbol
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
indicators: List of indicators to calculate ['sma_20', 'rsi', 'macd']
Returns:
Formatted stock data with indicators
"""
try:
# Get basic stock data
basic_data = self.get_stock_data(symbol, start_date, end_date)
if basic_data.startswith("❌"):
return basic_data
# If no indicators requested, return basic data
if not indicators:
return basic_data
# Fetch DataFrame for indicator calculation
data_df = self._fetch_from_yfinance(symbol, start_date, end_date)
if data_df is None or data_df.empty:
return basic_data
# Calculate indicators
indicator_data = self._calculate_indicators(data_df, indicators)
# Combine basic data with indicators
combined_data = basic_data + "\n\nTechnical Indicators:\n" + indicator_data
return combined_data
except Exception as e:
error_msg = f"❌ Error getting stock data with indicators for {symbol}: {e}"
print(error_msg)
return error_msg
def _calculate_indicators(self, data: pd.DataFrame, indicators: list) -> str:
"""Calculate technical indicators"""
try:
indicator_lines = []
for indicator in indicators:
if indicator == 'sma_20':
data['SMA_20'] = data['Close'].rolling(window=20).mean()
latest_sma = data['SMA_20'].iloc[-1]
indicator_lines.append(f"SMA(20): ${latest_sma:.2f}")
elif indicator == 'rsi':
# Simple RSI calculation
delta = data['Close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
rs = gain / loss
rsi = 100 - (100 / (1 + rs))
latest_rsi = rsi.iloc[-1]
indicator_lines.append(f"RSI(14): {latest_rsi:.2f}")
elif indicator == 'macd':
# Simple MACD calculation
ema_12 = data['Close'].ewm(span=12).mean()
ema_26 = data['Close'].ewm(span=26).mean()
macd_line = ema_12 - ema_26
signal_line = macd_line.ewm(span=9).mean()
latest_macd = macd_line.iloc[-1]
latest_signal = signal_line.iloc[-1]
indicator_lines.append(f"MACD: {latest_macd:.4f}, Signal: {latest_signal:.4f}")
return "\n".join(indicator_lines)
except Exception as e:
print(f"❌ Error calculating indicators: {e}")
return f"Error calculating indicators: {str(e)}"
# Global provider instance
_global_provider = None
def get_optimized_us_data_provider() -> OptimizedUSDataProvider:
"""
Get global optimized US data provider instance
Returns:
OptimizedUSDataProvider instance
"""
global _global_provider
if _global_provider is None:
_global_provider = OptimizedUSDataProvider()
return _global_provider
# Convenience functions
def get_optimized_stock_data(symbol: str, start_date: str, end_date: str,
force_refresh: bool = False) -> str:
"""Get optimized stock data (convenience function)"""
provider = get_optimized_us_data_provider()
return provider.get_stock_data(symbol, start_date, end_date, force_refresh)
def get_stock_with_indicators(symbol: str, start_date: str, end_date: str,
indicators: list = None) -> str:
"""Get stock data with technical indicators (convenience function)"""
provider = get_optimized_us_data_provider()
return provider.get_stock_with_indicators(symbol, start_date, end_date, indicators)
if __name__ == "__main__":
# Test the optimized data provider
print("🧪 Testing Optimized US Data Provider...")
# Initialize provider
provider = OptimizedUSDataProvider()
# Test data fetch
data = provider.get_stock_data("AAPL", "2024-01-01", "2024-01-31")
print("Sample data:")
print(data[:500] + "..." if len(data) > 500 else data)
# Test with indicators
data_with_indicators = provider.get_stock_with_indicators(
"AAPL", "2024-01-01", "2024-01-31",
indicators=['sma_20', 'rsi', 'macd']
)
print("\nData with indicators:")
print(data_with_indicators[-500:] if len(data_with_indicators) > 500 else data_with_indicators)
print("✅ Optimized data provider test completed!")

View File

@ -1,679 +0,0 @@
# 文件差异报告
# 当前文件: tradingagents\dataflows\optimized_us_data.py
# 中文版文件: TradingAgentsCN\tradingagents\dataflows\optimized_us_data.py
# 生成时间: 周日 2025/07/06
--- current/optimized_us_data.py+++ chinese_version/optimized_us_data.py@@ -1,7 +1,7 @@ #!/usr/bin/env python3
"""
-Optimized US Stock Data Fetcher
-Integrates caching strategy to reduce API calls and improve response speed
+优化的美股数据获取工具
+集成缓存策略减少API调用提高响应速度
"""
import os
@@ -16,24 +16,24 @@
class OptimizedUSDataProvider:
- """Optimized US Stock Data Provider - Integrates caching and API rate limiting"""
+ """优化的美股数据提供器 - 集成缓存和API限制处理"""
def __init__(self):
self.cache = get_cache()
self.config = get_config()
self.last_api_call = 0
- self.min_api_interval = 1.0 # Minimum API call interval (seconds)
-
- print("📊 Optimized US stock data provider initialized")
+ self.min_api_interval = 1.0 # 最小API调用间隔
+
+ print("📊 优化美股数据提供器初始化完成")
def _wait_for_rate_limit(self):
- """Wait for API rate limit"""
+ """等待API限制"""
current_time = time.time()
time_since_last_call = current_time - self.last_api_call
if time_since_last_call < self.min_api_interval:
wait_time = self.min_api_interval - time_since_last_call
- print(f"⏳ API rate limit wait {wait_time:.1f}s...")
+ print(f"⏳ API限制等待 {wait_time:.1f}s...")
time.sleep(wait_time)
self.last_api_call = time.time()
@@ -41,364 +41,292 @@ def get_stock_data(self, symbol: str, start_date: str, end_date: str,
force_refresh: bool = False) -> str:
"""
- Get US stock data - prioritize cache usage
+ 获取美股数据 - 优先使用缓存
Args:
- symbol: Stock symbol
- start_date: Start date (YYYY-MM-DD)
- end_date: End date (YYYY-MM-DD)
- force_refresh: Whether to force refresh cache
-
+ symbol: 股票代码
+ start_date: 开始日期 (YYYY-MM-DD)
+ end_date: 结束日期 (YYYY-MM-DD)
+ force_refresh: 是否强制刷新缓存
+
Returns:
- Formatted stock data string
+ 格式化的股票数据字符串
"""
+ print(f"📈 获取美股数据: {symbol} ({start_date} 到 {end_date})")
+
+ # 检查缓存(除非强制刷新)
+ if not force_refresh:
+ # 优先查找FINNHUB缓存
+ cache_key = self.cache.find_cached_stock_data(
+ symbol=symbol,
+ start_date=start_date,
+ end_date=end_date,
+ data_source="finnhub"
+ )
+
+ # 如果没有FINNHUB缓存查找Yahoo Finance缓存
+ if not cache_key:
+ cache_key = self.cache.find_cached_stock_data(
+ symbol=symbol,
+ start_date=start_date,
+ end_date=end_date,
+ data_source="yfinance"
+ )
+
+ if cache_key:
+ cached_data = self.cache.load_stock_data(cache_key)
+ if cached_data:
+ print(f"⚡ 从缓存加载美股数据: {symbol}")
+ return cached_data
+
+ # 缓存未命中从API获取 - 优先使用FINNHUB
+ formatted_data = None
+ data_source = None
+
+ # 尝试FINNHUB API优先
try:
- # Check cache first (unless force refresh)
- if not force_refresh:
- cache_key = self.cache.find_cached_stock_data(
- symbol, start_date, end_date, "optimized_yfinance"
- )
-
- if cache_key and self.cache.is_cache_valid(cache_key, symbol):
- cached_data = self.cache.load_stock_data(cache_key)
- if cached_data:
- print(f"📖 Using cached data for {symbol}")
- if isinstance(cached_data, pd.DataFrame):
- return self._format_stock_data(cached_data, symbol)
- else:
- return cached_data
-
- # Fetch new data from API
- print(f"🌐 Fetching new data for {symbol} from {start_date} to {end_date}")
-
- # Wait for rate limit
+ print(f"🌐 从FINNHUB API获取数据: {symbol}")
self._wait_for_rate_limit()
-
- # Try Yahoo Finance first
+
+ formatted_data = self._get_data_from_finnhub(symbol, start_date, end_date)
+ if formatted_data and "❌" not in formatted_data:
+ data_source = "finnhub"
+ print(f"✅ FINNHUB数据获取成功: {symbol}")
+ else:
+ print(f"⚠️ FINNHUB数据获取失败尝试备用方案")
+ formatted_data = None
+
+ except Exception as e:
+ print(f"❌ FINNHUB API调用失败: {e}")
+ formatted_data = None
+
+ # 备用方案Yahoo Finance API
+ if not formatted_data:
try:
- data = self._fetch_from_yfinance(symbol, start_date, end_date)
- if data is not None and not data.empty:
- # Cache the DataFrame
- cache_key = self.cache.save_stock_data(
- symbol, data, start_date, end_date, "optimized_yfinance"
- )
+ print(f"🌐 从Yahoo Finance API获取数据: {symbol}")
+ self._wait_for_rate_limit()
+
+ # 获取数据
+ ticker = yf.Ticker(symbol.upper())
+ data = ticker.history(start=start_date, end=end_date)
+
+ if data.empty:
+ error_msg = f"未找到股票 '{symbol}' 在 {start_date} 到 {end_date} 期间的数据"
+ print(f"❌ {error_msg}")
+ else:
+ # 格式化数据
+ formatted_data = self._format_stock_data(symbol, data, start_date, end_date)
+ data_source = "yfinance"
+ print(f"✅ Yahoo Finance数据获取成功: {symbol}")
+
+ except Exception as e:
+ print(f"❌ Yahoo Finance API调用失败: {e}")
+ formatted_data = None
+
+ # 如果所有API都失败生成备用数据
+ if not formatted_data:
+ error_msg = "所有美股数据源都不可用"
+ print(f"❌ {error_msg}")
+ return self._generate_fallback_data(symbol, start_date, end_date, error_msg)
+
+ # 保存到缓存
+ self.cache.save_stock_data(
+ symbol=symbol,
+ data=formatted_data,
+ start_date=start_date,
+ end_date=end_date,
+ data_source=data_source
+ )
+
+ return formatted_data
+
+ def _format_stock_data(self, symbol: str, data: pd.DataFrame,
+ start_date: str, end_date: str) -> str:
+ """格式化股票数据为字符串"""
+
+ # 移除时区信息
+ if data.index.tz is not None:
+ data.index = data.index.tz_localize(None)
+
+ # 四舍五入数值
+ numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"]
+ for col in numeric_columns:
+ if col in data.columns:
+ data[col] = data[col].round(2)
+
+ # 获取最新价格和统计信息
+ latest_price = data['Close'].iloc[-1]
+ price_change = data['Close'].iloc[-1] - data['Close'].iloc[0]
+ price_change_pct = (price_change / data['Close'].iloc[0]) * 100
+
+ # 计算技术指标
+ data['MA5'] = data['Close'].rolling(window=5).mean()
+ data['MA10'] = data['Close'].rolling(window=10).mean()
+ data['MA20'] = data['Close'].rolling(window=20).mean()
+
+ # 计算RSI
+ delta = data['Close'].diff()
+ gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
+ loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
+ rs = gain / loss
+ rsi = 100 - (100 / (1 + rs))
+
+ # 格式化输出
+ result = f"""# {symbol} 美股数据分析
+
+## 📊 基本信息
+- 股票代码: {symbol}
+- 数据期间: {start_date} 至 {end_date}
+- 数据条数: {len(data)}条
+- 最新价格: ${latest_price:.2f}
+- 期间涨跌: ${price_change:+.2f} ({price_change_pct:+.2f}%)
+
+## 📈 价格统计
+- 期间最高: ${data['High'].max():.2f}
+- 期间最低: ${data['Low'].min():.2f}
+- 平均成交量: {data['Volume'].mean():,.0f}
+
+## 🔍 技术指标
+- MA5: ${data['MA5'].iloc[-1]:.2f}
+- MA10: ${data['MA10'].iloc[-1]:.2f}
+- MA20: ${data['MA20'].iloc[-1]:.2f}
+- RSI: {rsi.iloc[-1]:.2f}
+
+## 📋 最近5日数据
+{data.tail().to_string()}
+
+数据来源: Yahoo Finance API
+更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
+"""
+
+ return result
+
+ def _try_get_old_cache(self, symbol: str, start_date: str, end_date: str) -> Optional[str]:
+ """尝试获取过期的缓存数据作为备用"""
+ try:
+ # 查找任何相关的缓存不考虑TTL
+ for metadata_file in self.cache.metadata_dir.glob(f"*_meta.json"):
+ try:
+ import json
+ with open(metadata_file, 'r', encoding='utf-8') as f:
+ metadata = json.load(f)
- # Format and return
- formatted_data = self._format_stock_data(data, symbol)
- print(f"✅ Successfully fetched and cached data for {symbol}")
- return formatted_data
- else:
- print(f"⚠️ No data returned from Yahoo Finance for {symbol}")
-
- except Exception as e:
- print(f"❌ Yahoo Finance error for {symbol}: {e}")
-
- # Fallback: Try FINNHUB (if API key available)
- try:
- finnhub_data = self._fetch_from_finnhub(symbol, start_date, end_date)
- if finnhub_data:
- # Cache the string data
- cache_key = self.cache.save_stock_data(
- symbol, finnhub_data, start_date, end_date, "optimized_finnhub"
- )
- print(f"✅ Successfully fetched data from FINNHUB for {symbol}")
- return finnhub_data
-
- except Exception as e:
- print(f"❌ FINNHUB error for {symbol}: {e}")
-
- # If all fails, return error message
- error_msg = f"❌ Failed to fetch data for {symbol} from {start_date} to {end_date}"
- print(error_msg)
- return error_msg
-
+ if (metadata.get('symbol') == symbol and
+ metadata.get('data_type') == 'stock_data' and
+ metadata.get('market_type') == 'us'):
+
+ cache_key = metadata_file.stem.replace('_meta', '')
+ cached_data = self.cache.load_stock_data(cache_key)
+ if cached_data:
+ return cached_data + "\n\n⚠ 注意: 使用的是过期缓存数据"
+ except Exception:
+ continue
+ except Exception:
+ pass
+
+ return None
+
+ def _get_data_from_finnhub(self, symbol: str, start_date: str, end_date: str) -> str:
+ """从FINNHUB API获取股票数据"""
+ try:
+ import finnhub
+ import os
+ from datetime import datetime, timedelta
+
+ # 获取API密钥
+ api_key = os.getenv('FINNHUB_API_KEY')
+ if not api_key:
+ return None
+
+ client = finnhub.Client(api_key=api_key)
+
+ # 获取实时报价
+ quote = client.quote(symbol.upper())
+ if not quote or 'c' not in quote:
+ return None
+
+ # 获取公司信息
+ profile = client.company_profile2(symbol=symbol.upper())
+ company_name = profile.get('name', symbol.upper()) if profile else symbol.upper()
+
+ # 格式化数据
+ current_price = quote.get('c', 0)
+ change = quote.get('d', 0)
+ change_percent = quote.get('dp', 0)
+
+ formatted_data = f"""# {symbol.upper()} 美股数据分析
+
+## 📊 实时行情
+- 股票名称: {company_name}
+- 当前价格: ${current_price:.2f}
+- 涨跌额: ${change:+.2f}
+- 涨跌幅: {change_percent:+.2f}%
+- 开盘价: ${quote.get('o', 0):.2f}
+- 最高价: ${quote.get('h', 0):.2f}
+- 最低价: ${quote.get('l', 0):.2f}
+- 前收盘: ${quote.get('pc', 0):.2f}
+- 更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
+
+## 📈 数据概览
+- 数据期间: {start_date} 至 {end_date}
+- 数据来源: FINNHUB API (实时数据)
+- 当前价位相对位置: {((current_price - quote.get('l', current_price)) / max(quote.get('h', current_price) - quote.get('l', current_price), 0.01) * 100):.1f}%
+- 日内振幅: {((quote.get('h', 0) - quote.get('l', 0)) / max(quote.get('pc', 1), 0.01) * 100):.2f}%
+
+生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
+"""
+
+ return formatted_data
+
except Exception as e:
- error_msg = f"❌ Unexpected error fetching data for {symbol}: {e}"
- print(error_msg)
- return error_msg
-
- def _fetch_from_yfinance(self, symbol: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
- """Fetch data from Yahoo Finance"""
- try:
- ticker = yf.Ticker(symbol)
- data = ticker.history(start=start_date, end=end_date)
-
- if data.empty:
- print(f"⚠️ No data available for {symbol} in the specified date range")
- return None
-
- # Reset index to make Date a column
- data = data.reset_index()
-
- # Ensure we have the required columns
- required_columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
- missing_columns = [col for col in required_columns if col not in data.columns]
-
- if missing_columns:
- print(f"⚠️ Missing columns for {symbol}: {missing_columns}")
- return None
-
- return data[required_columns]
-
- except Exception as e:
- print(f"❌ Yahoo Finance fetch error for {symbol}: {e}")
+ print(f"❌ FINNHUB数据获取失败: {e}")
return None
-
- def _fetch_from_finnhub(self, symbol: str, start_date: str, end_date: str) -> Optional[str]:
- """Fetch data from FINNHUB API"""
- try:
- # Check if FINNHUB API key is available
- finnhub_api_key = os.getenv('FINNHUB_API_KEY')
- if not finnhub_api_key:
- print("⚠️ FINNHUB API key not found, skipping FINNHUB data fetch")
- return None
-
- import finnhub
-
- # Initialize FINNHUB client
- finnhub_client = finnhub.Client(api_key=finnhub_api_key)
-
- # Convert dates to timestamps
- start_timestamp = int(datetime.strptime(start_date, '%Y-%m-%d').timestamp())
- end_timestamp = int(datetime.strptime(end_date, '%Y-%m-%d').timestamp())
-
- # Fetch candle data
- candle_data = finnhub_client.stock_candles(symbol, 'D', start_timestamp, end_timestamp)
-
- if candle_data['s'] != 'ok':
- print(f"⚠️ FINNHUB returned status: {candle_data['s']} for {symbol}")
- return None
-
- # Format data
- formatted_data = self._format_finnhub_data(candle_data, symbol)
- return formatted_data
-
- except ImportError:
- print("⚠️ finnhub-python package not installed, skipping FINNHUB data fetch")
- return None
- except Exception as e:
- print(f"❌ FINNHUB fetch error for {symbol}: {e}")
- return None
-
- def _format_stock_data(self, data: pd.DataFrame, symbol: str) -> str:
- """Format DataFrame stock data into string"""
- try:
- # Ensure Date column is properly formatted
- if 'Date' in data.columns:
- data['Date'] = pd.to_datetime(data['Date']).dt.strftime('%Y-%m-%d')
-
- # Round numerical columns to 2 decimal places
- numeric_columns = ['Open', 'High', 'Low', 'Close']
- for col in numeric_columns:
- if col in data.columns:
- data[col] = data[col].round(2)
-
- # Format volume as integer
- if 'Volume' in data.columns:
- data['Volume'] = data['Volume'].astype(int)
-
- # Create formatted string
- formatted_lines = [f"Stock Data for {symbol}:"]
- formatted_lines.append("Date,Open,High,Low,Close,Volume")
-
- for _, row in data.iterrows():
- line = f"{row['Date']},{row['Open']},{row['High']},{row['Low']},{row['Close']},{row['Volume']}"
- formatted_lines.append(line)
-
- # Add summary statistics
- if len(data) > 0:
- formatted_lines.append(f"\nSummary for {symbol}:")
- formatted_lines.append(f"Period: {data['Date'].iloc[0]} to {data['Date'].iloc[-1]}")
- formatted_lines.append(f"Total trading days: {len(data)}")
- formatted_lines.append(f"Average volume: {data['Volume'].mean():,.0f}")
- formatted_lines.append(f"Price range: ${data['Low'].min():.2f} - ${data['High'].max():.2f}")
-
- # Calculate basic statistics
- start_price = data['Open'].iloc[0]
- end_price = data['Close'].iloc[-1]
- price_change = end_price - start_price
- price_change_pct = (price_change / start_price) * 100
-
- formatted_lines.append(f"Period return: {price_change_pct:+.2f}% (${price_change:+.2f})")
-
- return "\n".join(formatted_lines)
-
- except Exception as e:
- print(f"❌ Error formatting stock data for {symbol}: {e}")
- return f"Error formatting data for {symbol}: {str(e)}"
-
- def _format_finnhub_data(self, candle_data: Dict, symbol: str) -> str:
- """Format FINNHUB candle data into string"""
- try:
- # Extract data arrays
- timestamps = candle_data['t']
- opens = candle_data['o']
- highs = candle_data['h']
- lows = candle_data['l']
- closes = candle_data['c']
- volumes = candle_data['v']
-
- # Create formatted string
- formatted_lines = [f"Stock Data for {symbol} (FINNHUB):"]
- formatted_lines.append("Date,Open,High,Low,Close,Volume")
-
- for i in range(len(timestamps)):
- date = datetime.fromtimestamp(timestamps[i]).strftime('%Y-%m-%d')
- line = f"{date},{opens[i]:.2f},{highs[i]:.2f},{lows[i]:.2f},{closes[i]:.2f},{int(volumes[i])}"
- formatted_lines.append(line)
-
- # Add summary
- if len(timestamps) > 0:
- start_date = datetime.fromtimestamp(timestamps[0]).strftime('%Y-%m-%d')
- end_date = datetime.fromtimestamp(timestamps[-1]).strftime('%Y-%m-%d')
-
- formatted_lines.append(f"\nSummary for {symbol}:")
- formatted_lines.append(f"Period: {start_date} to {end_date}")
- formatted_lines.append(f"Total trading days: {len(timestamps)}")
- formatted_lines.append(f"Average volume: {sum(volumes)/len(volumes):,.0f}")
- formatted_lines.append(f"Price range: ${min(lows):.2f} - ${max(highs):.2f}")
-
- # Calculate return
- price_change = closes[-1] - opens[0]
- price_change_pct = (price_change / opens[0]) * 100
- formatted_lines.append(f"Period return: {price_change_pct:+.2f}% (${price_change:+.2f})")
-
- return "\n".join(formatted_lines)
-
- except Exception as e:
- print(f"❌ Error formatting FINNHUB data for {symbol}: {e}")
- return f"Error formatting FINNHUB data for {symbol}: {str(e)}"
-
- def get_stock_with_indicators(self, symbol: str, start_date: str, end_date: str,
- indicators: list = None) -> str:
- """
- Get stock data with technical indicators
-
- Args:
- symbol: Stock symbol
- start_date: Start date (YYYY-MM-DD)
- end_date: End date (YYYY-MM-DD)
- indicators: List of indicators to calculate ['sma_20', 'rsi', 'macd']
-
- Returns:
- Formatted stock data with indicators
- """
- try:
- # Get basic stock data
- basic_data = self.get_stock_data(symbol, start_date, end_date)
-
- if basic_data.startswith("❌"):
- return basic_data
-
- # If no indicators requested, return basic data
- if not indicators:
- return basic_data
-
- # Fetch DataFrame for indicator calculation
- data_df = self._fetch_from_yfinance(symbol, start_date, end_date)
- if data_df is None or data_df.empty:
- return basic_data
-
- # Calculate indicators
- indicator_data = self._calculate_indicators(data_df, indicators)
-
- # Combine basic data with indicators
- combined_data = basic_data + "\n\nTechnical Indicators:\n" + indicator_data
-
- return combined_data
-
- except Exception as e:
- error_msg = f"❌ Error getting stock data with indicators for {symbol}: {e}"
- print(error_msg)
- return error_msg
-
- def _calculate_indicators(self, data: pd.DataFrame, indicators: list) -> str:
- """Calculate technical indicators"""
- try:
- indicator_lines = []
-
- for indicator in indicators:
- if indicator == 'sma_20':
- data['SMA_20'] = data['Close'].rolling(window=20).mean()
- latest_sma = data['SMA_20'].iloc[-1]
- indicator_lines.append(f"SMA(20): ${latest_sma:.2f}")
-
- elif indicator == 'rsi':
- # Simple RSI calculation
- delta = data['Close'].diff()
- gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
- loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
- rs = gain / loss
- rsi = 100 - (100 / (1 + rs))
- latest_rsi = rsi.iloc[-1]
- indicator_lines.append(f"RSI(14): {latest_rsi:.2f}")
-
- elif indicator == 'macd':
- # Simple MACD calculation
- ema_12 = data['Close'].ewm(span=12).mean()
- ema_26 = data['Close'].ewm(span=26).mean()
- macd_line = ema_12 - ema_26
- signal_line = macd_line.ewm(span=9).mean()
- latest_macd = macd_line.iloc[-1]
- latest_signal = signal_line.iloc[-1]
- indicator_lines.append(f"MACD: {latest_macd:.4f}, Signal: {latest_signal:.4f}")
-
- return "\n".join(indicator_lines)
-
- except Exception as e:
- print(f"❌ Error calculating indicators: {e}")
- return f"Error calculating indicators: {str(e)}"
-
-
-# Global provider instance
-_global_provider = None
+
+ def _generate_fallback_data(self, symbol: str, start_date: str, end_date: str, error_msg: str) -> str:
+ """生成备用数据"""
+ return f"""# {symbol} 美股数据获取失败
+
+## ❌ 错误信息
+{error_msg}
+
+## 📊 模拟数据(仅供演示)
+- 股票代码: {symbol}
+- 数据期间: {start_date} 至 {end_date}
+- 最新价格: ${random.uniform(100, 300):.2f}
+- 模拟涨跌: {random.uniform(-5, 5):+.2f}%
+
+## ⚠️ 重要提示
+由于API限制或网络问题无法获取实时数据。
+建议稍后重试或检查网络连接。
+
+生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
+"""
+
+
+# 全局实例
+_us_data_provider = None
def get_optimized_us_data_provider() -> OptimizedUSDataProvider:
+ """获取全局美股数据提供器实例"""
+ global _us_data_provider
+ if _us_data_provider is None:
+ _us_data_provider = OptimizedUSDataProvider()
+ return _us_data_provider
+
+
+def get_us_stock_data_cached(symbol: str, start_date: str, end_date: str,
+ force_refresh: bool = False) -> str:
"""
- Get global optimized US data provider instance
+ 获取美股数据的便捷函数
+
+ Args:
+ symbol: 股票代码
+ start_date: 开始日期 (YYYY-MM-DD)
+ end_date: 结束日期 (YYYY-MM-DD)
+ force_refresh: 是否强制刷新缓存
Returns:
- OptimizedUSDataProvider instance
+ 格式化的股票数据字符串
"""
- global _global_provider
- if _global_provider is None:
- _global_provider = OptimizedUSDataProvider()
- return _global_provider
-
-
-# Convenience functions
-def get_optimized_stock_data(symbol: str, start_date: str, end_date: str,
- force_refresh: bool = False) -> str:
- """Get optimized stock data (convenience function)"""
provider = get_optimized_us_data_provider()
return provider.get_stock_data(symbol, start_date, end_date, force_refresh)
-
-
-def get_stock_with_indicators(symbol: str, start_date: str, end_date: str,
- indicators: list = None) -> str:
- """Get stock data with technical indicators (convenience function)"""
- provider = get_optimized_us_data_provider()
- return provider.get_stock_with_indicators(symbol, start_date, end_date, indicators)
-
-
-if __name__ == "__main__":
- # Test the optimized data provider
- print("🧪 Testing Optimized US Data Provider...")
-
- # Initialize provider
- provider = OptimizedUSDataProvider()
-
- # Test data fetch
- data = provider.get_stock_data("AAPL", "2024-01-01", "2024-01-31")
- print("Sample data:")
- print(data[:500] + "..." if len(data) > 500 else data)
-
- # Test with indicators
- data_with_indicators = provider.get_stock_with_indicators(
- "AAPL", "2024-01-01", "2024-01-31",
- indicators=['sma_20', 'rsi', 'macd']
- )
- print("\nData with indicators:")
- print(data_with_indicators[-500:] if len(data_with_indicators) > 500 else data_with_indicators)
-
- print("✅ Optimized data provider test completed!")

View File

@ -3,7 +3,7 @@ import os
DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
"data_dir": os.path.join(os.path.expanduser("~"), "Documents", "TradingAgents", "data"),
"data_cache_dir": os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",
@ -19,4 +19,7 @@ DEFAULT_CONFIG = {
"max_recur_limit": 100,
# Tool settings
"online_tools": True,
# Note: Database and cache configuration is now managed by .env file and config.database_manager
# No database/cache settings in default config to avoid configuration conflicts
}

View File

@ -1,22 +0,0 @@
import os
DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
"data_cache_dir": os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",
),
# LLM settings
"llm_provider": "openai",
"deep_think_llm": "o4-mini",
"quick_think_llm": "gpt-4o-mini",
"backend_url": "https://api.openai.com/v1",
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,
"max_recur_limit": 100,
# Tool settings
"online_tools": True,
}

View File

@ -1,20 +0,0 @@
# 文件差异报告
# 当前文件: tradingagents\default_config.py
# 中文版文件: TradingAgentsCN\tradingagents\default_config.py
# 生成时间: 周日 2025/07/06
--- current/default_config.py+++ chinese_version/default_config.py@@ -3,7 +3,7 @@ DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
- "data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
+ "data_dir": os.path.join(os.path.expanduser("~"), "Documents", "TradingAgents", "data"),
"data_cache_dir": os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",
@@ -19,4 +19,7 @@ "max_recur_limit": 100,
# Tool settings
"online_tools": True,
+
+ # Note: Database and cache configuration is now managed by .env file and config.database_manager
+ # No database/cache settings in default config to avoid configuration conflicts
}