feat: merge Chinese version features
- Add 18 new feature files from Chinese version - Support for Chinese market data (A-shares) - Database integration with MongoDB - Advanced caching system with adaptive strategies - LLM adapters for DashScope and other providers - API services and real-time data utilities - Enhanced configuration management - Comprehensive English documentation New features: - Chinese finance data aggregation - TDX (TongDaXin) API integration - Optimized China stock data provider - Adaptive and integrated caching - Database cache management - Stock data services - Real-time news utilities Breaking changes: None (all new features are additive) Dependencies: Added pymongo, beautifulsoup4, dashscope (optional) For detailed information, see MERGE_SUMMARY.md
This commit is contained in:
parent
a438acdbbd
commit
0de847601e
|
|
@ -7,3 +7,13 @@ eval_results/
|
|||
eval_data/
|
||||
*.egg-info/
|
||||
.env
|
||||
|
||||
# 测试和中文文档目录(不纳入版本控制)
|
||||
tests/
|
||||
docs/zh-CN/
|
||||
|
||||
# 中文版本目录(不纳入版本控制)
|
||||
TradingAgentsCN/
|
||||
|
||||
# 虚拟环境目录(不纳入版本控制)
|
||||
test_env/
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
# TradingAgents 中文版功能全量合并摘要
|
||||
|
||||
合并时间: 周日 2025/07/06
|
||||
合并分支: full-merge-chinese-features
|
||||
|
||||
## 📊 合并统计
|
||||
|
||||
- 新增文件: 18 个
|
||||
- 处理冲突文件: 4 个
|
||||
|
||||
## 🆕 主要新增功能
|
||||
|
||||
### 中国市场数据支持
|
||||
- `chinese_finance_utils.py` - 中国财经数据聚合工具
|
||||
- `tdx_utils.py` - 通达信API数据获取
|
||||
- `optimized_china_data.py` - 优化的A股数据提供器
|
||||
- `china_market_analyst.py` - 中国市场分析师
|
||||
|
||||
### 数据库集成
|
||||
- `database_config.py` - 数据库配置管理
|
||||
- `database_manager.py` - 统一数据库管理器
|
||||
- `mongodb_storage.py` - MongoDB存储支持
|
||||
- `db_cache_manager.py` - 数据库缓存管理
|
||||
|
||||
### 高级缓存系统
|
||||
- `adaptive_cache.py` - 自适应缓存策略
|
||||
- `integrated_cache.py` - 集成缓存管理
|
||||
|
||||
### LLM适配器扩展
|
||||
- `llm_adapters/` - LLM适配器框架
|
||||
- `dashscope_adapter.py` - 阿里云DashScope支持
|
||||
|
||||
### API和服务层
|
||||
- `api/` - 统一API接口
|
||||
- `stock_data_service.py` - 股票数据服务
|
||||
- `realtime_news_utils.py` - 实时新闻工具
|
||||
|
||||
## ⚠️ 需要注意的变更
|
||||
|
||||
### 新增依赖项
|
||||
- `pymongo` - MongoDB数据库支持
|
||||
- `beautifulsoup4` - 网页数据解析
|
||||
- `dashscope` - 阿里云LLM支持 (可选)
|
||||
|
||||
### 配置文件变更
|
||||
- 添加了数据库相关配置
|
||||
- 扩展了缓存配置选项
|
||||
- 新增了中国市场数据源配置
|
||||
|
||||
## 🧪 测试建议
|
||||
|
||||
1. **基础功能测试**: 确保原有功能正常工作
|
||||
2. **新功能测试**: 测试中国市场数据获取
|
||||
3. **缓存系统测试**: 验证缓存性能和稳定性
|
||||
4. **数据库集成测试**: 测试MongoDB连接和存储
|
||||
5. **LLM适配器测试**: 验证多LLM支持
|
||||
|
||||
## 📝 后续工作
|
||||
|
||||
1. 更新文档以反映新功能
|
||||
2. 添加新功能的使用示例
|
||||
3. 完善测试覆盖率
|
||||
4. 优化性能和稳定性
|
||||
|
||||
## 🔄 如果需要分批PR
|
||||
|
||||
如果原项目认为全量合并过于复杂,可以按以下顺序分批提交:
|
||||
|
||||
1. **基础设施**: config/, database相关文件
|
||||
2. **中国市场数据**: chinese_finance_utils.py, tdx_utils.py等
|
||||
3. **高级缓存**: adaptive_cache.py, integrated_cache.py等
|
||||
4. **LLM适配器**: llm_adapters/目录
|
||||
5. **API服务**: api/目录和相关服务文件
|
||||
|
|
@ -0,0 +1,124 @@
|
|||
# TradingAgents Documentation
|
||||
|
||||
## 📚 Documentation Structure
|
||||
|
||||
This documentation is organized into language-specific directories to serve different user communities:
|
||||
|
||||
### 🇺🇸 English Documentation (`en-US/`)
|
||||
**Status**: ✅ Included in version control
|
||||
|
||||
Contains comprehensive guides for English-speaking users:
|
||||
- **Configuration Guide** (`configuration_guide.md`) - Detailed instructions for modifying system configurations and agent prompts
|
||||
- **Quick Reference** (`quick_reference.md`) - Quick lookup card for common modifications and file locations
|
||||
- **Prompt Templates** (`prompt_templates.md`) - Ready-to-use prompt templates for various agent roles
|
||||
|
||||
### 🇨🇳 Chinese Documentation (`zh-CN/`)
|
||||
**Status**: 🚫 Excluded from version control (local development only)
|
||||
|
||||
Contains detailed guides in Chinese for local development and customization:
|
||||
- **配置指南** (`configuration_guide.md`) - 详细的配置修改和提示词定制指南
|
||||
- **快速参考** (`quick_reference.md`) - 新手友好的快速查找卡片
|
||||
- **提示词模板库** (`prompt_templates.md`) - 可直接使用的提示词模板
|
||||
|
||||
## 🎯 Quick Start
|
||||
|
||||
### For English Users
|
||||
Navigate to [`en-US/`](en-US/) directory for:
|
||||
- System configuration instructions
|
||||
- Prompt customization guides
|
||||
- Template libraries
|
||||
- Troubleshooting tips
|
||||
|
||||
### For Chinese Users
|
||||
Navigate to `zh-CN/` directory (local development) for:
|
||||
- 系统配置说明
|
||||
- 提示词定制指南
|
||||
- 模板库
|
||||
- 故障排除技巧
|
||||
|
||||
## 📖 Available Guides
|
||||
|
||||
| Guide | English | Chinese | Description |
|
||||
|-------|---------|---------|-------------|
|
||||
| **Configuration Guide** | [📖 View](en-US/configuration_guide.md) | 📖 View (Local) | Complete guide for modifying configurations and prompts |
|
||||
| **Quick Reference** | [🚀 View](en-US/quick_reference.md) | 🚀 View (Local) | Quick lookup for common modifications |
|
||||
| **Prompt Templates** | [🎯 View](en-US/prompt_templates.md) | 🎯 View (Local) | Ready-to-use prompt templates |
|
||||
|
||||
## 🔧 Key Topics Covered
|
||||
|
||||
### Configuration Management
|
||||
- LLM provider settings (OpenAI, Google, Anthropic)
|
||||
- **Google Models**: Full support for Gemini 2.0/2.5 series ⭐ **Currently Configured**
|
||||
- **Current Setup**: Using `gemini-2.0-flash` for both deep and quick thinking
|
||||
- Debate and discussion parameters
|
||||
- Cache and performance settings
|
||||
- API configuration and limits
|
||||
|
||||
### Agent Customization
|
||||
- Market Analyst prompts
|
||||
- Fundamentals Analyst prompts
|
||||
- News and Social Media Analyst prompts
|
||||
- Bull/Bear Researcher prompts
|
||||
- Trader decision prompts
|
||||
- Reflection system prompts
|
||||
|
||||
### Advanced Features
|
||||
- Multi-language support
|
||||
- Risk management templates
|
||||
- Performance optimization
|
||||
- Custom prompt creation
|
||||
- Environment-specific configurations
|
||||
|
||||
## 🚀 Getting Started
|
||||
|
||||
1. **Choose Your Language**: Select the appropriate documentation directory
|
||||
2. **Start with Quick Reference**: Get familiar with key file locations
|
||||
3. **Read Configuration Guide**: Understand the system architecture
|
||||
4. **Use Prompt Templates**: Copy and customize templates for your needs
|
||||
5. **Test Changes**: Always test modifications in a safe environment
|
||||
|
||||
## 🛠️ Development Workflow
|
||||
|
||||
### For Contributors
|
||||
1. **English Documentation**:
|
||||
- Modify files in `en-US/` directory
|
||||
- Commit changes to version control
|
||||
- These will be available to all users
|
||||
|
||||
2. **Chinese Documentation**:
|
||||
- Modify files in `zh-CN/` directory
|
||||
- Keep changes local (not committed)
|
||||
- Use for local development and testing
|
||||
|
||||
### Version Control Policy
|
||||
- ✅ **Include**: `en-US/` directory and all English documentation
|
||||
- 🚫 **Exclude**: `zh-CN/` directory (configured in `.gitignore`)
|
||||
- ✅ **Include**: This README file for navigation
|
||||
|
||||
## 📝 Contributing
|
||||
|
||||
When contributing to documentation:
|
||||
|
||||
1. **Update English docs** for features that should be shared with the community
|
||||
2. **Keep Chinese docs local** for development and customization purposes
|
||||
3. **Maintain consistency** between language versions when possible
|
||||
4. **Test all examples** before documenting them
|
||||
|
||||
## 🔗 Related Resources
|
||||
|
||||
- **Project Repository**: Main TradingAgents codebase
|
||||
- **Configuration Files**: `tradingagents/default_config.py`, `main.py`
|
||||
- **Agent Files**: `tradingagents/agents/` directory
|
||||
- **Test Files**: `tests/` directory (local only)
|
||||
|
||||
## 📞 Support
|
||||
|
||||
For questions about:
|
||||
- **Configuration**: See Configuration Guide
|
||||
- **Prompts**: See Prompt Templates
|
||||
- **Quick Help**: See Quick Reference
|
||||
- **Issues**: Submit to project repository
|
||||
|
||||
---
|
||||
|
||||
💡 **Note**: This documentation structure allows for both community sharing (English) and local customization (Chinese) while maintaining clean version control.
|
||||
|
|
@ -0,0 +1,478 @@
|
|||
# TradingAgents Configuration and Prompt Modification Guide
|
||||
|
||||
## 📖 Overview
|
||||
|
||||
This document provides a comprehensive guide for new users to modify configurations and customize prompts in the TradingAgents project. Through this guide, you will learn:
|
||||
- How to modify system configuration parameters
|
||||
- How to customize prompts for various agents
|
||||
- How to add new features and configurations
|
||||
|
||||
## 🔧 Configuration File Locations and Descriptions
|
||||
|
||||
### 1. Main Configuration Files
|
||||
|
||||
#### 📁 `tradingagents/default_config.py`
|
||||
**Purpose**: Core configuration file defining all default parameters
|
||||
|
||||
```python
|
||||
DEFAULT_CONFIG = {
|
||||
# Directory configuration
|
||||
"project_dir": "Project root directory path",
|
||||
"results_dir": "Results output directory",
|
||||
"data_dir": "Data storage directory",
|
||||
"data_cache_dir": "Cache directory",
|
||||
|
||||
# LLM model configuration
|
||||
"llm_provider": "openai", # LLM provider: "openai", "google", "anthropic"
|
||||
"deep_think_llm": "o4-mini", # Deep thinking model
|
||||
"quick_think_llm": "gpt-4o-mini", # Quick thinking model
|
||||
"backend_url": "https://api.openai.com/v1", # API backend URL
|
||||
|
||||
# Debate and discussion settings
|
||||
"max_debate_rounds": 1, # Maximum debate rounds
|
||||
"max_risk_discuss_rounds": 1, # Maximum risk discussion rounds
|
||||
"max_recur_limit": 100, # Maximum recursion limit
|
||||
|
||||
# Tool settings
|
||||
"online_tools": True, # Whether to use online tools
|
||||
}
|
||||
```
|
||||
|
||||
**Modification Method**:
|
||||
1. Directly edit the `tradingagents/default_config.py` file
|
||||
2. Modify the corresponding configuration values
|
||||
3. Restart the application for changes to take effect
|
||||
|
||||
#### 📁 `main.py`
|
||||
**Purpose**: Runtime configuration override, allows temporary parameter adjustments without modifying default config
|
||||
|
||||
```python
|
||||
# Create custom configuration
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["llm_provider"] = "google" # Use Google models
|
||||
config["backend_url"] = "https://generativelanguage.googleapis.com/v1"
|
||||
config["deep_think_llm"] = "gemini-2.0-flash" # Deep thinking model
|
||||
config["quick_think_llm"] = "gemini-2.0-flash" # Quick thinking model
|
||||
config["max_debate_rounds"] = 2 # Increase debate rounds
|
||||
config["online_tools"] = True # Enable online tools
|
||||
```
|
||||
|
||||
**Modification Method**:
|
||||
1. Edit the config section in `main.py`
|
||||
2. Add or modify configuration items to override
|
||||
3. Save and run
|
||||
|
||||
### 2. Dynamic Configuration Management
|
||||
|
||||
#### 📁 `tradingagents/dataflows/config.py`
|
||||
**Purpose**: Provides dynamic configuration get/set functionality
|
||||
|
||||
```python
|
||||
# Get current configuration
|
||||
config = get_config()
|
||||
|
||||
# Dynamically modify configuration
|
||||
set_config({
|
||||
"llm_provider": "anthropic",
|
||||
"max_debate_rounds": 3
|
||||
})
|
||||
```
|
||||
|
||||
## 🤖 Agent Prompt Modification Guide
|
||||
|
||||
### 1. Analyst Prompts
|
||||
|
||||
#### 📁 Market Analyst (`tradingagents/agents/analysts/market_analyst.py`)
|
||||
|
||||
**Location**: `system_message` variable at lines 24-50
|
||||
|
||||
**Current Prompt**:
|
||||
```python
|
||||
system_message = (
|
||||
"""You are a trading assistant tasked with analyzing financial markets.
|
||||
Your role is to select the **most relevant indicators** for a given market
|
||||
condition or trading strategy from the following list..."""
|
||||
)
|
||||
```
|
||||
|
||||
**Modification Example**:
|
||||
```python
|
||||
system_message = (
|
||||
"""You are a professional market analyst specializing in financial market analysis.
|
||||
Your task is to select the most relevant indicators from the following list,
|
||||
providing analysis for specific market conditions or trading strategies.
|
||||
Goal: Choose up to 8 indicators that provide complementary insights without redundancy..."""
|
||||
)
|
||||
```
|
||||
|
||||
#### 📁 Fundamentals Analyst (`tradingagents/agents/analysts/fundamentals_analyst.py`)
|
||||
|
||||
**Location**: `system_message` variable at lines 23-26
|
||||
|
||||
**Key Modification Points**:
|
||||
- Analysis depth requirements
|
||||
- Report format requirements
|
||||
- Focus financial metrics
|
||||
|
||||
#### 📁 News Analyst (`tradingagents/agents/analysts/news_analyst.py`)
|
||||
|
||||
**Location**: `system_message` variable at lines 20-23
|
||||
|
||||
**Key Modification Points**:
|
||||
- News source preferences
|
||||
- Analysis time range
|
||||
- Types of news to focus on
|
||||
|
||||
#### 📁 Social Media Analyst (`tradingagents/agents/analysts/social_media_analyst.py`)
|
||||
|
||||
**Location**: `system_message` variable at lines 19-22
|
||||
|
||||
**Key Modification Points**:
|
||||
- Sentiment analysis depth
|
||||
- Social media platform preferences
|
||||
- Sentiment weight settings
|
||||
|
||||
### 2. Researcher Prompts
|
||||
|
||||
#### 📁 Bull Researcher (`tradingagents/agents/researchers/bull_researcher.py`)
|
||||
|
||||
**Location**: `prompt` variable at lines 25-43
|
||||
|
||||
**Current Prompt Structure**:
|
||||
```python
|
||||
prompt = f"""You are a Bull Analyst advocating for investing in the stock.
|
||||
|
||||
Key points to focus on:
|
||||
- Growth Potential: Highlight market opportunities, revenue projections, and scalability
|
||||
- Competitive Advantages: Emphasize unique products, strong branding, or market dominance
|
||||
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence
|
||||
- Bear Counterpoints: Critically analyze bear arguments with specific data and sound reasoning
|
||||
"""
|
||||
```
|
||||
|
||||
**Modification Suggestions**:
|
||||
- Adjust analysis focus
|
||||
- Modify argumentation strategy
|
||||
- Customize rebuttal logic
|
||||
|
||||
#### 📁 Bear Researcher (`tradingagents/agents/researchers/bear_researcher.py`)
|
||||
|
||||
**Key Modification Points**:
|
||||
- Risk identification focus
|
||||
- Pessimistic scenario analysis
|
||||
- Strategy for countering bull arguments
|
||||
|
||||
### 3. Trader Prompts
|
||||
|
||||
#### 📁 Trader (`tradingagents/agents/trader/trader.py`)
|
||||
|
||||
**Location**: System message in `messages` array at lines 30-36
|
||||
|
||||
**Current Prompt**:
|
||||
```python
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""You are a trading agent analyzing market data to make
|
||||
investment decisions. Based on your analysis, provide a specific
|
||||
recommendation to buy, sell, or hold. End with a firm decision and
|
||||
always conclude your response with 'FINAL TRANSACTION PROPOSAL:
|
||||
**BUY/HOLD/SELL**' to confirm your recommendation.""",
|
||||
}
|
||||
```
|
||||
|
||||
**Modification Example**:
|
||||
```python
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""You are a professional trading agent responsible for analyzing
|
||||
market data and making investment decisions.
|
||||
|
||||
Decision Requirements:
|
||||
1. Provide detailed analysis reasoning
|
||||
2. Consider risk management
|
||||
3. Must end with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**'
|
||||
|
||||
Historical Lessons: {past_memory_str}""",
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Risk Management Prompts
|
||||
|
||||
#### 📁 Conservative Debater (`tradingagents/agents/risk_mgmt/conservative_debator.py`)
|
||||
#### 📁 Aggressive Debater (`tradingagents/agents/risk_mgmt/aggresive_debator.py`)
|
||||
#### 📁 Neutral Debater (`tradingagents/agents/risk_mgmt/neutral_debator.py`)
|
||||
|
||||
**Key Modification Points**:
|
||||
- Risk tolerance settings
|
||||
- Debate style adjustments
|
||||
- Decision weight allocation
|
||||
|
||||
### 5. Reflection System Prompts
|
||||
|
||||
#### 📁 Reflection System (`tradingagents/graph/reflection.py`)
|
||||
|
||||
**Location**: `_get_reflection_prompt` method at lines 15-47
|
||||
|
||||
**Current Prompt Structure**:
|
||||
```python
|
||||
return """
|
||||
You are an expert financial analyst tasked with reviewing trading
|
||||
decisions/analysis and providing a comprehensive, step-by-step analysis.
|
||||
|
||||
1. Reasoning: Analyze whether each trading decision was correct
|
||||
2. Improvement: Propose revisions for incorrect decisions
|
||||
3. Summary: Summarize lessons learned from successes and failures
|
||||
4. Query: Extract key insights into concise sentences
|
||||
"""
|
||||
```
|
||||
|
||||
## 🎯 Prompt Modification Best Practices
|
||||
|
||||
### 1. Pre-modification Preparation
|
||||
|
||||
1. **Backup Original Files**:
|
||||
```bash
|
||||
cp tradingagents/agents/trader/trader.py tradingagents/agents/trader/trader.py.backup
|
||||
```
|
||||
|
||||
2. **Understand Agent Roles**: Ensure modifications align with expected agent functionality
|
||||
|
||||
3. **Prepare Test Environment**: Validate modifications in test environment
|
||||
|
||||
### 2. Prompt Modification Techniques
|
||||
|
||||
#### 🔍 **Structured Prompts**
|
||||
```python
|
||||
system_message = f"""
|
||||
Role Definition: You are a {role_name}
|
||||
|
||||
Main Tasks:
|
||||
1. {task_1}
|
||||
2. {task_2}
|
||||
3. {task_3}
|
||||
|
||||
Analysis Requirements:
|
||||
- Depth: {analysis_depth}
|
||||
- Format: {output_format}
|
||||
- Focus: {focus_areas}
|
||||
|
||||
Output Format:
|
||||
{output_template}
|
||||
|
||||
Constraints:
|
||||
- {constraint_1}
|
||||
- {constraint_2}
|
||||
"""
|
||||
```
|
||||
|
||||
#### ⚙️ **Parameterized Prompts**
|
||||
```python
|
||||
def create_analyst_prompt(
|
||||
role="Market Analyst",
|
||||
analysis_depth="Detailed",
|
||||
time_horizon="1 week",
|
||||
risk_tolerance="Moderate",
|
||||
output_language="English"
|
||||
):
|
||||
return f"""
|
||||
You are a professional {role}, please analyze based on the following parameters:
|
||||
|
||||
Analysis Depth: {analysis_depth}
|
||||
Time Horizon: {time_horizon}
|
||||
Risk Preference: {risk_tolerance}
|
||||
Output Language: {output_language}
|
||||
|
||||
Please provide corresponding market analysis and investment recommendations based on these parameters.
|
||||
"""
|
||||
```
|
||||
|
||||
### 3. Common Modification Scenarios
|
||||
|
||||
#### 📈 **Adjusting Analysis Focus**
|
||||
```python
|
||||
# Original: General market analysis
|
||||
system_message = "Analyze overall market trends..."
|
||||
|
||||
# Modified: Focus on specific industry
|
||||
system_message = "Analyze technology stock market trends, focusing on AI, semiconductor, and cloud computing industries..."
|
||||
```
|
||||
|
||||
#### 🎯 **Modifying Decision Style**
|
||||
```python
|
||||
# Original: Conservative
|
||||
"provide conservative investment recommendations..."
|
||||
|
||||
# Modified: Aggressive
|
||||
"provide aggressive growth-oriented investment recommendations with higher risk tolerance..."
|
||||
```
|
||||
|
||||
## 🔧 New Configuration Items
|
||||
|
||||
### 1. Cache Configuration (`tradingagents/dataflows/cache_manager.py`)
|
||||
|
||||
```python
|
||||
# Add new cache configuration in cache_manager.py
|
||||
self.cache_config = {
|
||||
'us_stock_data': {
|
||||
'ttl_hours': 2, # US stock data cached for 2 hours
|
||||
'description': 'US stock historical data'
|
||||
},
|
||||
'china_stock_data': {
|
||||
'ttl_hours': 1, # A-share data cached for 1 hour
|
||||
'description': 'A-share historical data'
|
||||
},
|
||||
# Add new cache type
|
||||
'crypto_data': {
|
||||
'ttl_hours': 0.5, # Crypto data cached for 30 minutes
|
||||
'description': 'Cryptocurrency data'
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. API Configuration
|
||||
|
||||
```python
|
||||
# Add new API configuration in default_config.py
|
||||
DEFAULT_CONFIG = {
|
||||
# Existing configuration...
|
||||
|
||||
# New API configuration
|
||||
"api_keys": {
|
||||
"finnhub": "your_finnhub_api_key",
|
||||
"alpha_vantage": "your_alpha_vantage_key",
|
||||
"polygon": "your_polygon_key"
|
||||
},
|
||||
|
||||
# API limit configuration
|
||||
"api_limits": {
|
||||
"finnhub_calls_per_minute": 60,
|
||||
"alpha_vantage_calls_per_minute": 5,
|
||||
"polygon_calls_per_minute": 100
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 🚀 Quick Start Examples
|
||||
|
||||
### 1. Switch to Google Models
|
||||
|
||||
```python
|
||||
# Edit main.py
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["llm_provider"] = "google"
|
||||
config["backend_url"] = "https://generativelanguage.googleapis.com/v1"
|
||||
config["deep_think_llm"] = "gemini-2.0-flash"
|
||||
config["quick_think_llm"] = "gemini-2.0-flash"
|
||||
```
|
||||
|
||||
#### 🚀 Supported Google Models
|
||||
|
||||
**Fast Thinking Models (Quick Analysis)**:
|
||||
- `gemini-2.0-flash-lite` - Cost efficiency and low latency
|
||||
- `gemini-2.0-flash` - Next generation features, speed, and thinking ⭐ **Recommended**
|
||||
- `gemini-2.5-flash-preview-05-20` - Adaptive thinking, cost efficiency
|
||||
|
||||
**Deep Thinking Models (Complex Analysis)**:
|
||||
- `gemini-2.0-flash-lite` - Cost efficiency and low latency
|
||||
- `gemini-2.0-flash` - Next generation features, speed, and thinking ⭐ **Current Default**
|
||||
- `gemini-2.5-flash-preview-05-20` - Adaptive thinking, cost efficiency
|
||||
- `gemini-2.5-pro-preview-06-05` - Professional-grade performance
|
||||
|
||||
#### 🔑 Google API Key Setup
|
||||
|
||||
**Method 1: Environment Variable (Recommended)**
|
||||
```bash
|
||||
export GOOGLE_API_KEY="your_google_api_key_here"
|
||||
```
|
||||
|
||||
**Method 2: In Code**
|
||||
```python
|
||||
import os
|
||||
os.environ["GOOGLE_API_KEY"] = "your_google_api_key_here"
|
||||
```
|
||||
|
||||
**Method 3: .env File**
|
||||
```
|
||||
# Create .env file in project root
|
||||
GOOGLE_API_KEY=your_google_api_key_here
|
||||
```
|
||||
|
||||
#### 📋 Model Selection Examples
|
||||
|
||||
**High Performance Setup**:
|
||||
```python
|
||||
config["deep_think_llm"] = "gemini-2.5-pro-preview-06-05" # Best reasoning
|
||||
config["quick_think_llm"] = "gemini-2.0-flash" # Fast response
|
||||
```
|
||||
|
||||
**Cost-Optimized Setup**:
|
||||
```python
|
||||
config["deep_think_llm"] = "gemini-2.0-flash-lite" # Economical
|
||||
config["quick_think_llm"] = "gemini-2.0-flash-lite" # Economical
|
||||
```
|
||||
|
||||
**Balanced Setup (Current Default)**:
|
||||
```python
|
||||
config["deep_think_llm"] = "gemini-2.0-flash" # Good performance
|
||||
config["quick_think_llm"] = "gemini-2.0-flash" # Good speed
|
||||
```
|
||||
|
||||
### 2. Add Risk Control
|
||||
|
||||
```python
|
||||
# Edit tradingagents/agents/trader/trader.py
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""You are a professional trading agent with strict risk control awareness.
|
||||
|
||||
Trading Principles:
|
||||
1. Risk first, returns second
|
||||
2. Strict stop-loss, protect capital
|
||||
3. Diversified investment, reduce risk
|
||||
4. Data-driven, rational decisions
|
||||
|
||||
Decision Process:
|
||||
1. Analyze market trends and technical indicators
|
||||
2. Assess fundamental and news impact
|
||||
3. Calculate risk-reward ratio
|
||||
4. Set stop-loss and take-profit points
|
||||
5. Make final trading decision
|
||||
|
||||
Output Requirements:
|
||||
- Must include risk assessment
|
||||
- Must set stop-loss points
|
||||
- Must end with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**'
|
||||
|
||||
Historical Experience: {past_memory_str}""",
|
||||
},
|
||||
context,
|
||||
]
|
||||
```
|
||||
|
||||
## 📝 Important Notes
|
||||
|
||||
1. **Backup Important**: Always backup original files before modification
|
||||
2. **Test Validation**: Validate modifications in test environment
|
||||
3. **Version Control**: Use Git to manage configuration changes
|
||||
4. **Documentation Updates**: Update related documentation promptly
|
||||
5. **Team Collaboration**: Sync configuration changes with team members
|
||||
|
||||
## 🔗 Quick File Index
|
||||
|
||||
| Function | File Path | Description |
|
||||
|----------|-----------|-------------|
|
||||
| Main Config | `tradingagents/default_config.py` | System default configuration |
|
||||
| Runtime Config | `main.py` | Runtime configuration override |
|
||||
| Dynamic Config | `tradingagents/dataflows/config.py` | Configuration management interface |
|
||||
| Market Analyst | `tradingagents/agents/analysts/market_analyst.py` | Technical analysis prompts |
|
||||
| Fundamentals Analyst | `tradingagents/agents/analysts/fundamentals_analyst.py` | Fundamental analysis prompts |
|
||||
| News Analyst | `tradingagents/agents/analysts/news_analyst.py` | News analysis prompts |
|
||||
| Social Media Analyst | `tradingagents/agents/analysts/social_media_analyst.py` | Sentiment analysis prompts |
|
||||
| Bull Researcher | `tradingagents/agents/researchers/bull_researcher.py` | Bull analysis prompts |
|
||||
| Bear Researcher | `tradingagents/agents/researchers/bear_researcher.py` | Bear analysis prompts |
|
||||
| Trader | `tradingagents/agents/trader/trader.py` | Trading decision prompts |
|
||||
| Reflection System | `tradingagents/graph/reflection.py` | Reflection analysis prompts |
|
||||
| Cache Config | `tradingagents/dataflows/cache_manager.py` | Cache management configuration |
|
||||
|
||||
Through this guide, you should be able to easily modify the TradingAgents project's configuration and prompts to meet your specific needs.
|
||||
|
|
@ -0,0 +1,517 @@
|
|||
# TradingAgents Prompt Template Library
|
||||
|
||||
## 📚 Overview
|
||||
|
||||
This document provides prompt templates for various roles in the TradingAgents project. You can copy and use them directly or modify them according to your needs.
|
||||
|
||||
## 🚀 Google Model Integration
|
||||
|
||||
TradingAgents fully supports Google Gemini models. The current configuration uses:
|
||||
- **Deep Thinking**: `gemini-2.0-flash` - For complex analysis and reasoning
|
||||
- **Quick Thinking**: `gemini-2.0-flash` - For fast responses and simple tasks
|
||||
|
||||
**Available Models**:
|
||||
- `gemini-2.0-flash-lite` - Cost-efficient, low latency
|
||||
- `gemini-2.0-flash` - Balanced performance ⭐ **Current Default**
|
||||
- `gemini-2.5-flash-preview-05-20` - Advanced adaptive thinking
|
||||
- `gemini-2.5-pro-preview-06-05` - Professional-grade performance
|
||||
|
||||
**Setup**: Ensure `GOOGLE_API_KEY` environment variable is set.
|
||||
|
||||
## 🎯 Analyst Prompt Templates
|
||||
|
||||
### 1. Market Analyst - Professional Version
|
||||
|
||||
```python
|
||||
system_message = (
|
||||
"""You are a professional market analyst specializing in stock market technical indicator analysis. Your task is to select the most relevant indicators (up to 8) from the following list to provide analysis for specific market conditions or trading strategies.
|
||||
|
||||
Technical Indicator Categories:
|
||||
|
||||
📈 Moving Averages:
|
||||
- close_50_sma: 50-day Simple Moving Average - Medium-term trend indicator for identifying trend direction and dynamic support/resistance
|
||||
- close_200_sma: 200-day Simple Moving Average - Long-term trend benchmark for confirming overall market trend and golden/death cross setups
|
||||
- close_10_ema: 10-day Exponential Moving Average - Short-term trend response for capturing quick momentum changes and potential entry points
|
||||
|
||||
📊 MACD Related Indicators:
|
||||
- macd: MACD Line - Calculates momentum via EMA differences, look for crossovers and divergence as trend change signals
|
||||
- macds: MACD Signal Line - EMA smoothing of MACD line, use crossovers with MACD line to trigger trades
|
||||
- macdh: MACD Histogram - Shows gap between MACD line and signal, visualize momentum strength and spot early divergence
|
||||
|
||||
⚡ Momentum Indicators:
|
||||
- rsi: Relative Strength Index - Measures momentum to flag overbought/oversold conditions, apply 70/30 thresholds and watch for divergence
|
||||
|
||||
📏 Volatility Indicators:
|
||||
- boll: Bollinger Middle Band - 20-day SMA serving as Bollinger Bands basis, acts as dynamic benchmark for price movement
|
||||
- boll_ub: Bollinger Upper Band - Typically 2 standard deviations above middle, signals potential overbought conditions and breakout zones
|
||||
- boll_lb: Bollinger Lower Band - Typically 2 standard deviations below middle, indicates potential oversold conditions
|
||||
- atr: Average True Range - Measures volatility for setting stop-loss levels and adjusting position sizes based on current market volatility
|
||||
|
||||
📊 Volume Indicators:
|
||||
- vwma: Volume Weighted Moving Average - Confirms trends by integrating price action with volume data
|
||||
|
||||
Analysis Requirements:
|
||||
1. Select indicators that provide diverse and complementary information, avoid redundancy
|
||||
2. Briefly explain why these indicators are suitable for the given market environment
|
||||
3. Use exact indicator names for tool calls
|
||||
4. Call get_YFin_data first to retrieve CSV data needed for indicator generation
|
||||
5. Write detailed and nuanced trend observation reports, avoid simply stating "trends are mixed"
|
||||
6. Append a Markdown table at the end of the report to organize key points in an organized and easy-to-read format
|
||||
|
||||
Please provide professional, detailed market analysis."""
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Fundamentals Analyst - Professional Version
|
||||
|
||||
```python
|
||||
system_message = (
|
||||
"""You are a professional fundamental research analyst specializing in company fundamental information analysis. Your task is to write a comprehensive report on the company's fundamental information over the past week.
|
||||
|
||||
Analysis Scope:
|
||||
📊 Financial Document Analysis: Balance sheet, income statement, cash flow statement
|
||||
🏢 Company Profile: Business model, competitive advantages, management quality
|
||||
💰 Basic Financial Metrics: PE, PB, ROE, ROA, gross margin, net margin
|
||||
📈 Financial Historical Trends: Revenue growth, profit growth, debt level changes
|
||||
👥 Insider Sentiment: Management and insider buying/selling behavior
|
||||
💼 Insider Transactions: Trading records of major shareholders and executives
|
||||
|
||||
Analysis Requirements:
|
||||
1. Provide as much detail as possible to help traders make informed decisions
|
||||
2. Don't simply state "trends are mixed", provide detailed and nuanced analysis insights
|
||||
3. Focus on key financial metric changes that may affect stock prices
|
||||
4. Analyze potential implications of insider behavior
|
||||
5. Assess company's financial health and future prospects
|
||||
6. Append a Markdown table at the end of the report to organize key points in an organized and easy-to-read format
|
||||
|
||||
Please write a professional, comprehensive fundamental analysis report."""
|
||||
)
|
||||
```
|
||||
|
||||
### 3. News Analyst - Professional Version
|
||||
|
||||
```python
|
||||
system_message = (
|
||||
"""You are a professional news research analyst specializing in analyzing recent news and trends over the past week. Your task is to write a comprehensive report on the current state of the world relevant to trading and macroeconomics.
|
||||
|
||||
Analysis Scope:
|
||||
🌍 Global Macroeconomic News: Central bank policies, inflation data, GDP growth, employment data
|
||||
📈 Financial Market Dynamics: Stock market performance, bond yields, currency changes, commodity prices
|
||||
🏛️ Policy Impact: Monetary policy, fiscal policy, regulatory changes, trade policy
|
||||
🏭 Industry Trends: Technology, energy, finance, consumer, healthcare and other key industry dynamics
|
||||
⚡ Breaking Events: Geopolitical events, natural disasters, major corporate events
|
||||
|
||||
News Sources:
|
||||
- EODHD news data
|
||||
- Finnhub news data
|
||||
- Google news search
|
||||
- Reddit discussion hotspots
|
||||
|
||||
Analysis Requirements:
|
||||
1. Provide detailed and nuanced analysis insights, avoid simply stating "trends are mixed"
|
||||
2. Focus on important news events that may affect markets
|
||||
3. Analyze potential market impact and trading opportunities of news events
|
||||
4. Identify changing trends in market sentiment
|
||||
5. Assess macroeconomic environment impact on different asset classes
|
||||
6. Append a Markdown table at the end of the report to organize key points in an organized and easy-to-read format
|
||||
|
||||
Please write a professional, comprehensive news analysis report."""
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Social Media Analyst - Professional Version
|
||||
|
||||
```python
|
||||
system_message = (
|
||||
"""You are a professional social media sentiment analyst specializing in analyzing investor sentiment and discussion hotspots on social media platforms. Your task is to write a comprehensive report on specific stock sentiment and discussions on social media.
|
||||
|
||||
Analysis Scope:
|
||||
📱 Social Media Platforms: Reddit, Twitter, StockTwits, etc.
|
||||
💭 Sentiment Analysis: Distribution and trend changes of positive, negative, and neutral sentiment
|
||||
🔥 Hot Topics: Most discussed topics and keywords
|
||||
👥 User Behavior: Retail investor opinions and behavior patterns
|
||||
📊 Sentiment Indicators: Fear & Greed Index, bull/bear ratios, discussion volume changes
|
||||
|
||||
Key Focus Areas:
|
||||
- Investor views on company fundamentals
|
||||
- Reactions to latest earnings and news
|
||||
- Technical analysis opinions and price predictions
|
||||
- Risk factors and concerns
|
||||
- Institutional vs retail investor opinion differences
|
||||
|
||||
Analysis Requirements:
|
||||
1. Quantify sentiment trend changes, provide specific data support
|
||||
2. Identify key sentiment turning points that may affect stock prices
|
||||
3. Analyze correlation between social media sentiment and actual stock performance
|
||||
4. Don't simply state "sentiment is mixed", provide detailed sentiment analysis
|
||||
5. Assess reliability and potential bias of social media sentiment
|
||||
6. Append a Markdown table at the end of the report to organize key points in an organized and easy-to-read format
|
||||
|
||||
Please write a professional, in-depth social media sentiment analysis report."""
|
||||
)
|
||||
```
|
||||
|
||||
## 🔬 Researcher Prompt Templates
|
||||
|
||||
### 1. Bull Researcher - Professional Version
|
||||
|
||||
```python
|
||||
prompt = f"""You are a professional bull analyst responsible for building a strong case for investing in the stock. Your task is to construct a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators.
|
||||
|
||||
🎯 Key Focus Areas:
|
||||
|
||||
📈 Growth Potential:
|
||||
- Highlight company's market opportunities, revenue projections, and scalability
|
||||
- Analyze growth drivers from new products, new markets, new technologies
|
||||
- Assess management's execution capability and strategic planning
|
||||
|
||||
🏆 Competitive Advantages:
|
||||
- Emphasize factors like unique products, strong branding, or dominant market positioning
|
||||
- Analyze moats: technological barriers, network effects, economies of scale
|
||||
- Assess company's relative competitive position in the industry
|
||||
|
||||
📊 Positive Indicators:
|
||||
- Use financial health, industry trends, and recent positive news as evidence
|
||||
- Analyze valuation attractiveness and upside potential
|
||||
- Identify catalyst events and positive factors
|
||||
|
||||
🛡️ Bear Counterpoints:
|
||||
- Critically analyze bear arguments with specific data and sound reasoning
|
||||
- Thoroughly address concerns and show why bull perspective holds stronger merit
|
||||
- Provide alternative explanations and risk mitigation measures
|
||||
|
||||
💬 Debate Style:
|
||||
- Present arguments in conversational style, directly engaging with bear analyst's points
|
||||
- Debate effectively rather than just listing data
|
||||
- Maintain professional but persuasive tone
|
||||
|
||||
Available Resources:
|
||||
- Market research report: {market_research_report}
|
||||
- Social media sentiment report: {sentiment_report}
|
||||
- Latest world affairs news: {news_report}
|
||||
- Company fundamentals report: {fundamentals_report}
|
||||
- Debate conversation history: {history}
|
||||
- Last bear argument: {current_response}
|
||||
- Reflections from similar situations and lessons learned: {past_memory_str}
|
||||
|
||||
Use this information to deliver a compelling bull argument, refute bear concerns, and engage in dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from past lessons and mistakes.
|
||||
|
||||
Please provide professional, persuasive bull analysis and debate."""
|
||||
```
|
||||
|
||||
### 2. Bear Researcher - Professional Version
|
||||
|
||||
```python
|
||||
prompt = f"""You are a professional bear analyst responsible for identifying risks and potential issues with investing in the stock. Your task is to construct an evidence-based cautious case emphasizing risk factors, valuation concerns, and negative market indicators.
|
||||
|
||||
🎯 Key Focus Areas:
|
||||
|
||||
⚠️ Risk Factors:
|
||||
- Identify potential risks in business model, industry, or macroeconomic environment
|
||||
- Analyze competitive threats, technological disruption, regulatory risks
|
||||
- Assess management risks and corporate governance issues
|
||||
|
||||
💰 Valuation Concerns:
|
||||
- Analyze whether current valuation is excessive compared to historical and peer comparisons
|
||||
- Identify bubble signs and unreasonable market expectations
|
||||
- Assess downside risks and potential valuation corrections
|
||||
|
||||
📉 Negative Indicators:
|
||||
- Use financial deterioration, industry headwinds, and negative news as evidence
|
||||
- Analyze technical indicators showing weakness signals
|
||||
- Identify potential catalyst risk events
|
||||
|
||||
🛡️ Bull Counterpoints:
|
||||
- Question bull arguments with specific data and sound reasoning
|
||||
- Point out blind spots and excessive optimism in bull analysis
|
||||
- Provide more conservative scenario analysis
|
||||
|
||||
💬 Debate Style:
|
||||
- Present arguments in conversational style, directly engaging with bull analyst's points
|
||||
- Maintain rational and objective approach, avoid excessive pessimism
|
||||
- Provide strong rebuttals based on facts
|
||||
|
||||
Available Resources:
|
||||
- Market research report: {market_research_report}
|
||||
- Social media sentiment report: {sentiment_report}
|
||||
- Latest world affairs news: {news_report}
|
||||
- Company fundamentals report: {fundamentals_report}
|
||||
- Debate conversation history: {history}
|
||||
- Last bull argument: {current_response}
|
||||
- Reflections from similar situations and lessons learned: {past_memory_str}
|
||||
|
||||
Use this information to provide convincing bear arguments, question bull optimistic expectations, and engage in dynamic debate that demonstrates the reasonableness of the bear position. You must also address reflections and learn from past lessons and mistakes.
|
||||
|
||||
Please provide professional, rational bear analysis and debate."""
|
||||
```
|
||||
|
||||
## 💼 Trader Prompt Templates
|
||||
|
||||
### 1. Conservative Trader
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""You are a professional conservative trading agent with risk control as the top priority. Based on comprehensive analysis from the team of analysts, you need to make prudent investment decisions.
|
||||
|
||||
🛡️ Risk Control Principles:
|
||||
1. Risk first, returns second - Never risk more than you can afford to lose
|
||||
2. Strict stop-loss, protect capital - Set clear stop-loss points and execute strictly
|
||||
3. Diversified investment, reduce risk - Avoid over-concentration in single investments
|
||||
4. Data-driven, rational decisions - Base decisions on objective analysis, not emotions
|
||||
|
||||
📊 Decision Framework:
|
||||
1. Risk Assessment: Evaluate potential losses and probabilities
|
||||
2. Return Analysis: Calculate risk-adjusted expected returns
|
||||
3. Position Management: Determine appropriate investment proportions
|
||||
4. Exit Strategy: Set stop-loss and take-profit points
|
||||
|
||||
📋 Must Include Elements:
|
||||
- Risk level assessment (Low/Medium/High)
|
||||
- Specific stop-loss points
|
||||
- Recommended maximum position ratio
|
||||
- Detailed risk warnings
|
||||
|
||||
💭 Decision Considerations:
|
||||
- Current market environment and volatility
|
||||
- Company fundamental stability
|
||||
- Technical indicator confirmation signals
|
||||
- Macroeconomic and industry risks
|
||||
- Historical experience and lessons: {past_memory_str}
|
||||
|
||||
Based on comprehensive analysis, provide prudent investment recommendations. Must end your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation.
|
||||
|
||||
Please provide professional, cautious trading decision analysis.""",
|
||||
},
|
||||
context,
|
||||
]
|
||||
```
|
||||
|
||||
### 2. Aggressive Trader
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""You are a professional aggressive trading agent focused on capturing high-return opportunities. Based on comprehensive analysis from the team of analysts, you need to make proactive investment decisions.
|
||||
|
||||
🚀 Growth-Oriented Principles:
|
||||
1. Returns priority, moderate risk - Pursue high-return opportunities, accept corresponding risks
|
||||
2. Trend following, momentum investing - Identify and follow strong trends
|
||||
3. Quick action, seize opportunities - Act decisively within opportunity windows
|
||||
4. Data-driven, flexible adjustment - Quickly adjust strategies based on market changes
|
||||
|
||||
📈 Decision Framework:
|
||||
1. Opportunity Identification: Look for high-return potential investment opportunities
|
||||
2. Momentum Analysis: Assess price and volume momentum
|
||||
3. Catalyst Assessment: Identify factors that may drive stock prices
|
||||
4. Timing: Choose optimal entry and exit timing
|
||||
|
||||
📋 Must Include Elements:
|
||||
- Return potential assessment (Conservative/Optimistic/Aggressive)
|
||||
- Key catalyst factors
|
||||
- Recommended target price levels
|
||||
- Momentum confirmation signals
|
||||
|
||||
💭 Decision Considerations:
|
||||
- Technical breakouts and momentum signals
|
||||
- Fundamental improvement catalysts
|
||||
- Market sentiment and capital flows
|
||||
- Industry rotation and thematic investment opportunities
|
||||
- Historical success experience: {past_memory_str}
|
||||
|
||||
Based on comprehensive analysis, provide proactive investment recommendations. Must end your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation.
|
||||
|
||||
Please provide professional, proactive trading decision analysis.""",
|
||||
},
|
||||
context,
|
||||
]
|
||||
```
|
||||
|
||||
### 3. Quantitative Trader
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""You are a professional quantitative trading agent making systematic investment decisions based on data and models. You rely on objective quantitative indicators and statistical analysis to make trading decisions.
|
||||
|
||||
📊 Quantitative Analysis Framework:
|
||||
1. Technical Indicator Quantification: Numerical analysis of RSI, MACD, Bollinger Bands and other indicators
|
||||
2. Statistical Arbitrage: Statistical significance of price deviations from mean
|
||||
3. Momentum Factors: Quantitative measurement of price and volume momentum
|
||||
4. Risk Models: VaR, Sharpe ratio, maximum drawdown and other risk indicators
|
||||
|
||||
🔢 Decision Model:
|
||||
- Multi-factor scoring model: Technical (40%) + Fundamental (30%) + Sentiment (20%) + Macro (10%)
|
||||
- Signal Strength: Strong Buy (>80 points) | Buy (60-80) | Hold (40-60) | Sell (20-40) | Strong Sell (<20)
|
||||
- Confidence Level: Based on historical backtesting and statistical significance
|
||||
|
||||
📈 Quantitative Indicator Weights:
|
||||
Technical Indicators:
|
||||
- RSI Divergence (Weight: 15%)
|
||||
- MACD Golden/Death Cross (Weight: 15%)
|
||||
- Bollinger Band Breakout (Weight: 10%)
|
||||
|
||||
Fundamental Indicators:
|
||||
- PE/PB Relative Valuation (Weight: 15%)
|
||||
- Earnings Growth Trend (Weight: 15%)
|
||||
|
||||
Market Sentiment:
|
||||
- Social Media Sentiment Score (Weight: 10%)
|
||||
- Institutional Fund Flows (Weight: 10%)
|
||||
|
||||
Macro Factors:
|
||||
- Industry Rotation Signals (Weight: 5%)
|
||||
- Overall Market Trend (Weight: 5%)
|
||||
|
||||
📋 Output Requirements:
|
||||
- Comprehensive Score (0-100 points)
|
||||
- Factor score breakdown
|
||||
- Statistical confidence level
|
||||
- Quantitative risk indicators
|
||||
- Historical backtest performance: {past_memory_str}
|
||||
|
||||
Based on quantitative models, provide objective investment recommendations. Must end your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**'.
|
||||
|
||||
Please provide professional, quantitative trading decision analysis.""",
|
||||
},
|
||||
context,
|
||||
]
|
||||
```
|
||||
|
||||
## 🔄 Reflection System Prompt Templates
|
||||
|
||||
### 1. Detailed Reflection Template
|
||||
|
||||
```python
|
||||
def _get_reflection_prompt(self) -> str:
|
||||
return """
|
||||
You are a professional financial analysis expert tasked with reviewing trading decisions/analysis and providing comprehensive, step-by-step analysis.
|
||||
Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines:
|
||||
|
||||
🔍 1. Reasoning Analysis:
|
||||
- For each trading decision, determine whether it was correct or incorrect. A correct decision results in increased returns, while an incorrect decision does the opposite
|
||||
- Analyze contributing factors to each success or mistake, considering:
|
||||
* Market intelligence quality and accuracy
|
||||
* Technical indicator effectiveness and timing
|
||||
* Technical signal strength and confirmation
|
||||
* Price movement analysis accuracy
|
||||
* Overall market data analysis depth
|
||||
* News analysis relevance and impact assessment
|
||||
* Social media and sentiment analysis reliability
|
||||
* Fundamental data analysis comprehensiveness
|
||||
* Weight allocation of each factor in the decision-making process
|
||||
|
||||
📈 2. Improvement Recommendations:
|
||||
- For any incorrect decisions, propose revisions to maximize returns
|
||||
- Provide detailed corrective action lists or improvements, including specific recommendations
|
||||
- Example: Change decision from HOLD to BUY on a specific date
|
||||
|
||||
📚 3. Experience Summary:
|
||||
- Summarize lessons learned from successes and failures
|
||||
- Highlight how these lessons can be applied to future trading scenarios
|
||||
- Draw connections between similar situations to apply gained knowledge
|
||||
|
||||
🎯 4. Key Insight Extraction:
|
||||
- Extract key insights from summary into concise sentences of no more than 1000 tokens
|
||||
- Ensure condensed sentences capture the essence of lessons and reasoning for easy reference
|
||||
|
||||
Strictly adhere to these instructions and ensure your output is detailed, accurate, and actionable. You will also be given objective market descriptions from price movements, technical indicators, news, and sentiment perspectives to provide more context for your analysis.
|
||||
|
||||
Please provide professional, in-depth reflection analysis.
|
||||
"""
|
||||
```
|
||||
|
||||
## 🎨 Custom Prompt Guidelines
|
||||
|
||||
### 1. Prompt Structure Template
|
||||
|
||||
```python
|
||||
def create_custom_prompt(
|
||||
role="Analyst",
|
||||
expertise="Market Analysis",
|
||||
style="Professional",
|
||||
language="English",
|
||||
risk_level="Moderate",
|
||||
output_format="Detailed Report"
|
||||
):
|
||||
return f"""
|
||||
Role Definition: You are a {style} {role}
|
||||
|
||||
🎯 Role Positioning:
|
||||
- Expertise: {expertise}
|
||||
- Analysis Style: {style}
|
||||
- Risk Preference: {risk_level}
|
||||
- Output Language: {language}
|
||||
|
||||
📋 Core Tasks:
|
||||
1. [Specific Task 1]
|
||||
2. [Specific Task 2]
|
||||
3. [Specific Task 3]
|
||||
|
||||
🔍 Analysis Framework:
|
||||
- Data Collection: [Data sources and types]
|
||||
- Analysis Methods: [Analysis tools and methods used]
|
||||
- Risk Assessment: [Risk identification and assessment methods]
|
||||
- Conclusion Formation: [Decision logic and criteria]
|
||||
|
||||
📊 Output Requirements:
|
||||
- Format: {output_format}
|
||||
- Structure: [Specific output structure requirements]
|
||||
- Focus: [Content that needs emphasis]
|
||||
- Constraints: [Content or practices to avoid]
|
||||
|
||||
💡 Important Notes:
|
||||
- [Special Requirement 1]
|
||||
- [Special Requirement 2]
|
||||
- [Special Requirement 3]
|
||||
|
||||
Please provide professional {expertise} analysis based on the above requirements.
|
||||
"""
|
||||
```
|
||||
|
||||
### 2. Multi-language Prompt Template
|
||||
|
||||
```python
|
||||
MULTILINGUAL_PROMPTS = {
|
||||
"en-US": {
|
||||
"role_prefix": "You are a professional",
|
||||
"task_intro": "Your task is to",
|
||||
"analysis_framework": "Analysis Framework:",
|
||||
"output_requirements": "Output Requirements:",
|
||||
"final_decision": "Final Recommendation:"
|
||||
},
|
||||
"zh-CN": {
|
||||
"role_prefix": "您是一位专业的",
|
||||
"task_intro": "您的任务是",
|
||||
"analysis_framework": "分析框架:",
|
||||
"output_requirements": "输出要求:",
|
||||
"final_decision": "最终建议:"
|
||||
},
|
||||
"ja-JP": {
|
||||
"role_prefix": "あなたはプロの",
|
||||
"task_intro": "あなたの任務は",
|
||||
"analysis_framework": "分析フレームワーク:",
|
||||
"output_requirements": "出力要件:",
|
||||
"final_decision": "最終推奨:"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
💡 **Usage Tips**:
|
||||
1. Copy the appropriate template code
|
||||
2. Modify specific content as needed
|
||||
3. Replace original prompts in corresponding files
|
||||
4. Test modification effects
|
||||
5. Further optimize based on results
|
||||
|
||||
📝 **Customization Suggestions**:
|
||||
- Maintain structured and logical prompts
|
||||
- Clearly specify output format and requirements
|
||||
- Include specific analysis frameworks and methods
|
||||
- Consider different market and cultural backgrounds
|
||||
- Regularly optimize prompts based on effectiveness feedback
|
||||
|
|
@ -0,0 +1,251 @@
|
|||
# TradingAgents Quick Reference Card
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### 1. Change LLM Provider
|
||||
```python
|
||||
# Edit main.py
|
||||
config["llm_provider"] = "google" # or "openai", "anthropic"
|
||||
config["backend_url"] = "https://generativelanguage.googleapis.com/v1"
|
||||
config["deep_think_llm"] = "gemini-2.0-flash"
|
||||
config["quick_think_llm"] = "gemini-2.0-flash"
|
||||
```
|
||||
|
||||
### 2. Modify Debate Rounds
|
||||
```python
|
||||
# Edit main.py or default_config.py
|
||||
config["max_debate_rounds"] = 3 # Increase to 3 rounds
|
||||
config["max_risk_discuss_rounds"] = 2 # Risk discussion 2 rounds
|
||||
```
|
||||
|
||||
### 3. Enable/Disable Online Tools
|
||||
```python
|
||||
config["online_tools"] = True # Enable online APIs
|
||||
config["online_tools"] = False # Use local data
|
||||
```
|
||||
|
||||
## 📁 Key File Locations
|
||||
|
||||
| Content to Modify | File Path | Specific Location |
|
||||
|------------------|-----------|-------------------|
|
||||
| **System Config** | `tradingagents/default_config.py` | Entire file |
|
||||
| **Runtime Config** | `main.py` | Lines 15-22 |
|
||||
| **Market Analyst Prompts** | `tradingagents/agents/analysts/market_analyst.py` | Lines 24-50 |
|
||||
| **Fundamentals Analyst Prompts** | `tradingagents/agents/analysts/fundamentals_analyst.py` | Lines 23-26 |
|
||||
| **News Analyst Prompts** | `tradingagents/agents/analysts/news_analyst.py` | Lines 20-23 |
|
||||
| **Social Media Analyst Prompts** | `tradingagents/agents/analysts/social_media_analyst.py` | Lines 19-22 |
|
||||
| **Bull Researcher Prompts** | `tradingagents/agents/researchers/bull_researcher.py` | Lines 25-43 |
|
||||
| **Bear Researcher Prompts** | `tradingagents/agents/researchers/bear_researcher.py` | Lines 25-43 |
|
||||
| **Trader Prompts** | `tradingagents/agents/trader/trader.py` | Lines 30-36 |
|
||||
| **Reflection System Prompts** | `tradingagents/graph/reflection.py` | Lines 15-47 |
|
||||
| **Cache Config** | `tradingagents/dataflows/cache_manager.py` | Lines 20-35 |
|
||||
|
||||
## 🎯 Common Modification Templates
|
||||
|
||||
### 1. Professional Prompt Template
|
||||
```python
|
||||
system_message = f"""
|
||||
You are a professional {role_name} with the following characteristics:
|
||||
|
||||
Expertise Areas:
|
||||
- {domain_1}
|
||||
- {domain_2}
|
||||
- {domain_3}
|
||||
|
||||
Analysis Requirements:
|
||||
1. Provide detailed analysis reasoning
|
||||
2. Include risk warnings
|
||||
3. Summarize key indicators in table format
|
||||
|
||||
Output Format:
|
||||
{output_format}
|
||||
|
||||
Important Notes:
|
||||
- Avoid simply saying "trends are mixed"
|
||||
- Provide specific data support
|
||||
- Consider market-specific factors
|
||||
"""
|
||||
```
|
||||
|
||||
### 2. Risk Control Template
|
||||
```python
|
||||
system_message = f"""
|
||||
You are a risk-conscious {role_name}.
|
||||
|
||||
Risk Control Principles:
|
||||
1. Risk first, returns second
|
||||
2. Strict stop-loss, protect capital
|
||||
3. Diversified investment, reduce risk
|
||||
4. Data-driven, rational decisions
|
||||
|
||||
Must Include:
|
||||
- Risk assessment level (Low/Medium/High)
|
||||
- Recommended stop-loss points
|
||||
- Maximum position suggestion
|
||||
- Risk warning description
|
||||
|
||||
Decision Format:
|
||||
Final Recommendation: **BUY/HOLD/SELL**
|
||||
Risk Level: **Low/Medium/High**
|
||||
Stop-Loss Point: **Specific price**
|
||||
Suggested Position: **Percentage**
|
||||
"""
|
||||
```
|
||||
|
||||
### 3. Technical Analysis Template
|
||||
```python
|
||||
system_message = f"""
|
||||
You are a professional technical analyst focusing on the following indicators:
|
||||
|
||||
Core Indicators:
|
||||
- Moving Averages: SMA, EMA
|
||||
- Momentum Indicators: RSI, MACD
|
||||
- Volatility Indicators: Bollinger Bands, ATR
|
||||
- Volume Indicators: VWMA
|
||||
|
||||
Analysis Framework:
|
||||
1. Trend identification (Up/Down/Sideways)
|
||||
2. Support and resistance levels
|
||||
3. Buy/sell signal identification
|
||||
4. Risk-reward ratio calculation
|
||||
|
||||
Output Requirements:
|
||||
- Clear trend judgment
|
||||
- Specific entry/exit points
|
||||
- Technical indicator divergence analysis
|
||||
- Volume-price relationship analysis
|
||||
"""
|
||||
```
|
||||
|
||||
## ⚙️ Configuration Parameters Quick Reference
|
||||
|
||||
### LLM Configuration
|
||||
```python
|
||||
"llm_provider": "openai" | "google" | "anthropic"
|
||||
"deep_think_llm": "model_name" # Deep thinking model
|
||||
"quick_think_llm": "model_name" # Quick thinking model
|
||||
"backend_url": "API_address"
|
||||
```
|
||||
|
||||
#### Google Models Quick Reference
|
||||
```python
|
||||
# Fast Models: gemini-2.0-flash-lite, gemini-2.0-flash ⭐, gemini-2.5-flash-preview-05-20
|
||||
# Deep Models: gemini-2.0-flash ⭐, gemini-2.5-flash-preview-05-20, gemini-2.5-pro-preview-06-05
|
||||
|
||||
# Google API Setup
|
||||
export GOOGLE_API_KEY="your_key_here"
|
||||
```
|
||||
|
||||
### Debate Configuration
|
||||
```python
|
||||
"max_debate_rounds": 1-5 # Debate rounds
|
||||
"max_risk_discuss_rounds": 1-3 # Risk discussion rounds
|
||||
"max_recur_limit": 100 # Recursion limit
|
||||
```
|
||||
|
||||
### Tool Configuration
|
||||
```python
|
||||
"online_tools": True | False # Whether to use online tools
|
||||
"data_cache_dir": "cache_directory_path"
|
||||
"results_dir": "results_output_directory"
|
||||
```
|
||||
|
||||
### Cache Configuration
|
||||
```python
|
||||
# In cache_manager.py
|
||||
'us_stock_data': {'ttl_hours': 2} # US stock cache 2 hours
|
||||
'china_stock_data': {'ttl_hours': 1} # A-share cache 1 hour
|
||||
```
|
||||
|
||||
## 🔧 Common Commands
|
||||
|
||||
### Test Configuration
|
||||
```bash
|
||||
# Run basic tests
|
||||
cd tests && python test_cache_manager.py
|
||||
|
||||
# Run integration tests
|
||||
cd tests && python test_integration.py
|
||||
|
||||
# Run performance tests
|
||||
cd tests && python test_performance.py
|
||||
```
|
||||
|
||||
### Backup and Restore
|
||||
```bash
|
||||
# Backup configuration files
|
||||
cp tradingagents/default_config.py tradingagents/default_config.py.backup
|
||||
|
||||
# Backup prompt files
|
||||
cp tradingagents/agents/trader/trader.py tradingagents/agents/trader/trader.py.backup
|
||||
|
||||
# Restore files
|
||||
cp tradingagents/default_config.py.backup tradingagents/default_config.py
|
||||
```
|
||||
|
||||
### Git Management
|
||||
```bash
|
||||
# Check modification status
|
||||
git status
|
||||
|
||||
# Commit configuration changes
|
||||
git add tradingagents/default_config.py
|
||||
git commit -m "feat: Update LLM configuration to Google Gemini"
|
||||
|
||||
# Commit prompt changes
|
||||
git add tradingagents/agents/trader/trader.py
|
||||
git commit -m "feat: Optimize trader prompts, add risk control"
|
||||
```
|
||||
|
||||
## 🚨 Important Notes
|
||||
|
||||
### ⚠️ Must Do Before Modification
|
||||
1. **Backup Files**: Always backup original files before modification
|
||||
2. **Test Environment**: Validate modifications in test environment
|
||||
3. **Version Control**: Use Git to track all changes
|
||||
|
||||
### ⚠️ Common Errors
|
||||
1. **Forgot to Restart**: Need to restart application after config changes
|
||||
2. **Path Errors**: Ensure file paths are correct
|
||||
3. **Syntax Errors**: Python syntax must be correct
|
||||
4. **Encoding Issues**: Use UTF-8 encoding for content
|
||||
|
||||
### ⚠️ Performance Considerations
|
||||
1. **Prompt Length**: Avoid overly long prompts (recommend <4000 tokens)
|
||||
2. **API Call Frequency**: Be aware of API call limits
|
||||
3. **Cache Settings**: Set reasonable cache TTL times
|
||||
|
||||
## 🆘 Troubleshooting
|
||||
|
||||
### Issue: Configuration not taking effect
|
||||
```python
|
||||
# Solution: Force reload configuration
|
||||
from tradingagents.dataflows.config import reload_config
|
||||
reload_config()
|
||||
```
|
||||
|
||||
### Issue: API call failures
|
||||
```python
|
||||
# Solution: Check API keys and network connection
|
||||
import os
|
||||
print("OpenAI API Key:", os.getenv("OPENAI_API_KEY", "Not set"))
|
||||
print("Google API Key:", os.getenv("GOOGLE_API_KEY", "Not set"))
|
||||
```
|
||||
|
||||
### Issue: High memory usage
|
||||
```python
|
||||
# Solution: Enable cache cleanup
|
||||
config["cache_settings"]["cache_size_limit_mb"] = 500 # Limit cache size
|
||||
config["cache_settings"]["cache_cleanup_interval"] = 1800 # Clean every 30 minutes
|
||||
```
|
||||
|
||||
## 📞 Getting Help
|
||||
|
||||
1. **View Detailed Documentation**: `docs/en-US/configuration_guide.md`
|
||||
2. **Run Tests**: Test files in `tests/` directory
|
||||
3. **View Examples**: `examples/` directory (if available)
|
||||
4. **GitHub Issues**: Submit issues in project repository
|
||||
|
||||
---
|
||||
|
||||
💡 **Tip**: Recommend bookmarking this document for easy reference!
|
||||
|
|
@ -0,0 +1,383 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
TradingAgents 全量合并脚本
|
||||
将中文版本的所有新功能合并到主项目中
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import difflib
|
||||
|
||||
class FullMerger:
|
||||
"""全量合并器"""
|
||||
|
||||
def __init__(self, source_dir="TradingAgentsCN", target_dir="."):
|
||||
self.source_dir = Path(source_dir)
|
||||
self.target_dir = Path(target_dir)
|
||||
|
||||
# 需要特殊处理的冲突文件
|
||||
self.conflict_files = [
|
||||
"tradingagents/dataflows/cache_manager.py",
|
||||
"tradingagents/dataflows/optimized_us_data.py",
|
||||
"tradingagents/dataflows/interface.py",
|
||||
"tradingagents/default_config.py"
|
||||
]
|
||||
|
||||
# 要忽略的文件和目录
|
||||
self.ignore_patterns = [
|
||||
"__pycache__",
|
||||
"*.pyc",
|
||||
".git",
|
||||
"test_env",
|
||||
"env",
|
||||
"data_cache",
|
||||
"*.csv",
|
||||
"eval_results",
|
||||
"results",
|
||||
"finnhub_data",
|
||||
"enhanced_analysis_reports"
|
||||
]
|
||||
|
||||
def should_ignore(self, path: Path) -> bool:
|
||||
"""检查是否应该忽略此路径"""
|
||||
path_str = str(path)
|
||||
for pattern in self.ignore_patterns:
|
||||
if pattern in path_str:
|
||||
return True
|
||||
return False
|
||||
|
||||
def merge_new_files(self) -> int:
|
||||
"""合并新增文件"""
|
||||
print("📄 合并新增文件...")
|
||||
|
||||
source_tradingagents = self.source_dir / "tradingagents"
|
||||
target_tradingagents = self.target_dir / "tradingagents"
|
||||
|
||||
if not source_tradingagents.exists():
|
||||
print(f"❌ 源目录不存在: {source_tradingagents}")
|
||||
return 0
|
||||
|
||||
merged_count = 0
|
||||
|
||||
# 遍历源目录中的所有文件
|
||||
for source_file in source_tradingagents.rglob("*"):
|
||||
if source_file.is_file() and not self.should_ignore(source_file):
|
||||
# 计算相对路径
|
||||
rel_path = source_file.relative_to(self.source_dir)
|
||||
target_file = self.target_dir / rel_path
|
||||
|
||||
# 检查是否为新文件
|
||||
if not target_file.exists():
|
||||
# 创建目标目录
|
||||
target_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 复制文件
|
||||
shutil.copy2(source_file, target_file)
|
||||
print(f" ✅ 新增: {rel_path}")
|
||||
merged_count += 1
|
||||
|
||||
return merged_count
|
||||
|
||||
def handle_conflict_files(self) -> int:
|
||||
"""处理冲突文件"""
|
||||
print("\n⚠️ 处理冲突文件...")
|
||||
|
||||
handled_count = 0
|
||||
|
||||
for conflict_file in self.conflict_files:
|
||||
source_file = self.source_dir / conflict_file
|
||||
target_file = self.target_dir / conflict_file
|
||||
|
||||
if source_file.exists() and target_file.exists():
|
||||
print(f" 🔄 处理冲突: {conflict_file}")
|
||||
|
||||
# 创建备份
|
||||
backup_file = target_file.with_suffix(target_file.suffix + ".backup")
|
||||
shutil.copy2(target_file, backup_file)
|
||||
print(f" 💾 备份创建: {backup_file.name}")
|
||||
|
||||
# 生成差异报告
|
||||
diff_file = target_file.with_suffix(target_file.suffix + ".diff")
|
||||
self._generate_diff_report(source_file, target_file, diff_file)
|
||||
print(f" 📋 差异报告: {diff_file.name}")
|
||||
|
||||
# 对于某些文件,尝试智能合并
|
||||
if self._smart_merge(source_file, target_file, conflict_file):
|
||||
print(f" ✅ 智能合并成功")
|
||||
else:
|
||||
print(f" ⚠️ 需要手动合并")
|
||||
|
||||
handled_count += 1
|
||||
elif source_file.exists():
|
||||
# 源文件存在但目标文件不存在,直接复制
|
||||
target_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(source_file, target_file)
|
||||
print(f" ✅ 直接复制: {conflict_file}")
|
||||
handled_count += 1
|
||||
|
||||
return handled_count
|
||||
|
||||
def _generate_diff_report(self, source_file: Path, target_file: Path, diff_file: Path):
|
||||
"""生成差异报告"""
|
||||
try:
|
||||
with open(source_file, 'r', encoding='utf-8') as f:
|
||||
source_lines = f.readlines()
|
||||
with open(target_file, 'r', encoding='utf-8') as f:
|
||||
target_lines = f.readlines()
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
target_lines, source_lines,
|
||||
fromfile=f"current/{target_file.name}",
|
||||
tofile=f"chinese_version/{source_file.name}",
|
||||
lineterm=''
|
||||
)
|
||||
|
||||
with open(diff_file, 'w', encoding='utf-8') as f:
|
||||
f.write(f"# 文件差异报告\n")
|
||||
f.write(f"# 当前文件: {target_file}\n")
|
||||
f.write(f"# 中文版文件: {source_file}\n")
|
||||
f.write(f"# 生成时间: {os.popen('date /t').read().strip()}\n\n")
|
||||
f.writelines(diff)
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 生成差异报告失败: {e}")
|
||||
|
||||
def _smart_merge(self, source_file: Path, target_file: Path, conflict_file: str) -> bool:
|
||||
"""智能合并某些文件"""
|
||||
try:
|
||||
if "default_config.py" in conflict_file:
|
||||
return self._merge_default_config(source_file, target_file)
|
||||
elif "cache_manager.py" in conflict_file:
|
||||
# cache_manager.py 已经是英文版本,保持当前版本
|
||||
print(f" 📝 保持当前英文版本的 cache_manager.py")
|
||||
return True
|
||||
elif "optimized_us_data.py" in conflict_file:
|
||||
# optimized_us_data.py 已经是英文版本,保持当前版本
|
||||
print(f" 📝 保持当前英文版本的 optimized_us_data.py")
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f" ❌ 智能合并失败: {e}")
|
||||
return False
|
||||
|
||||
def _merge_default_config(self, source_file: Path, target_file: Path) -> bool:
|
||||
"""合并默认配置文件"""
|
||||
try:
|
||||
# 读取两个文件
|
||||
with open(source_file, 'r', encoding='utf-8') as f:
|
||||
source_content = f.read()
|
||||
with open(target_file, 'r', encoding='utf-8') as f:
|
||||
target_content = f.read()
|
||||
|
||||
# 简单策略:如果中文版本有新的配置项,添加到目标文件
|
||||
# 这里可以根据需要实现更复杂的合并逻辑
|
||||
|
||||
# 检查中文版本是否有新的配置项
|
||||
if "data_cache_dir" in source_content and "data_cache_dir" not in target_content:
|
||||
# 添加缓存目录配置
|
||||
lines = target_content.split('\n')
|
||||
for i, line in enumerate(lines):
|
||||
if '"data_dir":' in line:
|
||||
# 在data_dir后面添加data_cache_dir
|
||||
cache_dir_line = ' "data_cache_dir": os.path.join('
|
||||
cache_dir_line += '\n os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),'
|
||||
cache_dir_line += '\n "dataflows/data_cache",'
|
||||
cache_dir_line += '\n ),'
|
||||
lines.insert(i + 1, cache_dir_line)
|
||||
break
|
||||
|
||||
# 写回文件
|
||||
with open(target_file, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(lines))
|
||||
|
||||
print(f" ✅ 添加了 data_cache_dir 配置")
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 合并默认配置失败: {e}")
|
||||
return False
|
||||
|
||||
def update_dependencies(self) -> bool:
|
||||
"""更新依赖项"""
|
||||
print("\n📦 更新依赖项...")
|
||||
|
||||
source_pyproject = self.source_dir / "pyproject.toml"
|
||||
target_pyproject = self.target_dir / "pyproject.toml"
|
||||
|
||||
if not source_pyproject.exists():
|
||||
print(" ⚠️ 源项目没有 pyproject.toml")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 读取源文件的依赖项
|
||||
with open(source_pyproject, 'r', encoding='utf-8') as f:
|
||||
source_content = f.read()
|
||||
|
||||
# 提取新的依赖项
|
||||
new_deps = []
|
||||
if 'pymongo' in source_content:
|
||||
new_deps.append('"pymongo>=4.0.0"')
|
||||
if 'beautifulsoup4' in source_content:
|
||||
new_deps.append('"beautifulsoup4>=4.9.0"')
|
||||
if 'dashscope' in source_content:
|
||||
new_deps.append('"dashscope>=1.0.0"')
|
||||
|
||||
if new_deps:
|
||||
# 读取目标文件
|
||||
with open(target_pyproject, 'r', encoding='utf-8') as f:
|
||||
target_lines = f.readlines()
|
||||
|
||||
# 找到dependencies部分并添加新依赖
|
||||
in_dependencies = False
|
||||
for i, line in enumerate(target_lines):
|
||||
if 'dependencies = [' in line:
|
||||
in_dependencies = True
|
||||
elif in_dependencies and ']' in line:
|
||||
# 在]前添加新依赖
|
||||
for dep in new_deps:
|
||||
target_lines.insert(i, f' {dep},\n')
|
||||
i += 1
|
||||
break
|
||||
|
||||
# 写回文件
|
||||
with open(target_pyproject, 'w', encoding='utf-8') as f:
|
||||
f.writelines(target_lines)
|
||||
|
||||
print(f" ✅ 添加了新依赖: {', '.join(new_deps)}")
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 更新依赖项失败: {e}")
|
||||
return False
|
||||
|
||||
def create_merge_summary(self, new_files: int, conflict_files: int) -> str:
|
||||
"""创建合并摘要"""
|
||||
summary_file = "MERGE_SUMMARY.md"
|
||||
|
||||
with open(summary_file, 'w', encoding='utf-8') as f:
|
||||
f.write("# TradingAgents 中文版功能全量合并摘要\n\n")
|
||||
f.write(f"合并时间: {os.popen('date /t').read().strip()}\n")
|
||||
f.write(f"合并分支: full-merge-chinese-features\n\n")
|
||||
|
||||
f.write("## 📊 合并统计\n\n")
|
||||
f.write(f"- 新增文件: {new_files} 个\n")
|
||||
f.write(f"- 处理冲突文件: {conflict_files} 个\n\n")
|
||||
|
||||
f.write("## 🆕 主要新增功能\n\n")
|
||||
f.write("### 中国市场数据支持\n")
|
||||
f.write("- `chinese_finance_utils.py` - 中国财经数据聚合工具\n")
|
||||
f.write("- `tdx_utils.py` - 通达信API数据获取\n")
|
||||
f.write("- `optimized_china_data.py` - 优化的A股数据提供器\n")
|
||||
f.write("- `china_market_analyst.py` - 中国市场分析师\n\n")
|
||||
|
||||
f.write("### 数据库集成\n")
|
||||
f.write("- `database_config.py` - 数据库配置管理\n")
|
||||
f.write("- `database_manager.py` - 统一数据库管理器\n")
|
||||
f.write("- `mongodb_storage.py` - MongoDB存储支持\n")
|
||||
f.write("- `db_cache_manager.py` - 数据库缓存管理\n\n")
|
||||
|
||||
f.write("### 高级缓存系统\n")
|
||||
f.write("- `adaptive_cache.py` - 自适应缓存策略\n")
|
||||
f.write("- `integrated_cache.py` - 集成缓存管理\n\n")
|
||||
|
||||
f.write("### LLM适配器扩展\n")
|
||||
f.write("- `llm_adapters/` - LLM适配器框架\n")
|
||||
f.write("- `dashscope_adapter.py` - 阿里云DashScope支持\n\n")
|
||||
|
||||
f.write("### API和服务层\n")
|
||||
f.write("- `api/` - 统一API接口\n")
|
||||
f.write("- `stock_data_service.py` - 股票数据服务\n")
|
||||
f.write("- `realtime_news_utils.py` - 实时新闻工具\n\n")
|
||||
|
||||
f.write("## ⚠️ 需要注意的变更\n\n")
|
||||
f.write("### 新增依赖项\n")
|
||||
f.write("- `pymongo` - MongoDB数据库支持\n")
|
||||
f.write("- `beautifulsoup4` - 网页数据解析\n")
|
||||
f.write("- `dashscope` - 阿里云LLM支持 (可选)\n\n")
|
||||
|
||||
f.write("### 配置文件变更\n")
|
||||
f.write("- 添加了数据库相关配置\n")
|
||||
f.write("- 扩展了缓存配置选项\n")
|
||||
f.write("- 新增了中国市场数据源配置\n\n")
|
||||
|
||||
f.write("## 🧪 测试建议\n\n")
|
||||
f.write("1. **基础功能测试**: 确保原有功能正常工作\n")
|
||||
f.write("2. **新功能测试**: 测试中国市场数据获取\n")
|
||||
f.write("3. **缓存系统测试**: 验证缓存性能和稳定性\n")
|
||||
f.write("4. **数据库集成测试**: 测试MongoDB连接和存储\n")
|
||||
f.write("5. **LLM适配器测试**: 验证多LLM支持\n\n")
|
||||
|
||||
f.write("## 📝 后续工作\n\n")
|
||||
f.write("1. 更新文档以反映新功能\n")
|
||||
f.write("2. 添加新功能的使用示例\n")
|
||||
f.write("3. 完善测试覆盖率\n")
|
||||
f.write("4. 优化性能和稳定性\n\n")
|
||||
|
||||
f.write("## 🔄 如果需要分批PR\n\n")
|
||||
f.write("如果原项目认为全量合并过于复杂,可以按以下顺序分批提交:\n\n")
|
||||
f.write("1. **基础设施**: config/, database相关文件\n")
|
||||
f.write("2. **中国市场数据**: chinese_finance_utils.py, tdx_utils.py等\n")
|
||||
f.write("3. **高级缓存**: adaptive_cache.py, integrated_cache.py等\n")
|
||||
f.write("4. **LLM适配器**: llm_adapters/目录\n")
|
||||
f.write("5. **API服务**: api/目录和相关服务文件\n")
|
||||
|
||||
return summary_file
|
||||
|
||||
def run_full_merge(self) -> bool:
|
||||
"""执行全量合并"""
|
||||
print("🚀 开始全量合并中文版功能...")
|
||||
print("=" * 50)
|
||||
|
||||
# 检查源目录
|
||||
if not self.source_dir.exists():
|
||||
print(f"❌ 源目录不存在: {self.source_dir}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 1. 合并新增文件
|
||||
new_files = self.merge_new_files()
|
||||
|
||||
# 2. 处理冲突文件
|
||||
conflict_files = self.handle_conflict_files()
|
||||
|
||||
# 3. 更新依赖项
|
||||
self.update_dependencies()
|
||||
|
||||
# 4. 创建合并摘要
|
||||
summary_file = self.create_merge_summary(new_files, conflict_files)
|
||||
|
||||
print(f"\n✅ 全量合并完成!")
|
||||
print(f"📊 合并统计:")
|
||||
print(f" 新增文件: {new_files} 个")
|
||||
print(f" 处理冲突文件: {conflict_files} 个")
|
||||
print(f"📋 合并摘要: {summary_file}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 合并过程中出现错误: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
merger = FullMerger()
|
||||
|
||||
if merger.run_full_merge():
|
||||
print(f"\n🎯 下一步操作:")
|
||||
print("1. 检查合并结果和差异文件")
|
||||
print("2. 手动处理需要合并的冲突文件")
|
||||
print("3. 运行测试确保功能正常")
|
||||
print("4. 提交更改: git add . && git commit -m 'feat: merge Chinese version features'")
|
||||
print("5. 推送分支: git push origin full-merge-chinese-features")
|
||||
print("6. 创建PR到原项目")
|
||||
else:
|
||||
print(f"\n❌ 合并失败,请检查错误信息")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,214 @@
|
|||
# TradingAgents 中文版功能合并计划
|
||||
|
||||
## 📊 差异分析
|
||||
|
||||
### 中文版本新增的主要功能模块
|
||||
|
||||
#### 🏗️ 基础设施层
|
||||
- `config/` 目录:配置管理、数据库配置
|
||||
- `api/` 目录:股票API接口
|
||||
|
||||
#### 📊 数据层
|
||||
- `chinese_finance_utils.py` - 中国财经数据聚合
|
||||
- `tdx_utils.py` - 通达信API数据获取
|
||||
- `optimized_china_data.py` - 优化的A股数据提供器
|
||||
- `stock_data_service.py` - 股票数据服务
|
||||
- `realtime_news_utils.py` - 实时新闻工具
|
||||
|
||||
#### 💾 缓存层
|
||||
- `adaptive_cache.py` - 自适应缓存
|
||||
- `integrated_cache.py` - 集成缓存系统
|
||||
- `db_cache_manager.py` - 数据库缓存管理
|
||||
|
||||
#### 🤖 LLM适配层
|
||||
- `llm_adapters/dashscope_adapter.py` - 阿里云DashScope适配器
|
||||
|
||||
#### 🗄️ 数据库层
|
||||
- `database_config.py` - 数据库配置
|
||||
- `database_manager.py` - 数据库管理器
|
||||
- `mongodb_storage.py` - MongoDB存储
|
||||
|
||||
## 🎯 合并策略
|
||||
|
||||
### 阶段1:基础设施合并 (优先级:高)
|
||||
|
||||
**目标**:建立配置和数据库基础设施
|
||||
|
||||
**步骤**:
|
||||
1. 创建 `tradingagents/config/` 目录
|
||||
2. 合并配置管理相关文件
|
||||
3. 合并数据库相关文件
|
||||
4. 更新依赖项
|
||||
|
||||
**风险评估**:低
|
||||
**预计时间**:1-2天
|
||||
|
||||
### 阶段2:中国市场数据支持 (优先级:高)
|
||||
|
||||
**目标**:添加A股和中国市场数据支持
|
||||
|
||||
**步骤**:
|
||||
1. 合并 `chinese_finance_utils.py`
|
||||
2. 合并 `tdx_utils.py`
|
||||
3. 合并 `optimized_china_data.py`
|
||||
4. 测试中国市场数据获取功能
|
||||
|
||||
**风险评估**:中等
|
||||
**预计时间**:2-3天
|
||||
|
||||
### 阶段3:高级缓存系统 (优先级:中)
|
||||
|
||||
**目标**:提升缓存性能和智能化
|
||||
|
||||
**步骤**:
|
||||
1. 合并 `adaptive_cache.py`
|
||||
2. 合并 `integrated_cache.py`
|
||||
3. 合并 `db_cache_manager.py`
|
||||
4. 集成到现有缓存系统
|
||||
|
||||
**风险评估**:中等
|
||||
**预计时间**:2-3天
|
||||
|
||||
### 阶段4:LLM适配器扩展 (优先级:中)
|
||||
|
||||
**目标**:支持更多LLM提供商
|
||||
|
||||
**步骤**:
|
||||
1. 创建 `tradingagents/llm_adapters/` 目录
|
||||
2. 合并 `dashscope_adapter.py`
|
||||
3. 集成到现有LLM系统
|
||||
4. 测试多LLM支持
|
||||
|
||||
**风险评估**:中等
|
||||
**预计时间**:1-2天
|
||||
|
||||
### 阶段5:API和服务层 (优先级:低)
|
||||
|
||||
**目标**:完善API接口和服务
|
||||
|
||||
**步骤**:
|
||||
1. 创建 `tradingagents/api/` 目录
|
||||
2. 合并API相关文件
|
||||
3. 合并服务层文件
|
||||
4. 集成测试
|
||||
|
||||
**风险评估**:低
|
||||
**预计时间**:1-2天
|
||||
|
||||
## 🔧 实施细节
|
||||
|
||||
### 合并前检查清单
|
||||
|
||||
- [ ] 备份当前项目
|
||||
- [ ] 创建合并分支
|
||||
- [ ] 分析依赖冲突
|
||||
- [ ] 准备测试环境
|
||||
- [ ] 制定回滚计划
|
||||
|
||||
### 文件冲突处理
|
||||
|
||||
**已存在的文件**:
|
||||
- `cache_manager.py` - 需要合并功能
|
||||
- `optimized_us_data.py` - 需要合并功能
|
||||
- `interface.py` - 需要合并功能
|
||||
|
||||
**处理策略**:
|
||||
1. 比较文件差异
|
||||
2. 保留最佳功能
|
||||
3. 统一代码风格
|
||||
4. 更新文档
|
||||
|
||||
### 依赖管理
|
||||
|
||||
**新增依赖**:
|
||||
- `pymongo` - MongoDB支持
|
||||
- `beautifulsoup4` - 网页解析
|
||||
- `dashscope` - 阿里云LLM
|
||||
|
||||
**处理方式**:
|
||||
- 更新 `pyproject.toml`
|
||||
- 添加可选依赖组
|
||||
- 更新安装文档
|
||||
|
||||
## 🧪 测试策略
|
||||
|
||||
### 单元测试
|
||||
- 每个阶段完成后进行单元测试
|
||||
- 重点测试新功能和集成点
|
||||
- 确保向后兼容性
|
||||
|
||||
### 集成测试
|
||||
- 测试数据流完整性
|
||||
- 测试缓存系统性能
|
||||
- 测试多市场数据获取
|
||||
|
||||
### 性能测试
|
||||
- 对比合并前后性能
|
||||
- 测试缓存命中率
|
||||
- 测试内存使用情况
|
||||
|
||||
## 📝 文档更新
|
||||
|
||||
### 需要更新的文档
|
||||
- 配置指南
|
||||
- API文档
|
||||
- 安装指南
|
||||
- 使用示例
|
||||
|
||||
### 新增文档
|
||||
- 中国市场数据使用指南
|
||||
- 数据库配置指南
|
||||
- 多LLM配置指南
|
||||
|
||||
## 🚨 风险控制
|
||||
|
||||
### 主要风险
|
||||
1. **功能冲突**:新旧功能可能存在冲突
|
||||
2. **性能影响**:新功能可能影响现有性能
|
||||
3. **依赖冲突**:新依赖可能与现有依赖冲突
|
||||
4. **稳定性**:新功能可能引入不稳定因素
|
||||
|
||||
### 缓解措施
|
||||
1. **分阶段合并**:降低单次合并风险
|
||||
2. **充分测试**:每个阶段都进行完整测试
|
||||
3. **版本控制**:使用Git分支管理合并过程
|
||||
4. **回滚计划**:准备快速回滚方案
|
||||
|
||||
## 📅 时间计划
|
||||
|
||||
| 阶段 | 预计时间 | 累计时间 |
|
||||
|------|----------|----------|
|
||||
| 阶段1:基础设施 | 1-2天 | 1-2天 |
|
||||
| 阶段2:中国市场数据 | 2-3天 | 3-5天 |
|
||||
| 阶段3:高级缓存 | 2-3天 | 5-8天 |
|
||||
| 阶段4:LLM适配器 | 1-2天 | 6-10天 |
|
||||
| 阶段5:API服务 | 1-2天 | 7-12天 |
|
||||
| 测试和文档 | 2-3天 | 9-15天 |
|
||||
|
||||
**总预计时间**:9-15天
|
||||
|
||||
## ✅ 成功标准
|
||||
|
||||
### 功能标准
|
||||
- [ ] 所有原有功能正常工作
|
||||
- [ ] 新功能正确集成
|
||||
- [ ] 性能无明显下降
|
||||
- [ ] 文档完整更新
|
||||
|
||||
### 质量标准
|
||||
- [ ] 代码风格统一
|
||||
- [ ] 测试覆盖率不降低
|
||||
- [ ] 无明显技术债务
|
||||
- [ ] 向后兼容性保持
|
||||
|
||||
## 🎯 后续优化
|
||||
|
||||
### 短期优化
|
||||
- 代码重构和优化
|
||||
- 性能调优
|
||||
- 文档完善
|
||||
|
||||
### 长期规划
|
||||
- 功能扩展
|
||||
- 架构优化
|
||||
- 社区贡献
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_china_market_analyst(llm, toolkit):
|
||||
"""创建中国市场分析师"""
|
||||
|
||||
def china_market_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
|
||||
# 中国股票分析工具
|
||||
tools = [
|
||||
toolkit.get_china_stock_data,
|
||||
toolkit.get_china_market_overview,
|
||||
toolkit.get_YFin_data, # 备用数据源
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"""您是一位专业的中国股市分析师,专门分析A股、港股等中国资本市场。您具备深厚的中国股市知识和丰富的本土投资经验。
|
||||
|
||||
您的专业领域包括:
|
||||
1. **A股市场分析**: 深度理解A股的独特性,包括涨跌停制度、T+1交易、融资融券等
|
||||
2. **中国经济政策**: 熟悉货币政策、财政政策对股市的影响机制
|
||||
3. **行业板块轮动**: 掌握中国特色的板块轮动规律和热点切换
|
||||
4. **监管环境**: 了解证监会政策、退市制度、注册制等监管变化
|
||||
5. **市场情绪**: 理解中国投资者的行为特征和情绪波动
|
||||
|
||||
分析重点:
|
||||
- **技术面分析**: 使用通达信数据进行精确的技术指标分析
|
||||
- **基本面分析**: 结合中国会计准则和财报特点进行分析
|
||||
- **政策面分析**: 评估政策变化对个股和板块的影响
|
||||
- **资金面分析**: 分析北向资金、融资融券、大宗交易等资金流向
|
||||
- **市场风格**: 判断当前是成长风格还是价值风格占优
|
||||
|
||||
中国股市特色考虑:
|
||||
- 涨跌停板限制对交易策略的影响
|
||||
- ST股票的特殊风险和机会
|
||||
- 科创板、创业板的差异化分析
|
||||
- 国企改革、混改等主题投资机会
|
||||
- 中美关系、地缘政治对中概股的影响
|
||||
|
||||
请基于通达信API提供的实时数据和技术指标,结合中国股市的特殊性,撰写专业的中文分析报告。
|
||||
确保在报告末尾附上Markdown表格总结关键发现和投资建议。"""
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"您是一位专业的AI助手,与其他分析师协作进行股票分析。"
|
||||
" 使用提供的工具获取和分析数据。"
|
||||
" 如果您无法完全回答,没关系;其他分析师会补充您的分析。"
|
||||
" 专注于您的专业领域,提供高质量的分析见解。"
|
||||
" 您可以访问以下工具:{tool_names}。\n{system_message}"
|
||||
"当前分析日期:{current_date},分析标的:{ticker}。请用中文撰写所有分析内容。",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(ticker=ticker)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"china_market_report": report,
|
||||
"sender": "ChinaMarketAnalyst",
|
||||
}
|
||||
|
||||
return china_market_analyst_node
|
||||
|
||||
|
||||
def create_china_stock_screener(llm, toolkit):
|
||||
"""创建中国股票筛选器"""
|
||||
|
||||
def china_stock_screener_node(state):
|
||||
current_date = state["trade_date"]
|
||||
|
||||
tools = [
|
||||
toolkit.get_china_market_overview,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"""您是一位专业的中国股票筛选专家,负责从A股市场中筛选出具有投资价值的股票。
|
||||
|
||||
筛选维度包括:
|
||||
1. **基本面筛选**:
|
||||
- 财务指标:ROE、ROA、净利润增长率、营收增长率
|
||||
- 估值指标:PE、PB、PEG、PS比率
|
||||
- 财务健康:资产负债率、流动比率、速动比率
|
||||
|
||||
2. **技术面筛选**:
|
||||
- 趋势指标:均线系统、MACD、KDJ
|
||||
- 动量指标:RSI、威廉指标、CCI
|
||||
- 成交量指标:量价关系、换手率
|
||||
|
||||
3. **市场面筛选**:
|
||||
- 资金流向:主力资金净流入、北向资金偏好
|
||||
- 机构持仓:基金重仓、社保持仓、QFII持仓
|
||||
- 市场热度:概念板块活跃度、题材炒作程度
|
||||
|
||||
4. **政策面筛选**:
|
||||
- 政策受益:国家政策扶持行业
|
||||
- 改革红利:国企改革、混改标的
|
||||
- 监管影响:监管政策变化的影响
|
||||
|
||||
筛选策略:
|
||||
- **价值投资**: 低估值、高分红、稳定增长
|
||||
- **成长投资**: 高增长、新兴行业、技术创新
|
||||
- **主题投资**: 政策驱动、事件催化、概念炒作
|
||||
- **周期投资**: 经济周期、行业周期、季节性
|
||||
|
||||
请基于当前市场环境和政策背景,提供专业的股票筛选建议。"""
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"您是一位专业的股票筛选专家。"
|
||||
" 使用提供的工具分析市场概况。"
|
||||
" 您可以访问以下工具:{tool_names}。\n{system_message}"
|
||||
"当前日期:{current_date}。请用中文撰写分析内容。",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"stock_screening_report": result.content,
|
||||
"sender": "ChinaStockScreener",
|
||||
}
|
||||
|
||||
return china_stock_screener_node
|
||||
|
|
@ -0,0 +1,295 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
股票数据API接口
|
||||
提供便捷的股票数据获取接口,支持完整的降级机制
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# 添加dataflows目录到路径
|
||||
dataflows_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'dataflows')
|
||||
if dataflows_path not in sys.path:
|
||||
sys.path.append(dataflows_path)
|
||||
|
||||
try:
|
||||
from stock_data_service import get_stock_data_service
|
||||
SERVICE_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
print(f"⚠️ 股票数据服务不可用: {e}")
|
||||
SERVICE_AVAILABLE = False
|
||||
|
||||
def get_stock_info(stock_code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取单个股票的基础信息
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码(如 '000001')
|
||||
|
||||
Returns:
|
||||
Dict: 股票基础信息
|
||||
|
||||
Example:
|
||||
>>> info = get_stock_info('000001')
|
||||
>>> print(info['name']) # 平安银行
|
||||
"""
|
||||
if not SERVICE_AVAILABLE:
|
||||
return {
|
||||
'error': '股票数据服务不可用',
|
||||
'code': stock_code,
|
||||
'suggestion': '请检查服务配置'
|
||||
}
|
||||
|
||||
service = get_stock_data_service()
|
||||
result = service.get_stock_basic_info(stock_code)
|
||||
|
||||
if result is None:
|
||||
return {
|
||||
'error': f'未找到股票{stock_code}的信息',
|
||||
'code': stock_code,
|
||||
'suggestion': '请检查股票代码是否正确'
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_all_stocks() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取所有股票的基础信息
|
||||
|
||||
Returns:
|
||||
List[Dict]: 所有股票的基础信息列表
|
||||
|
||||
Example:
|
||||
>>> stocks = get_all_stocks()
|
||||
>>> print(f"共有{len(stocks)}只股票")
|
||||
"""
|
||||
if not SERVICE_AVAILABLE:
|
||||
return [{
|
||||
'error': '股票数据服务不可用',
|
||||
'suggestion': '请检查服务配置'
|
||||
}]
|
||||
|
||||
service = get_stock_data_service()
|
||||
result = service.get_stock_basic_info()
|
||||
|
||||
if result is None or (isinstance(result, dict) and 'error' in result):
|
||||
return [{
|
||||
'error': '无法获取股票列表',
|
||||
'suggestion': '请检查网络连接和数据库配置'
|
||||
}]
|
||||
|
||||
return result if isinstance(result, list) else [result]
|
||||
|
||||
def get_stock_data(stock_code: str, start_date: str = None, end_date: str = None) -> str:
|
||||
"""
|
||||
获取股票历史数据(带降级机制)
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
start_date: 开始日期(格式:YYYY-MM-DD),默认为30天前
|
||||
end_date: 结束日期(格式:YYYY-MM-DD),默认为今天
|
||||
|
||||
Returns:
|
||||
str: 股票数据的字符串表示或错误信息
|
||||
|
||||
Example:
|
||||
>>> data = get_stock_data('000001', '2024-01-01', '2024-01-31')
|
||||
>>> print(data)
|
||||
"""
|
||||
if not SERVICE_AVAILABLE:
|
||||
return "❌ 股票数据服务不可用,请检查服务配置"
|
||||
|
||||
# 设置默认日期
|
||||
if end_date is None:
|
||||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
if start_date is None:
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d')
|
||||
|
||||
service = get_stock_data_service()
|
||||
return service.get_stock_data_with_fallback(stock_code, start_date, end_date)
|
||||
|
||||
def search_stocks(keyword: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
根据关键词搜索股票
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词(股票代码或名称的一部分)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 匹配的股票信息列表
|
||||
|
||||
Example:
|
||||
>>> results = search_stocks('平安')
|
||||
>>> for stock in results:
|
||||
... print(f"{stock['code']}: {stock['name']}")
|
||||
"""
|
||||
all_stocks = get_all_stocks()
|
||||
|
||||
if not all_stocks or (len(all_stocks) == 1 and 'error' in all_stocks[0]):
|
||||
return all_stocks
|
||||
|
||||
# 搜索匹配的股票
|
||||
matches = []
|
||||
keyword_lower = keyword.lower()
|
||||
|
||||
for stock in all_stocks:
|
||||
if 'error' in stock:
|
||||
continue
|
||||
|
||||
code = stock.get('code', '').lower()
|
||||
name = stock.get('name', '').lower()
|
||||
|
||||
if keyword_lower in code or keyword_lower in name:
|
||||
matches.append(stock)
|
||||
|
||||
return matches
|
||||
|
||||
def get_market_summary() -> Dict[str, Any]:
|
||||
"""
|
||||
获取市场概览信息
|
||||
|
||||
Returns:
|
||||
Dict: 市场统计信息
|
||||
|
||||
Example:
|
||||
>>> summary = get_market_summary()
|
||||
>>> print(f"沪市股票数量: {summary['shanghai_count']}")
|
||||
"""
|
||||
all_stocks = get_all_stocks()
|
||||
|
||||
if not all_stocks or (len(all_stocks) == 1 and 'error' in all_stocks[0]):
|
||||
return {
|
||||
'error': '无法获取市场数据',
|
||||
'suggestion': '请检查网络连接和数据库配置'
|
||||
}
|
||||
|
||||
# 统计市场信息
|
||||
shanghai_count = 0
|
||||
shenzhen_count = 0
|
||||
category_stats = {}
|
||||
|
||||
for stock in all_stocks:
|
||||
if 'error' in stock:
|
||||
continue
|
||||
|
||||
market = stock.get('market', '')
|
||||
category = stock.get('category', '未知')
|
||||
|
||||
if market == '上海':
|
||||
shanghai_count += 1
|
||||
elif market == '深圳':
|
||||
shenzhen_count += 1
|
||||
|
||||
category_stats[category] = category_stats.get(category, 0) + 1
|
||||
|
||||
return {
|
||||
'total_count': len([s for s in all_stocks if 'error' not in s]),
|
||||
'shanghai_count': shanghai_count,
|
||||
'shenzhen_count': shenzhen_count,
|
||||
'category_stats': category_stats,
|
||||
'data_source': all_stocks[0].get('source', 'unknown') if all_stocks else 'unknown',
|
||||
'updated_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
def check_service_status() -> Dict[str, Any]:
|
||||
"""
|
||||
检查服务状态
|
||||
|
||||
Returns:
|
||||
Dict: 服务状态信息
|
||||
|
||||
Example:
|
||||
>>> status = check_service_status()
|
||||
>>> print(f"MongoDB状态: {status['mongodb_status']}")
|
||||
"""
|
||||
if not SERVICE_AVAILABLE:
|
||||
return {
|
||||
'service_available': False,
|
||||
'error': '股票数据服务不可用',
|
||||
'suggestion': '请检查服务配置和依赖'
|
||||
}
|
||||
|
||||
service = get_stock_data_service()
|
||||
|
||||
# 检查MongoDB状态
|
||||
mongodb_status = 'disconnected'
|
||||
if service.db_manager and service.db_manager.mongodb_db:
|
||||
try:
|
||||
# 尝试执行一个简单的查询来测试连接
|
||||
service.db_manager.mongodb_db.list_collection_names()
|
||||
mongodb_status = 'connected'
|
||||
except Exception:
|
||||
mongodb_status = 'error'
|
||||
|
||||
# 检查通达信API状态
|
||||
tdx_status = 'unavailable'
|
||||
if service.tdx_provider:
|
||||
try:
|
||||
# 尝试获取一个股票名称来测试API
|
||||
test_name = service.tdx_provider._get_stock_name('000001')
|
||||
if test_name and test_name != '000001':
|
||||
tdx_status = 'available'
|
||||
else:
|
||||
tdx_status = 'limited'
|
||||
except Exception:
|
||||
tdx_status = 'error'
|
||||
|
||||
return {
|
||||
'service_available': True,
|
||||
'mongodb_status': mongodb_status,
|
||||
'tdx_api_status': tdx_status,
|
||||
'enhanced_fetcher_available': hasattr(service, '_get_from_tdx_api'),
|
||||
'fallback_available': True,
|
||||
'checked_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# 便捷的别名函数
|
||||
get_stock = get_stock_info # 别名
|
||||
get_stocks = get_all_stocks # 别名
|
||||
search = search_stocks # 别名
|
||||
status = check_service_status # 别名
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 简单的命令行测试
|
||||
print("🔍 股票数据API测试")
|
||||
print("=" * 50)
|
||||
|
||||
# 检查服务状态
|
||||
print("\n📊 服务状态检查:")
|
||||
status_info = check_service_status()
|
||||
for key, value in status_info.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
# 测试获取单个股票信息
|
||||
print("\n🏢 获取平安银行信息:")
|
||||
stock_info = get_stock_info('000001')
|
||||
if 'error' not in stock_info:
|
||||
print(f" 代码: {stock_info.get('code')}")
|
||||
print(f" 名称: {stock_info.get('name')}")
|
||||
print(f" 市场: {stock_info.get('market')}")
|
||||
print(f" 类别: {stock_info.get('category')}")
|
||||
print(f" 数据源: {stock_info.get('source')}")
|
||||
else:
|
||||
print(f" 错误: {stock_info.get('error')}")
|
||||
|
||||
# 测试搜索功能
|
||||
print("\n🔍 搜索'平安'相关股票:")
|
||||
search_results = search_stocks('平安')
|
||||
for i, stock in enumerate(search_results[:3]): # 只显示前3个结果
|
||||
if 'error' not in stock:
|
||||
print(f" {i+1}. {stock.get('code')}: {stock.get('name')}")
|
||||
|
||||
# 测试市场概览
|
||||
print("\n📈 市场概览:")
|
||||
summary = get_market_summary()
|
||||
if 'error' not in summary:
|
||||
print(f" 总股票数: {summary.get('total_count')}")
|
||||
print(f" 沪市股票: {summary.get('shanghai_count')}")
|
||||
print(f" 深市股票: {summary.get('shenzhen_count')}")
|
||||
print(f" 数据源: {summary.get('data_source')}")
|
||||
else:
|
||||
print(f" 错误: {summary.get('error')}")
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
"""
|
||||
配置管理模块
|
||||
"""
|
||||
|
||||
from .config_manager import config_manager, token_tracker, ModelConfig, PricingConfig, UsageRecord
|
||||
|
||||
__all__ = [
|
||||
'config_manager',
|
||||
'token_tracker',
|
||||
'ModelConfig',
|
||||
'PricingConfig',
|
||||
'UsageRecord'
|
||||
]
|
||||
|
|
@ -0,0 +1,574 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
配置管理器
|
||||
管理API密钥、模型配置、费率设置等
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
try:
|
||||
from .mongodb_storage import MongoDBStorage
|
||||
MONGODB_AVAILABLE = True
|
||||
except ImportError:
|
||||
MONGODB_AVAILABLE = False
|
||||
MongoDBStorage = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""模型配置"""
|
||||
provider: str # 供应商:dashscope, openai, google, etc.
|
||||
model_name: str # 模型名称
|
||||
api_key: str # API密钥
|
||||
base_url: Optional[str] = None # 自定义API地址
|
||||
max_tokens: int = 4000 # 最大token数
|
||||
temperature: float = 0.7 # 温度参数
|
||||
enabled: bool = True # 是否启用
|
||||
|
||||
|
||||
@dataclass
|
||||
class PricingConfig:
|
||||
"""定价配置"""
|
||||
provider: str # 供应商
|
||||
model_name: str # 模型名称
|
||||
input_price_per_1k: float # 输入token价格(每1000个token)
|
||||
output_price_per_1k: float # 输出token价格(每1000个token)
|
||||
currency: str = "CNY" # 货币单位
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageRecord:
|
||||
"""使用记录"""
|
||||
timestamp: str # 时间戳
|
||||
provider: str # 供应商
|
||||
model_name: str # 模型名称
|
||||
input_tokens: int # 输入token数
|
||||
output_tokens: int # 输出token数
|
||||
cost: float # 成本
|
||||
session_id: str # 会话ID
|
||||
analysis_type: str # 分析类型
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""配置管理器"""
|
||||
|
||||
def __init__(self, config_dir: str = "config"):
|
||||
self.config_dir = Path(config_dir)
|
||||
self.config_dir.mkdir(exist_ok=True)
|
||||
|
||||
self.models_file = self.config_dir / "models.json"
|
||||
self.pricing_file = self.config_dir / "pricing.json"
|
||||
self.usage_file = self.config_dir / "usage.json"
|
||||
self.settings_file = self.config_dir / "settings.json"
|
||||
|
||||
# 加载.env文件(保持向后兼容)
|
||||
self._load_env_file()
|
||||
|
||||
# 初始化MongoDB存储(如果可用)
|
||||
self.mongodb_storage = None
|
||||
self._init_mongodb_storage()
|
||||
|
||||
self._init_default_configs()
|
||||
|
||||
def _load_env_file(self):
|
||||
"""加载.env文件(保持向后兼容)"""
|
||||
# 尝试从项目根目录加载.env文件
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
env_file = project_root / ".env"
|
||||
|
||||
if env_file.exists():
|
||||
load_dotenv(env_file, override=True)
|
||||
|
||||
def _get_env_api_key(self, provider: str) -> str:
|
||||
"""从环境变量获取API密钥"""
|
||||
env_key_map = {
|
||||
"dashscope": "DASHSCOPE_API_KEY",
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"google": "GOOGLE_API_KEY",
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"deepseek": "DEEPSEEK_API_KEY"
|
||||
}
|
||||
|
||||
env_key = env_key_map.get(provider.lower())
|
||||
if env_key:
|
||||
return os.getenv(env_key, "")
|
||||
return ""
|
||||
|
||||
def _init_mongodb_storage(self):
|
||||
"""初始化MongoDB存储"""
|
||||
if not MONGODB_AVAILABLE:
|
||||
return
|
||||
|
||||
# 检查是否启用MongoDB存储
|
||||
use_mongodb = os.getenv("USE_MONGODB_STORAGE", "false").lower() == "true"
|
||||
if not use_mongodb:
|
||||
return
|
||||
|
||||
try:
|
||||
connection_string = os.getenv("MONGODB_CONNECTION_STRING")
|
||||
database_name = os.getenv("MONGODB_DATABASE_NAME", "tradingagents")
|
||||
|
||||
self.mongodb_storage = MongoDBStorage(
|
||||
connection_string=connection_string,
|
||||
database_name=database_name
|
||||
)
|
||||
|
||||
if self.mongodb_storage.is_connected():
|
||||
print("✅ MongoDB存储已启用")
|
||||
else:
|
||||
self.mongodb_storage = None
|
||||
print("⚠️ MongoDB连接失败,将使用JSON文件存储")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ MongoDB初始化失败: {e}")
|
||||
self.mongodb_storage = None
|
||||
|
||||
def _init_default_configs(self):
|
||||
"""初始化默认配置"""
|
||||
# 默认模型配置
|
||||
if not self.models_file.exists():
|
||||
default_models = [
|
||||
ModelConfig(
|
||||
provider="dashscope",
|
||||
model_name="qwen-turbo",
|
||||
api_key="",
|
||||
max_tokens=4000,
|
||||
temperature=0.7
|
||||
),
|
||||
ModelConfig(
|
||||
provider="dashscope",
|
||||
model_name="qwen-plus-latest",
|
||||
api_key="",
|
||||
max_tokens=8000,
|
||||
temperature=0.7
|
||||
),
|
||||
ModelConfig(
|
||||
provider="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
api_key="",
|
||||
max_tokens=4000,
|
||||
temperature=0.7,
|
||||
enabled=False
|
||||
),
|
||||
ModelConfig(
|
||||
provider="openai",
|
||||
model_name="gpt-4",
|
||||
api_key="",
|
||||
max_tokens=8000,
|
||||
temperature=0.7,
|
||||
enabled=False
|
||||
),
|
||||
ModelConfig(
|
||||
provider="google",
|
||||
model_name="gemini-pro",
|
||||
api_key="",
|
||||
max_tokens=4000,
|
||||
temperature=0.7,
|
||||
enabled=False
|
||||
)
|
||||
]
|
||||
self.save_models(default_models)
|
||||
|
||||
# 默认定价配置
|
||||
if not self.pricing_file.exists():
|
||||
default_pricing = [
|
||||
# 阿里百炼定价 (人民币)
|
||||
PricingConfig("dashscope", "qwen-turbo", 0.002, 0.006, "CNY"),
|
||||
PricingConfig("dashscope", "qwen-plus-latest", 0.004, 0.012, "CNY"),
|
||||
PricingConfig("dashscope", "qwen-max", 0.02, 0.06, "CNY"),
|
||||
|
||||
# OpenAI定价 (美元)
|
||||
PricingConfig("openai", "gpt-3.5-turbo", 0.0015, 0.002, "USD"),
|
||||
PricingConfig("openai", "gpt-4", 0.03, 0.06, "USD"),
|
||||
PricingConfig("openai", "gpt-4-turbo", 0.01, 0.03, "USD"),
|
||||
|
||||
# Google定价 (美元)
|
||||
PricingConfig("google", "gemini-pro", 0.00025, 0.0005, "USD"),
|
||||
PricingConfig("google", "gemini-pro-vision", 0.00025, 0.0005, "USD"),
|
||||
]
|
||||
self.save_pricing(default_pricing)
|
||||
|
||||
# 默认设置
|
||||
if not self.settings_file.exists():
|
||||
# 导入默认数据目录配置
|
||||
import os
|
||||
default_data_dir = os.path.join(os.path.expanduser("~"), "Documents", "TradingAgents", "data")
|
||||
|
||||
default_settings = {
|
||||
"default_provider": "dashscope",
|
||||
"default_model": "qwen-turbo",
|
||||
"enable_cost_tracking": True,
|
||||
"cost_alert_threshold": 100.0, # 成本警告阈值
|
||||
"currency_preference": "CNY",
|
||||
"auto_save_usage": True,
|
||||
"max_usage_records": 10000,
|
||||
"data_dir": default_data_dir, # 数据目录配置
|
||||
"cache_dir": os.path.join(default_data_dir, "cache"), # 缓存目录
|
||||
"results_dir": os.path.join(os.path.expanduser("~"), "Documents", "TradingAgents", "results"), # 结果目录
|
||||
"auto_create_dirs": True # 自动创建目录
|
||||
}
|
||||
self.save_settings(default_settings)
|
||||
|
||||
def load_models(self) -> List[ModelConfig]:
|
||||
"""加载模型配置,优先使用.env中的API密钥"""
|
||||
try:
|
||||
with open(self.models_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
models = [ModelConfig(**item) for item in data]
|
||||
|
||||
# 合并.env中的API密钥(优先级更高)
|
||||
for model in models:
|
||||
env_api_key = self._get_env_api_key(model.provider)
|
||||
if env_api_key:
|
||||
model.api_key = env_api_key
|
||||
# 如果.env中有API密钥,自动启用该模型
|
||||
if not model.enabled:
|
||||
model.enabled = True
|
||||
|
||||
return models
|
||||
except Exception as e:
|
||||
print(f"加载模型配置失败: {e}")
|
||||
return []
|
||||
|
||||
def save_models(self, models: List[ModelConfig]):
|
||||
"""保存模型配置"""
|
||||
try:
|
||||
data = [asdict(model) for model in models]
|
||||
with open(self.models_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"保存模型配置失败: {e}")
|
||||
|
||||
def load_pricing(self) -> List[PricingConfig]:
|
||||
"""加载定价配置"""
|
||||
try:
|
||||
with open(self.pricing_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return [PricingConfig(**item) for item in data]
|
||||
except Exception as e:
|
||||
print(f"加载定价配置失败: {e}")
|
||||
return []
|
||||
|
||||
def save_pricing(self, pricing: List[PricingConfig]):
|
||||
"""保存定价配置"""
|
||||
try:
|
||||
data = [asdict(price) for price in pricing]
|
||||
with open(self.pricing_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"保存定价配置失败: {e}")
|
||||
|
||||
def load_usage_records(self) -> List[UsageRecord]:
|
||||
"""加载使用记录"""
|
||||
try:
|
||||
if not self.usage_file.exists():
|
||||
return []
|
||||
with open(self.usage_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return [UsageRecord(**item) for item in data]
|
||||
except Exception as e:
|
||||
print(f"加载使用记录失败: {e}")
|
||||
return []
|
||||
|
||||
def save_usage_records(self, records: List[UsageRecord]):
|
||||
"""保存使用记录"""
|
||||
try:
|
||||
data = [asdict(record) for record in records]
|
||||
with open(self.usage_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"保存使用记录失败: {e}")
|
||||
|
||||
def add_usage_record(self, provider: str, model_name: str, input_tokens: int,
|
||||
output_tokens: int, session_id: str, analysis_type: str = "stock_analysis"):
|
||||
"""添加使用记录"""
|
||||
# 计算成本
|
||||
cost = self.calculate_cost(provider, model_name, input_tokens, output_tokens)
|
||||
|
||||
record = UsageRecord(
|
||||
timestamp=datetime.now().isoformat(),
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
session_id=session_id,
|
||||
analysis_type=analysis_type
|
||||
)
|
||||
|
||||
# 优先使用MongoDB存储
|
||||
if self.mongodb_storage and self.mongodb_storage.is_connected():
|
||||
success = self.mongodb_storage.save_usage_record(record)
|
||||
if success:
|
||||
return record
|
||||
else:
|
||||
print("⚠️ MongoDB保存失败,回退到JSON文件存储")
|
||||
|
||||
# 回退到JSON文件存储
|
||||
records = self.load_usage_records()
|
||||
records.append(record)
|
||||
|
||||
# 限制记录数量
|
||||
settings = self.load_settings()
|
||||
max_records = settings.get("max_usage_records", 10000)
|
||||
if len(records) > max_records:
|
||||
records = records[-max_records:]
|
||||
|
||||
self.save_usage_records(records)
|
||||
return record
|
||||
|
||||
def calculate_cost(self, provider: str, model_name: str, input_tokens: int, output_tokens: int) -> float:
|
||||
"""计算使用成本"""
|
||||
pricing_configs = self.load_pricing()
|
||||
|
||||
for pricing in pricing_configs:
|
||||
if pricing.provider == provider and pricing.model_name == model_name:
|
||||
input_cost = (input_tokens / 1000) * pricing.input_price_per_1k
|
||||
output_cost = (output_tokens / 1000) * pricing.output_price_per_1k
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
return 0.0
|
||||
|
||||
def load_settings(self) -> Dict[str, Any]:
|
||||
"""加载设置,合并.env中的配置"""
|
||||
try:
|
||||
with open(self.settings_file, 'r', encoding='utf-8') as f:
|
||||
settings = json.load(f)
|
||||
except Exception as e:
|
||||
print(f"加载设置失败: {e}")
|
||||
settings = {}
|
||||
|
||||
# 合并.env中的其他配置
|
||||
env_settings = {
|
||||
"finnhub_api_key": os.getenv("FINNHUB_API_KEY", ""),
|
||||
"reddit_client_id": os.getenv("REDDIT_CLIENT_ID", ""),
|
||||
"reddit_client_secret": os.getenv("REDDIT_CLIENT_SECRET", ""),
|
||||
"reddit_user_agent": os.getenv("REDDIT_USER_AGENT", ""),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", ""),
|
||||
"log_level": os.getenv("TRADINGAGENTS_LOG_LEVEL", "INFO"),
|
||||
"data_dir": os.getenv("TRADINGAGENTS_DATA_DIR", ""), # 数据目录环境变量
|
||||
"cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", ""), # 缓存目录环境变量
|
||||
}
|
||||
|
||||
# 只有当环境变量存在且不为空时才覆盖
|
||||
for key, value in env_settings.items():
|
||||
if value:
|
||||
settings[key] = value
|
||||
|
||||
return settings
|
||||
|
||||
def get_env_config_status(self) -> Dict[str, Any]:
|
||||
"""获取.env配置状态"""
|
||||
return {
|
||||
"env_file_exists": (Path(__file__).parent.parent.parent / ".env").exists(),
|
||||
"api_keys": {
|
||||
"dashscope": bool(os.getenv("DASHSCOPE_API_KEY")),
|
||||
"openai": bool(os.getenv("OPENAI_API_KEY")),
|
||||
"google": bool(os.getenv("GOOGLE_API_KEY")),
|
||||
"anthropic": bool(os.getenv("ANTHROPIC_API_KEY")),
|
||||
"finnhub": bool(os.getenv("FINNHUB_API_KEY")),
|
||||
},
|
||||
"other_configs": {
|
||||
"reddit_configured": bool(os.getenv("REDDIT_CLIENT_ID") and os.getenv("REDDIT_CLIENT_SECRET")),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||
"log_level": os.getenv("TRADINGAGENTS_LOG_LEVEL", "INFO"),
|
||||
}
|
||||
}
|
||||
|
||||
def save_settings(self, settings: Dict[str, Any]):
|
||||
"""保存设置"""
|
||||
try:
|
||||
with open(self.settings_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(settings, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"保存设置失败: {e}")
|
||||
|
||||
def get_enabled_models(self) -> List[ModelConfig]:
|
||||
"""获取启用的模型"""
|
||||
models = self.load_models()
|
||||
return [model for model in models if model.enabled and model.api_key]
|
||||
|
||||
def get_model_by_name(self, provider: str, model_name: str) -> Optional[ModelConfig]:
|
||||
"""根据名称获取模型配置"""
|
||||
models = self.load_models()
|
||||
for model in models:
|
||||
if model.provider == provider and model.model_name == model_name:
|
||||
return model
|
||||
return None
|
||||
|
||||
def get_usage_statistics(self, days: int = 30) -> Dict[str, Any]:
|
||||
"""获取使用统计"""
|
||||
# 优先使用MongoDB获取统计
|
||||
if self.mongodb_storage and self.mongodb_storage.is_connected():
|
||||
try:
|
||||
# 从MongoDB获取基础统计
|
||||
stats = self.mongodb_storage.get_usage_statistics(days)
|
||||
# 获取供应商统计
|
||||
provider_stats = self.mongodb_storage.get_provider_statistics(days)
|
||||
|
||||
if stats:
|
||||
stats["provider_stats"] = provider_stats
|
||||
stats["records_count"] = stats.get("total_requests", 0)
|
||||
return stats
|
||||
except Exception as e:
|
||||
print(f"⚠️ MongoDB统计获取失败,回退到JSON文件: {e}")
|
||||
|
||||
# 回退到JSON文件统计
|
||||
records = self.load_usage_records()
|
||||
|
||||
# 过滤最近N天的记录
|
||||
from datetime import datetime, timedelta
|
||||
cutoff_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
recent_records = []
|
||||
for record in records:
|
||||
try:
|
||||
record_date = datetime.fromisoformat(record.timestamp)
|
||||
if record_date >= cutoff_date:
|
||||
recent_records.append(record)
|
||||
except:
|
||||
continue
|
||||
|
||||
# 统计数据
|
||||
total_cost = sum(record.cost for record in recent_records)
|
||||
total_input_tokens = sum(record.input_tokens for record in recent_records)
|
||||
total_output_tokens = sum(record.output_tokens for record in recent_records)
|
||||
|
||||
# 按供应商统计
|
||||
provider_stats = {}
|
||||
for record in recent_records:
|
||||
if record.provider not in provider_stats:
|
||||
provider_stats[record.provider] = {
|
||||
"cost": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"requests": 0
|
||||
}
|
||||
provider_stats[record.provider]["cost"] += record.cost
|
||||
provider_stats[record.provider]["input_tokens"] += record.input_tokens
|
||||
provider_stats[record.provider]["output_tokens"] += record.output_tokens
|
||||
provider_stats[record.provider]["requests"] += 1
|
||||
|
||||
return {
|
||||
"period_days": days,
|
||||
"total_cost": round(total_cost, 4),
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_requests": len(recent_records),
|
||||
"provider_stats": provider_stats,
|
||||
"records_count": len(recent_records)
|
||||
}
|
||||
|
||||
def get_data_dir(self) -> str:
|
||||
"""获取数据目录路径"""
|
||||
settings = self.load_settings()
|
||||
data_dir = settings.get("data_dir")
|
||||
if not data_dir:
|
||||
# 如果没有配置,使用默认路径
|
||||
data_dir = os.path.join(os.path.expanduser("~"), "Documents", "TradingAgents", "data")
|
||||
return data_dir
|
||||
|
||||
def set_data_dir(self, data_dir: str):
|
||||
"""设置数据目录路径"""
|
||||
settings = self.load_settings()
|
||||
settings["data_dir"] = data_dir
|
||||
# 同时更新缓存目录
|
||||
settings["cache_dir"] = os.path.join(data_dir, "cache")
|
||||
self.save_settings(settings)
|
||||
|
||||
# 如果启用自动创建目录,则创建目录
|
||||
if settings.get("auto_create_dirs", True):
|
||||
self.ensure_directories_exist()
|
||||
|
||||
def ensure_directories_exist(self):
|
||||
"""确保必要的目录存在"""
|
||||
settings = self.load_settings()
|
||||
|
||||
directories = [
|
||||
settings.get("data_dir"),
|
||||
settings.get("cache_dir"),
|
||||
settings.get("results_dir"),
|
||||
os.path.join(settings.get("data_dir", ""), "finnhub_data"),
|
||||
os.path.join(settings.get("data_dir", ""), "finnhub_data", "news_data"),
|
||||
os.path.join(settings.get("data_dir", ""), "finnhub_data", "insider_sentiment"),
|
||||
os.path.join(settings.get("data_dir", ""), "finnhub_data", "insider_transactions")
|
||||
]
|
||||
|
||||
for directory in directories:
|
||||
if directory and not os.path.exists(directory):
|
||||
try:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
print(f"✅ 创建目录: {directory}")
|
||||
except Exception as e:
|
||||
print(f"❌ 创建目录失败 {directory}: {e}")
|
||||
|
||||
|
||||
class TokenTracker:
|
||||
"""Token使用跟踪器"""
|
||||
|
||||
def __init__(self, config_manager: ConfigManager):
|
||||
self.config_manager = config_manager
|
||||
|
||||
def track_usage(self, provider: str, model_name: str, input_tokens: int,
|
||||
output_tokens: int, session_id: str = None, analysis_type: str = "stock_analysis"):
|
||||
"""跟踪Token使用"""
|
||||
if session_id is None:
|
||||
session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
# 检查是否启用成本跟踪
|
||||
settings = self.config_manager.load_settings()
|
||||
if not settings.get("enable_cost_tracking", True):
|
||||
return None
|
||||
|
||||
# 添加使用记录
|
||||
record = self.config_manager.add_usage_record(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
session_id=session_id,
|
||||
analysis_type=analysis_type
|
||||
)
|
||||
|
||||
# 检查成本警告
|
||||
self._check_cost_alert(record.cost)
|
||||
|
||||
return record
|
||||
|
||||
def _check_cost_alert(self, current_cost: float):
|
||||
"""检查成本警告"""
|
||||
settings = self.config_manager.load_settings()
|
||||
threshold = settings.get("cost_alert_threshold", 100.0)
|
||||
|
||||
# 获取今日总成本
|
||||
today_stats = self.config_manager.get_usage_statistics(1)
|
||||
total_today = today_stats["total_cost"]
|
||||
|
||||
if total_today >= threshold:
|
||||
print(f"⚠️ 成本警告: 今日成本已达到 ¥{total_today:.4f},超过阈值 ¥{threshold}")
|
||||
|
||||
def get_session_cost(self, session_id: str) -> float:
|
||||
"""获取会话成本"""
|
||||
records = self.config_manager.load_usage_records()
|
||||
session_cost = sum(record.cost for record in records if record.session_id == session_id)
|
||||
return session_cost
|
||||
|
||||
def estimate_cost(self, provider: str, model_name: str, estimated_input_tokens: int,
|
||||
estimated_output_tokens: int) -> float:
|
||||
"""估算成本"""
|
||||
return self.config_manager.calculate_cost(
|
||||
provider, model_name, estimated_input_tokens, estimated_output_tokens
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
# 全局配置管理器实例
|
||||
config_manager = ConfigManager()
|
||||
token_tracker = TokenTracker(config_manager)
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
数据库配置管理模块
|
||||
统一管理MongoDB和Redis的连接配置
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class DatabaseConfig:
|
||||
"""数据库配置管理类"""
|
||||
|
||||
@staticmethod
|
||||
def get_mongodb_config() -> Dict[str, Any]:
|
||||
"""
|
||||
获取MongoDB配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: MongoDB配置字典
|
||||
|
||||
Raises:
|
||||
ValueError: 当必要的配置未设置时
|
||||
"""
|
||||
connection_string = os.getenv('MONGODB_CONNECTION_STRING')
|
||||
if not connection_string:
|
||||
raise ValueError(
|
||||
"MongoDB连接字符串未配置。请设置环境变量 MONGODB_CONNECTION_STRING\n"
|
||||
"例如: MONGODB_CONNECTION_STRING=mongodb://localhost:27017/"
|
||||
)
|
||||
|
||||
return {
|
||||
'connection_string': connection_string,
|
||||
'database': os.getenv('MONGODB_DATABASE', 'tradingagents'),
|
||||
'auth_source': os.getenv('MONGODB_AUTH_SOURCE', 'admin')
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_redis_config() -> Dict[str, Any]:
|
||||
"""
|
||||
获取Redis配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Redis配置字典
|
||||
|
||||
Raises:
|
||||
ValueError: 当必要的配置未设置时
|
||||
"""
|
||||
# 优先使用连接字符串
|
||||
connection_string = os.getenv('REDIS_CONNECTION_STRING')
|
||||
if connection_string:
|
||||
return {
|
||||
'connection_string': connection_string,
|
||||
'database': int(os.getenv('REDIS_DATABASE', 0))
|
||||
}
|
||||
|
||||
# 使用分离的配置参数
|
||||
host = os.getenv('REDIS_HOST')
|
||||
port = os.getenv('REDIS_PORT')
|
||||
|
||||
if not host or not port:
|
||||
raise ValueError(
|
||||
"Redis连接配置未完整设置。请设置以下环境变量之一:\n"
|
||||
"1. REDIS_CONNECTION_STRING=redis://localhost:6379/0\n"
|
||||
"2. REDIS_HOST + REDIS_PORT (例如: REDIS_HOST=localhost, REDIS_PORT=6379)"
|
||||
)
|
||||
|
||||
return {
|
||||
'host': host,
|
||||
'port': int(port),
|
||||
'password': os.getenv('REDIS_PASSWORD'),
|
||||
'database': int(os.getenv('REDIS_DATABASE', 0))
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def validate_config() -> Dict[str, bool]:
|
||||
"""
|
||||
验证数据库配置是否完整
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: 验证结果
|
||||
"""
|
||||
result = {
|
||||
'mongodb_valid': False,
|
||||
'redis_valid': False
|
||||
}
|
||||
|
||||
try:
|
||||
DatabaseConfig.get_mongodb_config()
|
||||
result['mongodb_valid'] = True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
DatabaseConfig.get_redis_config()
|
||||
result['redis_valid'] = True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_config_status() -> str:
|
||||
"""
|
||||
获取配置状态的友好描述
|
||||
|
||||
Returns:
|
||||
str: 配置状态描述
|
||||
"""
|
||||
validation = DatabaseConfig.validate_config()
|
||||
|
||||
if validation['mongodb_valid'] and validation['redis_valid']:
|
||||
return "✅ 所有数据库配置正常"
|
||||
elif validation['mongodb_valid']:
|
||||
return "⚠️ MongoDB配置正常,Redis配置缺失"
|
||||
elif validation['redis_valid']:
|
||||
return "⚠️ Redis配置正常,MongoDB配置缺失"
|
||||
else:
|
||||
return "❌ 数据库配置缺失,请检查环境变量"
|
||||
|
|
@ -0,0 +1,360 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
智能数据库管理器
|
||||
自动检测MongoDB和Redis可用性,提供降级方案
|
||||
使用项目现有的.env配置
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
class DatabaseManager:
|
||||
"""智能数据库管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# 加载.env配置
|
||||
self._load_env_config()
|
||||
|
||||
# 数据库连接状态
|
||||
self.mongodb_available = False
|
||||
self.redis_available = False
|
||||
self.mongodb_client = None
|
||||
self.redis_client = None
|
||||
|
||||
# 检测数据库可用性
|
||||
self._detect_databases()
|
||||
|
||||
# 初始化连接
|
||||
self._initialize_connections()
|
||||
|
||||
self.logger.info(f"数据库管理器初始化完成 - MongoDB: {self.mongodb_available}, Redis: {self.redis_available}")
|
||||
|
||||
def _load_env_config(self):
|
||||
"""从.env文件加载配置"""
|
||||
# 尝试加载python-dotenv
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except ImportError:
|
||||
self.logger.info("python-dotenv未安装,直接读取环境变量")
|
||||
|
||||
# 读取启用开关
|
||||
self.mongodb_enabled = os.getenv("MONGODB_ENABLED", "false").lower() == "true"
|
||||
self.redis_enabled = os.getenv("REDIS_ENABLED", "false").lower() == "true"
|
||||
|
||||
# 从环境变量读取MongoDB配置
|
||||
self.mongodb_config = {
|
||||
"enabled": self.mongodb_enabled,
|
||||
"host": os.getenv("MONGODB_HOST", "localhost"),
|
||||
"port": int(os.getenv("MONGODB_PORT", "27017")),
|
||||
"username": os.getenv("MONGODB_USERNAME"),
|
||||
"password": os.getenv("MONGODB_PASSWORD"),
|
||||
"database": os.getenv("MONGODB_DATABASE", "tradingagents"),
|
||||
"auth_source": os.getenv("MONGODB_AUTH_SOURCE", "admin"),
|
||||
"timeout": 2000
|
||||
}
|
||||
|
||||
# 从环境变量读取Redis配置
|
||||
self.redis_config = {
|
||||
"enabled": self.redis_enabled,
|
||||
"host": os.getenv("REDIS_HOST", "localhost"),
|
||||
"port": int(os.getenv("REDIS_PORT", "6379")),
|
||||
"password": os.getenv("REDIS_PASSWORD"),
|
||||
"db": int(os.getenv("REDIS_DB", "0")),
|
||||
"timeout": 2
|
||||
}
|
||||
|
||||
self.logger.info(f"MongoDB启用: {self.mongodb_enabled}")
|
||||
self.logger.info(f"Redis启用: {self.redis_enabled}")
|
||||
if self.mongodb_enabled:
|
||||
self.logger.info(f"MongoDB配置: {self.mongodb_config['host']}:{self.mongodb_config['port']}")
|
||||
if self.redis_enabled:
|
||||
self.logger.info(f"Redis配置: {self.redis_config['host']}:{self.redis_config['port']}")
|
||||
|
||||
|
||||
|
||||
def _detect_mongodb(self) -> Tuple[bool, str]:
|
||||
"""检测MongoDB是否可用"""
|
||||
# 首先检查是否启用
|
||||
if not self.mongodb_enabled:
|
||||
return False, "MongoDB未启用 (MONGODB_ENABLED=false)"
|
||||
|
||||
try:
|
||||
import pymongo
|
||||
from pymongo import MongoClient
|
||||
|
||||
# 构建连接参数
|
||||
connect_kwargs = {
|
||||
"host": self.mongodb_config["host"],
|
||||
"port": self.mongodb_config["port"],
|
||||
"serverSelectionTimeoutMS": self.mongodb_config["timeout"],
|
||||
"connectTimeoutMS": self.mongodb_config["timeout"]
|
||||
}
|
||||
|
||||
# 如果有用户名和密码,添加认证
|
||||
if self.mongodb_config["username"] and self.mongodb_config["password"]:
|
||||
connect_kwargs.update({
|
||||
"username": self.mongodb_config["username"],
|
||||
"password": self.mongodb_config["password"],
|
||||
"authSource": self.mongodb_config["auth_source"]
|
||||
})
|
||||
|
||||
client = MongoClient(**connect_kwargs)
|
||||
|
||||
# 测试连接
|
||||
client.server_info()
|
||||
client.close()
|
||||
|
||||
return True, "MongoDB连接成功"
|
||||
|
||||
except ImportError:
|
||||
return False, "pymongo未安装"
|
||||
except Exception as e:
|
||||
return False, f"MongoDB连接失败: {str(e)}"
|
||||
|
||||
def _detect_redis(self) -> Tuple[bool, str]:
|
||||
"""检测Redis是否可用"""
|
||||
# 首先检查是否启用
|
||||
if not self.redis_enabled:
|
||||
return False, "Redis未启用 (REDIS_ENABLED=false)"
|
||||
|
||||
try:
|
||||
import redis
|
||||
|
||||
# 构建连接参数
|
||||
connect_kwargs = {
|
||||
"host": self.redis_config["host"],
|
||||
"port": self.redis_config["port"],
|
||||
"db": self.redis_config["db"],
|
||||
"socket_timeout": self.redis_config["timeout"],
|
||||
"socket_connect_timeout": self.redis_config["timeout"]
|
||||
}
|
||||
|
||||
# 如果有密码,添加密码
|
||||
if self.redis_config["password"]:
|
||||
connect_kwargs["password"] = self.redis_config["password"]
|
||||
|
||||
client = redis.Redis(**connect_kwargs)
|
||||
|
||||
# 测试连接
|
||||
client.ping()
|
||||
|
||||
return True, "Redis连接成功"
|
||||
|
||||
except ImportError:
|
||||
return False, "redis未安装"
|
||||
except Exception as e:
|
||||
return False, f"Redis连接失败: {str(e)}"
|
||||
|
||||
def _detect_databases(self):
|
||||
"""检测所有数据库"""
|
||||
self.logger.info("开始检测数据库可用性...")
|
||||
|
||||
# 检测MongoDB
|
||||
mongodb_available, mongodb_msg = self._detect_mongodb()
|
||||
self.mongodb_available = mongodb_available
|
||||
|
||||
if mongodb_available:
|
||||
self.logger.info(f"✅ MongoDB: {mongodb_msg}")
|
||||
else:
|
||||
self.logger.info(f"❌ MongoDB: {mongodb_msg}")
|
||||
|
||||
# 检测Redis
|
||||
redis_available, redis_msg = self._detect_redis()
|
||||
self.redis_available = redis_available
|
||||
|
||||
if redis_available:
|
||||
self.logger.info(f"✅ Redis: {redis_msg}")
|
||||
else:
|
||||
self.logger.info(f"❌ Redis: {redis_msg}")
|
||||
|
||||
# 更新配置
|
||||
self._update_config_based_on_detection()
|
||||
|
||||
def _update_config_based_on_detection(self):
|
||||
"""根据检测结果更新配置"""
|
||||
# 确定缓存后端
|
||||
if self.redis_available:
|
||||
self.primary_backend = "redis"
|
||||
elif self.mongodb_available:
|
||||
self.primary_backend = "mongodb"
|
||||
else:
|
||||
self.primary_backend = "file"
|
||||
|
||||
self.logger.info(f"主要缓存后端: {self.primary_backend}")
|
||||
|
||||
def _initialize_connections(self):
|
||||
"""初始化数据库连接"""
|
||||
# 初始化MongoDB连接
|
||||
if self.mongodb_available:
|
||||
try:
|
||||
import pymongo
|
||||
|
||||
# 构建连接参数
|
||||
connect_kwargs = {
|
||||
"host": self.mongodb_config["host"],
|
||||
"port": self.mongodb_config["port"],
|
||||
"serverSelectionTimeoutMS": self.mongodb_config["timeout"]
|
||||
}
|
||||
|
||||
# 如果有用户名和密码,添加认证
|
||||
if self.mongodb_config["username"] and self.mongodb_config["password"]:
|
||||
connect_kwargs.update({
|
||||
"username": self.mongodb_config["username"],
|
||||
"password": self.mongodb_config["password"],
|
||||
"authSource": self.mongodb_config["auth_source"]
|
||||
})
|
||||
|
||||
self.mongodb_client = pymongo.MongoClient(**connect_kwargs)
|
||||
self.logger.info("MongoDB客户端初始化成功")
|
||||
except Exception as e:
|
||||
self.logger.error(f"MongoDB客户端初始化失败: {e}")
|
||||
self.mongodb_available = False
|
||||
|
||||
# 初始化Redis连接
|
||||
if self.redis_available:
|
||||
try:
|
||||
import redis
|
||||
|
||||
# 构建连接参数
|
||||
connect_kwargs = {
|
||||
"host": self.redis_config["host"],
|
||||
"port": self.redis_config["port"],
|
||||
"db": self.redis_config["db"],
|
||||
"socket_timeout": self.redis_config["timeout"]
|
||||
}
|
||||
|
||||
# 如果有密码,添加密码
|
||||
if self.redis_config["password"]:
|
||||
connect_kwargs["password"] = self.redis_config["password"]
|
||||
|
||||
self.redis_client = redis.Redis(**connect_kwargs)
|
||||
self.logger.info("Redis客户端初始化成功")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Redis客户端初始化失败: {e}")
|
||||
self.redis_available = False
|
||||
|
||||
def get_mongodb_client(self):
|
||||
"""获取MongoDB客户端"""
|
||||
if self.mongodb_available and self.mongodb_client:
|
||||
return self.mongodb_client
|
||||
return None
|
||||
|
||||
def get_redis_client(self):
|
||||
"""获取Redis客户端"""
|
||||
if self.redis_available and self.redis_client:
|
||||
return self.redis_client
|
||||
return None
|
||||
|
||||
def is_mongodb_available(self) -> bool:
|
||||
"""检查MongoDB是否可用"""
|
||||
return self.mongodb_available
|
||||
|
||||
def is_redis_available(self) -> bool:
|
||||
"""检查Redis是否可用"""
|
||||
return self.redis_available
|
||||
|
||||
def is_database_available(self) -> bool:
|
||||
"""检查是否有任何数据库可用"""
|
||||
return self.mongodb_available or self.redis_available
|
||||
|
||||
def get_cache_backend(self) -> str:
|
||||
"""获取当前缓存后端"""
|
||||
return self.primary_backend
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""获取配置信息"""
|
||||
return {
|
||||
"mongodb": self.mongodb_config,
|
||||
"redis": self.redis_config,
|
||||
"primary_backend": self.primary_backend,
|
||||
"mongodb_available": self.mongodb_available,
|
||||
"redis_available": self.redis_available
|
||||
}
|
||||
|
||||
def get_status_report(self) -> Dict[str, Any]:
|
||||
"""获取状态报告"""
|
||||
return {
|
||||
"database_available": self.is_database_available(),
|
||||
"mongodb": {
|
||||
"available": self.mongodb_available,
|
||||
"host": self.mongodb_config["host"],
|
||||
"port": self.mongodb_config["port"]
|
||||
},
|
||||
"redis": {
|
||||
"available": self.redis_available,
|
||||
"host": self.redis_config["host"],
|
||||
"port": self.redis_config["port"]
|
||||
},
|
||||
"cache_backend": self.get_cache_backend(),
|
||||
"fallback_enabled": True # 总是启用降级
|
||||
}
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
stats = {
|
||||
"mongodb_available": self.mongodb_available,
|
||||
"redis_available": self.redis_available,
|
||||
"redis_keys": 0,
|
||||
"redis_memory": "N/A"
|
||||
}
|
||||
|
||||
# Redis统计
|
||||
if self.redis_available and self.redis_client:
|
||||
try:
|
||||
info = self.redis_client.info()
|
||||
stats["redis_keys"] = self.redis_client.dbsize()
|
||||
stats["redis_memory"] = info.get("used_memory_human", "N/A")
|
||||
except Exception as e:
|
||||
self.logger.error(f"获取Redis统计失败: {e}")
|
||||
|
||||
return stats
|
||||
|
||||
def cache_clear_pattern(self, pattern: str) -> int:
|
||||
"""清理匹配模式的缓存"""
|
||||
cleared_count = 0
|
||||
|
||||
if self.redis_available and self.redis_client:
|
||||
try:
|
||||
keys = self.redis_client.keys(pattern)
|
||||
if keys:
|
||||
cleared_count += self.redis_client.delete(*keys)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Redis缓存清理失败: {e}")
|
||||
|
||||
return cleared_count
|
||||
|
||||
|
||||
# 全局数据库管理器实例
|
||||
_database_manager = None
|
||||
|
||||
def get_database_manager() -> DatabaseManager:
|
||||
"""获取全局数据库管理器实例"""
|
||||
global _database_manager
|
||||
if _database_manager is None:
|
||||
_database_manager = DatabaseManager()
|
||||
return _database_manager
|
||||
|
||||
def is_mongodb_available() -> bool:
|
||||
"""检查MongoDB是否可用"""
|
||||
return get_database_manager().is_mongodb_available()
|
||||
|
||||
def is_redis_available() -> bool:
|
||||
"""检查Redis是否可用"""
|
||||
return get_database_manager().is_redis_available()
|
||||
|
||||
def get_cache_backend() -> str:
|
||||
"""获取当前缓存后端"""
|
||||
return get_database_manager().get_cache_backend()
|
||||
|
||||
def get_mongodb_client():
|
||||
"""获取MongoDB客户端"""
|
||||
return get_database_manager().get_mongodb_client()
|
||||
|
||||
def get_redis_client():
|
||||
"""获取Redis客户端"""
|
||||
return get_database_manager().get_redis_client()
|
||||
|
|
@ -0,0 +1,285 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
MongoDB存储适配器
|
||||
用于将token使用记录存储到MongoDB数据库
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import asdict
|
||||
from .config_manager import UsageRecord
|
||||
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
|
||||
MONGODB_AVAILABLE = True
|
||||
except ImportError:
|
||||
MONGODB_AVAILABLE = False
|
||||
MongoClient = None
|
||||
|
||||
|
||||
class MongoDBStorage:
|
||||
"""MongoDB存储适配器"""
|
||||
|
||||
def __init__(self, connection_string: str = None, database_name: str = "tradingagents"):
|
||||
if not MONGODB_AVAILABLE:
|
||||
raise ImportError("pymongo is not installed. Please install it with: pip install pymongo")
|
||||
|
||||
# 修复硬编码问题 - 如果没有提供连接字符串且环境变量也未设置,则抛出错误
|
||||
self.connection_string = connection_string or os.getenv("MONGODB_CONNECTION_STRING")
|
||||
if not self.connection_string:
|
||||
raise ValueError(
|
||||
"MongoDB连接字符串未配置。请通过以下方式之一进行配置:\n"
|
||||
"1. 设置环境变量 MONGODB_CONNECTION_STRING\n"
|
||||
"2. 在初始化时传入 connection_string 参数\n"
|
||||
"例如: MONGODB_CONNECTION_STRING=mongodb://localhost:27017/"
|
||||
)
|
||||
|
||||
self.database_name = database_name
|
||||
self.collection_name = "token_usage"
|
||||
|
||||
self.client = None
|
||||
self.db = None
|
||||
self.collection = None
|
||||
self._connected = False
|
||||
|
||||
# 尝试连接
|
||||
self._connect()
|
||||
|
||||
def _connect(self):
|
||||
"""连接到MongoDB"""
|
||||
try:
|
||||
self.client = MongoClient(
|
||||
self.connection_string,
|
||||
serverSelectionTimeoutMS=5000 # 5秒超时
|
||||
)
|
||||
# 测试连接
|
||||
self.client.admin.command('ping')
|
||||
|
||||
self.db = self.client[self.database_name]
|
||||
self.collection = self.db[self.collection_name]
|
||||
|
||||
# 创建索引以提高查询性能
|
||||
self._create_indexes()
|
||||
|
||||
self._connected = True
|
||||
print(f"✅ MongoDB连接成功: {self.database_name}.{self.collection_name}")
|
||||
|
||||
except (ConnectionFailure, ServerSelectionTimeoutError) as e:
|
||||
print(f"❌ MongoDB连接失败: {e}")
|
||||
print("将使用本地JSON文件存储")
|
||||
self._connected = False
|
||||
except Exception as e:
|
||||
print(f"❌ MongoDB初始化失败: {e}")
|
||||
self._connected = False
|
||||
|
||||
def _create_indexes(self):
|
||||
"""创建数据库索引"""
|
||||
try:
|
||||
# 创建复合索引
|
||||
self.collection.create_index([
|
||||
("timestamp", -1), # 按时间倒序
|
||||
("provider", 1),
|
||||
("model_name", 1)
|
||||
])
|
||||
|
||||
# 创建会话ID索引
|
||||
self.collection.create_index("session_id")
|
||||
|
||||
# 创建分析类型索引
|
||||
self.collection.create_index("analysis_type")
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建MongoDB索引失败: {e}")
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""检查是否连接到MongoDB"""
|
||||
return self._connected
|
||||
|
||||
def save_usage_record(self, record: UsageRecord) -> bool:
|
||||
"""保存单个使用记录到MongoDB"""
|
||||
if not self._connected:
|
||||
return False
|
||||
|
||||
try:
|
||||
# 转换为字典格式
|
||||
record_dict = asdict(record)
|
||||
|
||||
# 添加MongoDB特有的字段
|
||||
record_dict['_created_at'] = datetime.now()
|
||||
|
||||
# 插入记录
|
||||
result = self.collection.insert_one(record_dict)
|
||||
|
||||
if result.inserted_id:
|
||||
return True
|
||||
else:
|
||||
print("MongoDB插入失败:未返回插入ID")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存记录到MongoDB失败: {e}")
|
||||
return False
|
||||
|
||||
def load_usage_records(self, limit: int = 10000, days: int = None) -> List[UsageRecord]:
|
||||
"""从MongoDB加载使用记录"""
|
||||
if not self._connected:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 构建查询条件
|
||||
query = {}
|
||||
if days:
|
||||
from datetime import timedelta
|
||||
cutoff_date = datetime.now() - timedelta(days=days)
|
||||
query['timestamp'] = {'$gte': cutoff_date.isoformat()}
|
||||
|
||||
# 查询记录,按时间倒序
|
||||
cursor = self.collection.find(query).sort('timestamp', -1).limit(limit)
|
||||
|
||||
records = []
|
||||
for doc in cursor:
|
||||
# 移除MongoDB特有的字段
|
||||
doc.pop('_id', None)
|
||||
doc.pop('_created_at', None)
|
||||
|
||||
# 转换为UsageRecord对象
|
||||
try:
|
||||
record = UsageRecord(**doc)
|
||||
records.append(record)
|
||||
except Exception as e:
|
||||
print(f"解析记录失败: {e}, 记录: {doc}")
|
||||
continue
|
||||
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
print(f"从MongoDB加载记录失败: {e}")
|
||||
return []
|
||||
|
||||
def get_usage_statistics(self, days: int = 30) -> Dict[str, Any]:
|
||||
"""从MongoDB获取使用统计"""
|
||||
if not self._connected:
|
||||
return {}
|
||||
|
||||
try:
|
||||
from datetime import timedelta
|
||||
cutoff_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
# 聚合查询
|
||||
pipeline = [
|
||||
{
|
||||
'$match': {
|
||||
'timestamp': {'$gte': cutoff_date.isoformat()}
|
||||
}
|
||||
},
|
||||
{
|
||||
'$group': {
|
||||
'_id': None,
|
||||
'total_cost': {'$sum': '$cost'},
|
||||
'total_input_tokens': {'$sum': '$input_tokens'},
|
||||
'total_output_tokens': {'$sum': '$output_tokens'},
|
||||
'total_requests': {'$sum': 1}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
result = list(self.collection.aggregate(pipeline))
|
||||
|
||||
if result:
|
||||
stats = result[0]
|
||||
return {
|
||||
'period_days': days,
|
||||
'total_cost': round(stats.get('total_cost', 0), 4),
|
||||
'total_input_tokens': stats.get('total_input_tokens', 0),
|
||||
'total_output_tokens': stats.get('total_output_tokens', 0),
|
||||
'total_requests': stats.get('total_requests', 0)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'period_days': days,
|
||||
'total_cost': 0,
|
||||
'total_input_tokens': 0,
|
||||
'total_output_tokens': 0,
|
||||
'total_requests': 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取MongoDB统计失败: {e}")
|
||||
return {}
|
||||
|
||||
def get_provider_statistics(self, days: int = 30) -> Dict[str, Dict[str, Any]]:
|
||||
"""按供应商获取统计信息"""
|
||||
if not self._connected:
|
||||
return {}
|
||||
|
||||
try:
|
||||
from datetime import timedelta
|
||||
cutoff_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
# 按供应商聚合
|
||||
pipeline = [
|
||||
{
|
||||
'$match': {
|
||||
'timestamp': {'$gte': cutoff_date.isoformat()}
|
||||
}
|
||||
},
|
||||
{
|
||||
'$group': {
|
||||
'_id': '$provider',
|
||||
'cost': {'$sum': '$cost'},
|
||||
'input_tokens': {'$sum': '$input_tokens'},
|
||||
'output_tokens': {'$sum': '$output_tokens'},
|
||||
'requests': {'$sum': 1}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
results = list(self.collection.aggregate(pipeline))
|
||||
|
||||
provider_stats = {}
|
||||
for result in results:
|
||||
provider = result['_id']
|
||||
provider_stats[provider] = {
|
||||
'cost': round(result.get('cost', 0), 4),
|
||||
'input_tokens': result.get('input_tokens', 0),
|
||||
'output_tokens': result.get('output_tokens', 0),
|
||||
'requests': result.get('requests', 0)
|
||||
}
|
||||
|
||||
return provider_stats
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取供应商统计失败: {e}")
|
||||
return {}
|
||||
|
||||
def cleanup_old_records(self, days: int = 90) -> int:
|
||||
"""清理旧记录"""
|
||||
if not self._connected:
|
||||
return 0
|
||||
|
||||
try:
|
||||
from datetime import timedelta
|
||||
cutoff_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
result = self.collection.delete_many({
|
||||
'timestamp': {'$lt': cutoff_date.isoformat()}
|
||||
})
|
||||
|
||||
deleted_count = result.deleted_count
|
||||
if deleted_count > 0:
|
||||
print(f"清理了 {deleted_count} 条超过 {days} 天的记录")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
print(f"清理旧记录失败: {e}")
|
||||
return 0
|
||||
|
||||
def close(self):
|
||||
"""关闭MongoDB连接"""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self._connected = False
|
||||
print("MongoDB连接已关闭")
|
||||
|
|
@ -0,0 +1,383 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
自适应缓存系统
|
||||
根据数据库可用性自动选择最佳缓存策略
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import pandas as pd
|
||||
|
||||
from ..config.database_manager import get_database_manager
|
||||
|
||||
class AdaptiveCacheSystem:
|
||||
"""自适应缓存系统"""
|
||||
|
||||
def __init__(self, cache_dir: str = "data/cache"):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# 获取数据库管理器
|
||||
self.db_manager = get_database_manager()
|
||||
|
||||
# 设置缓存目录
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取配置
|
||||
self.config = self.db_manager.get_config()
|
||||
self.cache_config = self.config["cache"]
|
||||
|
||||
# 初始化缓存后端
|
||||
self.primary_backend = self.cache_config["primary_backend"]
|
||||
self.fallback_enabled = self.cache_config["fallback_enabled"]
|
||||
|
||||
self.logger.info(f"自适应缓存系统初始化 - 主要后端: {self.primary_backend}")
|
||||
|
||||
def _get_cache_key(self, symbol: str, start_date: str = "", end_date: str = "",
|
||||
data_source: str = "default", data_type: str = "stock_data") -> str:
|
||||
"""生成缓存键"""
|
||||
key_data = f"{symbol}_{start_date}_{end_date}_{data_source}_{data_type}"
|
||||
return hashlib.md5(key_data.encode()).hexdigest()
|
||||
|
||||
def _get_ttl_seconds(self, symbol: str, data_type: str = "stock_data") -> int:
|
||||
"""获取TTL秒数"""
|
||||
# 判断市场类型
|
||||
if len(symbol) == 6 and symbol.isdigit():
|
||||
market = "china"
|
||||
else:
|
||||
market = "us"
|
||||
|
||||
# 获取TTL配置
|
||||
ttl_key = f"{market}_{data_type}"
|
||||
ttl_seconds = self.cache_config["ttl_settings"].get(ttl_key, 7200)
|
||||
return ttl_seconds
|
||||
|
||||
def _is_cache_valid(self, cache_time: datetime, ttl_seconds: int) -> bool:
|
||||
"""检查缓存是否有效"""
|
||||
if cache_time is None:
|
||||
return False
|
||||
|
||||
expiry_time = cache_time + timedelta(seconds=ttl_seconds)
|
||||
return datetime.now() < expiry_time
|
||||
|
||||
def _save_to_file(self, cache_key: str, data: Any, metadata: Dict) -> bool:
|
||||
"""保存到文件缓存"""
|
||||
try:
|
||||
cache_file = self.cache_dir / f"{cache_key}.pkl"
|
||||
cache_data = {
|
||||
'data': data,
|
||||
'metadata': metadata,
|
||||
'timestamp': datetime.now(),
|
||||
'backend': 'file'
|
||||
}
|
||||
|
||||
with open(cache_file, 'wb') as f:
|
||||
pickle.dump(cache_data, f)
|
||||
|
||||
self.logger.debug(f"文件缓存保存成功: {cache_key}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"文件缓存保存失败: {e}")
|
||||
return False
|
||||
|
||||
def _load_from_file(self, cache_key: str) -> Optional[Dict]:
|
||||
"""从文件缓存加载"""
|
||||
try:
|
||||
cache_file = self.cache_dir / f"{cache_key}.pkl"
|
||||
if not cache_file.exists():
|
||||
return None
|
||||
|
||||
with open(cache_file, 'rb') as f:
|
||||
cache_data = pickle.load(f)
|
||||
|
||||
self.logger.debug(f"文件缓存加载成功: {cache_key}")
|
||||
return cache_data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"文件缓存加载失败: {e}")
|
||||
return None
|
||||
|
||||
def _save_to_redis(self, cache_key: str, data: Any, metadata: Dict, ttl_seconds: int) -> bool:
|
||||
"""保存到Redis缓存"""
|
||||
redis_client = self.db_manager.get_redis_client()
|
||||
if not redis_client:
|
||||
return False
|
||||
|
||||
try:
|
||||
cache_data = {
|
||||
'data': data,
|
||||
'metadata': metadata,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'backend': 'redis'
|
||||
}
|
||||
|
||||
serialized_data = pickle.dumps(cache_data)
|
||||
redis_client.setex(cache_key, ttl_seconds, serialized_data)
|
||||
|
||||
self.logger.debug(f"Redis缓存保存成功: {cache_key}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Redis缓存保存失败: {e}")
|
||||
return False
|
||||
|
||||
def _load_from_redis(self, cache_key: str) -> Optional[Dict]:
|
||||
"""从Redis缓存加载"""
|
||||
redis_client = self.db_manager.get_redis_client()
|
||||
if not redis_client:
|
||||
return None
|
||||
|
||||
try:
|
||||
serialized_data = redis_client.get(cache_key)
|
||||
if not serialized_data:
|
||||
return None
|
||||
|
||||
cache_data = pickle.loads(serialized_data)
|
||||
|
||||
# 转换时间戳
|
||||
if isinstance(cache_data['timestamp'], str):
|
||||
cache_data['timestamp'] = datetime.fromisoformat(cache_data['timestamp'])
|
||||
|
||||
self.logger.debug(f"Redis缓存加载成功: {cache_key}")
|
||||
return cache_data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Redis缓存加载失败: {e}")
|
||||
return None
|
||||
|
||||
def _save_to_mongodb(self, cache_key: str, data: Any, metadata: Dict, ttl_seconds: int) -> bool:
|
||||
"""保存到MongoDB缓存"""
|
||||
mongodb_client = self.db_manager.get_mongodb_client()
|
||||
if not mongodb_client:
|
||||
return False
|
||||
|
||||
try:
|
||||
db = mongodb_client.tradingagents
|
||||
collection = db.cache
|
||||
|
||||
# 序列化数据
|
||||
if isinstance(data, pd.DataFrame):
|
||||
serialized_data = data.to_json()
|
||||
data_type = 'dataframe'
|
||||
else:
|
||||
serialized_data = pickle.dumps(data).hex()
|
||||
data_type = 'pickle'
|
||||
|
||||
cache_doc = {
|
||||
'_id': cache_key,
|
||||
'data': serialized_data,
|
||||
'data_type': data_type,
|
||||
'metadata': metadata,
|
||||
'timestamp': datetime.now(),
|
||||
'expires_at': datetime.now() + timedelta(seconds=ttl_seconds),
|
||||
'backend': 'mongodb'
|
||||
}
|
||||
|
||||
collection.replace_one({'_id': cache_key}, cache_doc, upsert=True)
|
||||
|
||||
self.logger.debug(f"MongoDB缓存保存成功: {cache_key}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"MongoDB缓存保存失败: {e}")
|
||||
return False
|
||||
|
||||
def _load_from_mongodb(self, cache_key: str) -> Optional[Dict]:
|
||||
"""从MongoDB缓存加载"""
|
||||
mongodb_client = self.db_manager.get_mongodb_client()
|
||||
if not mongodb_client:
|
||||
return None
|
||||
|
||||
try:
|
||||
db = mongodb_client.tradingagents
|
||||
collection = db.cache
|
||||
|
||||
doc = collection.find_one({'_id': cache_key})
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
# 检查是否过期
|
||||
if doc.get('expires_at') and doc['expires_at'] < datetime.now():
|
||||
collection.delete_one({'_id': cache_key})
|
||||
return None
|
||||
|
||||
# 反序列化数据
|
||||
if doc['data_type'] == 'dataframe':
|
||||
data = pd.read_json(doc['data'])
|
||||
else:
|
||||
data = pickle.loads(bytes.fromhex(doc['data']))
|
||||
|
||||
cache_data = {
|
||||
'data': data,
|
||||
'metadata': doc['metadata'],
|
||||
'timestamp': doc['timestamp'],
|
||||
'backend': 'mongodb'
|
||||
}
|
||||
|
||||
self.logger.debug(f"MongoDB缓存加载成功: {cache_key}")
|
||||
return cache_data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"MongoDB缓存加载失败: {e}")
|
||||
return None
|
||||
|
||||
def save_data(self, symbol: str, data: Any, start_date: str = "", end_date: str = "",
|
||||
data_source: str = "default", data_type: str = "stock_data") -> str:
|
||||
"""保存数据到缓存"""
|
||||
# 生成缓存键
|
||||
cache_key = self._get_cache_key(symbol, start_date, end_date, data_source, data_type)
|
||||
|
||||
# 准备元数据
|
||||
metadata = {
|
||||
'symbol': symbol,
|
||||
'start_date': start_date,
|
||||
'end_date': end_date,
|
||||
'data_source': data_source,
|
||||
'data_type': data_type
|
||||
}
|
||||
|
||||
# 获取TTL
|
||||
ttl_seconds = self._get_ttl_seconds(symbol, data_type)
|
||||
|
||||
# 根据主要后端保存
|
||||
success = False
|
||||
|
||||
if self.primary_backend == "redis":
|
||||
success = self._save_to_redis(cache_key, data, metadata, ttl_seconds)
|
||||
elif self.primary_backend == "mongodb":
|
||||
success = self._save_to_mongodb(cache_key, data, metadata, ttl_seconds)
|
||||
elif self.primary_backend == "file":
|
||||
success = self._save_to_file(cache_key, data, metadata)
|
||||
|
||||
# 如果主要后端失败,使用降级策略
|
||||
if not success and self.fallback_enabled:
|
||||
self.logger.warning(f"主要后端({self.primary_backend})保存失败,使用文件缓存降级")
|
||||
success = self._save_to_file(cache_key, data, metadata)
|
||||
|
||||
if success:
|
||||
self.logger.info(f"数据缓存成功: {symbol} -> {cache_key} (后端: {self.primary_backend})")
|
||||
else:
|
||||
self.logger.error(f"数据缓存失败: {symbol}")
|
||||
|
||||
return cache_key
|
||||
|
||||
def load_data(self, cache_key: str) -> Optional[Any]:
|
||||
"""从缓存加载数据"""
|
||||
cache_data = None
|
||||
|
||||
# 根据主要后端加载
|
||||
if self.primary_backend == "redis":
|
||||
cache_data = self._load_from_redis(cache_key)
|
||||
elif self.primary_backend == "mongodb":
|
||||
cache_data = self._load_from_mongodb(cache_key)
|
||||
elif self.primary_backend == "file":
|
||||
cache_data = self._load_from_file(cache_key)
|
||||
|
||||
# 如果主要后端失败,尝试降级
|
||||
if not cache_data and self.fallback_enabled:
|
||||
self.logger.debug(f"主要后端({self.primary_backend})加载失败,尝试文件缓存")
|
||||
cache_data = self._load_from_file(cache_key)
|
||||
|
||||
if not cache_data:
|
||||
return None
|
||||
|
||||
# 检查缓存是否有效(仅对文件缓存,数据库缓存有自己的TTL机制)
|
||||
if cache_data.get('backend') == 'file':
|
||||
symbol = cache_data['metadata'].get('symbol', '')
|
||||
data_type = cache_data['metadata'].get('data_type', 'stock_data')
|
||||
ttl_seconds = self._get_ttl_seconds(symbol, data_type)
|
||||
|
||||
if not self._is_cache_valid(cache_data['timestamp'], ttl_seconds):
|
||||
self.logger.debug(f"文件缓存已过期: {cache_key}")
|
||||
return None
|
||||
|
||||
return cache_data['data']
|
||||
|
||||
def find_cached_data(self, symbol: str, start_date: str = "", end_date: str = "",
|
||||
data_source: str = "default", data_type: str = "stock_data") -> Optional[str]:
|
||||
"""查找缓存的数据"""
|
||||
cache_key = self._get_cache_key(symbol, start_date, end_date, data_source, data_type)
|
||||
|
||||
# 检查缓存是否存在且有效
|
||||
if self.load_data(cache_key) is not None:
|
||||
return cache_key
|
||||
|
||||
return None
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
stats = {
|
||||
'primary_backend': self.primary_backend,
|
||||
'fallback_enabled': self.fallback_enabled,
|
||||
'database_available': self.db_manager.is_database_available(),
|
||||
'mongodb_available': self.db_manager.is_mongodb_available(),
|
||||
'redis_available': self.db_manager.is_redis_available(),
|
||||
'file_cache_directory': str(self.cache_dir),
|
||||
'file_cache_count': len(list(self.cache_dir.glob("*.pkl"))),
|
||||
}
|
||||
|
||||
# Redis统计
|
||||
redis_client = self.db_manager.get_redis_client()
|
||||
if redis_client:
|
||||
try:
|
||||
redis_info = redis_client.info()
|
||||
stats['redis_memory_used'] = redis_info.get('used_memory_human', 'N/A')
|
||||
stats['redis_keys'] = redis_client.dbsize()
|
||||
except:
|
||||
stats['redis_status'] = 'Error'
|
||||
|
||||
# MongoDB统计
|
||||
mongodb_client = self.db_manager.get_mongodb_client()
|
||||
if mongodb_client:
|
||||
try:
|
||||
db = mongodb_client.tradingagents
|
||||
stats['mongodb_cache_count'] = db.cache.count_documents({})
|
||||
except:
|
||||
stats['mongodb_status'] = 'Error'
|
||||
|
||||
return stats
|
||||
|
||||
def clear_expired_cache(self):
|
||||
"""清理过期缓存"""
|
||||
self.logger.info("开始清理过期缓存...")
|
||||
|
||||
# 清理文件缓存
|
||||
cleared_files = 0
|
||||
for cache_file in self.cache_dir.glob("*.pkl"):
|
||||
try:
|
||||
with open(cache_file, 'rb') as f:
|
||||
cache_data = pickle.load(f)
|
||||
|
||||
symbol = cache_data['metadata'].get('symbol', '')
|
||||
data_type = cache_data['metadata'].get('data_type', 'stock_data')
|
||||
ttl_seconds = self._get_ttl_seconds(symbol, data_type)
|
||||
|
||||
if not self._is_cache_valid(cache_data['timestamp'], ttl_seconds):
|
||||
cache_file.unlink()
|
||||
cleared_files += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"清理缓存文件失败 {cache_file}: {e}")
|
||||
|
||||
self.logger.info(f"文件缓存清理完成,删除 {cleared_files} 个过期文件")
|
||||
|
||||
# MongoDB会自动清理过期文档(通过expires_at字段)
|
||||
# Redis会自动清理过期键
|
||||
|
||||
|
||||
# 全局缓存系统实例
|
||||
_cache_system = None
|
||||
|
||||
def get_cache_system() -> AdaptiveCacheSystem:
|
||||
"""获取全局自适应缓存系统实例"""
|
||||
global _cache_system
|
||||
if _cache_system is None:
|
||||
_cache_system = AdaptiveCacheSystem()
|
||||
return _cache_system
|
||||
|
|
@ -0,0 +1,498 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Stock Data Cache Manager
|
||||
Supports local caching of stock data to reduce API calls and improve response speed
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, Union
|
||||
import hashlib
|
||||
|
||||
|
||||
class StockDataCache:
|
||||
"""Stock Data Cache Manager - Supports optimized caching for US and Chinese stock data"""
|
||||
|
||||
def __init__(self, cache_dir: str = None):
|
||||
"""
|
||||
Initialize cache manager
|
||||
|
||||
Args:
|
||||
cache_dir: Cache directory path, defaults to tradingagents/dataflows/data_cache
|
||||
"""
|
||||
if cache_dir is None:
|
||||
# Get current file directory
|
||||
current_dir = Path(__file__).parent
|
||||
cache_dir = current_dir / "data_cache"
|
||||
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Create subdirectories - categorized by market
|
||||
self.us_stock_dir = self.cache_dir / "us_stocks"
|
||||
self.china_stock_dir = self.cache_dir / "china_stocks"
|
||||
self.us_news_dir = self.cache_dir / "us_news"
|
||||
self.china_news_dir = self.cache_dir / "china_news"
|
||||
self.us_fundamentals_dir = self.cache_dir / "us_fundamentals"
|
||||
self.china_fundamentals_dir = self.cache_dir / "china_fundamentals"
|
||||
self.metadata_dir = self.cache_dir / "metadata"
|
||||
|
||||
# Create all directories
|
||||
for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
|
||||
self.china_news_dir, self.us_fundamentals_dir,
|
||||
self.china_fundamentals_dir, self.metadata_dir]:
|
||||
dir_path.mkdir(exist_ok=True)
|
||||
|
||||
# Cache configuration - different TTL settings for different markets
|
||||
self.cache_config = {
|
||||
'us_stock_data': {
|
||||
'ttl_hours': 2, # US stock data cached for 2 hours (considering API limits)
|
||||
'max_files': 1000,
|
||||
'description': 'US stock historical data'
|
||||
},
|
||||
'china_stock_data': {
|
||||
'ttl_hours': 1, # A-share data cached for 1 hour (high real-time requirement)
|
||||
'max_files': 1000,
|
||||
'description': 'A-share historical data'
|
||||
},
|
||||
'us_news': {
|
||||
'ttl_hours': 6, # US stock news cached for 6 hours
|
||||
'max_files': 500,
|
||||
'description': 'US stock news data'
|
||||
},
|
||||
'china_news': {
|
||||
'ttl_hours': 4, # A-share news cached for 4 hours
|
||||
'max_files': 500,
|
||||
'description': 'A-share news data'
|
||||
},
|
||||
'us_fundamentals': {
|
||||
'ttl_hours': 24, # US stock fundamentals cached for 24 hours
|
||||
'max_files': 200,
|
||||
'description': 'US stock fundamentals data'
|
||||
},
|
||||
'china_fundamentals': {
|
||||
'ttl_hours': 12, # A-share fundamentals cached for 12 hours
|
||||
'max_files': 200,
|
||||
'description': 'A-share fundamentals data'
|
||||
}
|
||||
}
|
||||
|
||||
print(f"📁 Cache manager initialized, cache directory: {self.cache_dir}")
|
||||
print(f"🗄️ Database cache manager initialized")
|
||||
print(f" US stock data: ✅ Configured")
|
||||
print(f" A-share data: ✅ Configured")
|
||||
|
||||
def _determine_market_type(self, symbol: str) -> str:
|
||||
"""Determine market type based on stock symbol"""
|
||||
import re
|
||||
|
||||
# Check if it's Chinese A-share (6-digit number)
|
||||
if re.match(r'^\d{6}$', str(symbol)):
|
||||
return 'china'
|
||||
else:
|
||||
return 'us'
|
||||
|
||||
def _generate_cache_key(self, data_type: str, symbol: str, **kwargs) -> str:
|
||||
"""Generate cache key"""
|
||||
# Create a string containing all parameters
|
||||
params_str = f"{data_type}_{symbol}"
|
||||
for key, value in sorted(kwargs.items()):
|
||||
params_str += f"_{key}_{value}"
|
||||
|
||||
# Use MD5 to generate short unique identifier
|
||||
cache_key = hashlib.md5(params_str.encode()).hexdigest()[:12]
|
||||
return f"{symbol}_{data_type}_{cache_key}"
|
||||
|
||||
def _get_cache_path(self, data_type: str, cache_key: str, file_format: str = "json", symbol: str = None) -> Path:
|
||||
"""Get cache file path - supports market classification"""
|
||||
if symbol:
|
||||
market_type = self._determine_market_type(symbol)
|
||||
else:
|
||||
# Try to extract market type from cache key
|
||||
market_type = 'us' if not cache_key.startswith(('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')) else 'china'
|
||||
|
||||
# Select directory based on data type and market type
|
||||
if data_type == "stock_data":
|
||||
base_dir = self.china_stock_dir if market_type == 'china' else self.us_stock_dir
|
||||
elif data_type == "news":
|
||||
base_dir = self.china_news_dir if market_type == 'china' else self.us_news_dir
|
||||
elif data_type == "fundamentals":
|
||||
base_dir = self.china_fundamentals_dir if market_type == 'china' else self.us_fundamentals_dir
|
||||
else:
|
||||
base_dir = self.cache_dir
|
||||
|
||||
return base_dir / f"{cache_key}.{file_format}"
|
||||
|
||||
def _get_metadata_path(self, cache_key: str) -> Path:
|
||||
"""Get metadata file path"""
|
||||
return self.metadata_dir / f"{cache_key}_meta.json"
|
||||
|
||||
def _save_metadata(self, cache_key: str, metadata: Dict[str, Any]):
|
||||
"""Save cache metadata"""
|
||||
metadata_path = self._get_metadata_path(cache_key)
|
||||
try:
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=2, default=str)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to save metadata: {e}")
|
||||
|
||||
def _load_metadata(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load cache metadata"""
|
||||
metadata_path = self._get_metadata_path(cache_key)
|
||||
if not metadata_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to load metadata: {e}")
|
||||
return None
|
||||
|
||||
def save_stock_data(self, symbol: str, data: Union[str, pd.DataFrame],
|
||||
start_date: str, end_date: str, data_source: str = "unknown") -> str:
|
||||
"""
|
||||
Save stock data to cache
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol
|
||||
data: Stock data (string or DataFrame)
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
data_source: Data source name
|
||||
|
||||
Returns:
|
||||
Cache key
|
||||
"""
|
||||
try:
|
||||
# Generate cache key
|
||||
cache_key = self._generate_cache_key(
|
||||
"stock_data", symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
data_source=data_source
|
||||
)
|
||||
|
||||
# Determine file format and save data
|
||||
if isinstance(data, pd.DataFrame):
|
||||
# Save DataFrame as pickle for better performance
|
||||
cache_path = self._get_cache_path("stock_data", cache_key, "pkl", symbol)
|
||||
data.to_pickle(cache_path)
|
||||
data_type = "dataframe"
|
||||
else:
|
||||
# Save string data as JSON
|
||||
cache_path = self._get_cache_path("stock_data", cache_key, "json", symbol)
|
||||
with open(cache_path, 'w', encoding='utf-8') as f:
|
||||
json.dump({"data": data}, f, ensure_ascii=False, indent=2)
|
||||
data_type = "string"
|
||||
|
||||
# Save metadata
|
||||
market_type = self._determine_market_type(symbol)
|
||||
metadata = {
|
||||
"symbol": symbol,
|
||||
"data_type": data_type,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"data_source": data_source,
|
||||
"market_type": market_type,
|
||||
"cache_time": datetime.now().isoformat(),
|
||||
"file_path": str(cache_path),
|
||||
"cache_key": cache_key
|
||||
}
|
||||
self._save_metadata(cache_key, metadata)
|
||||
|
||||
print(f"💾 Stock data cached: {symbol} ({market_type.upper()}) -> {cache_key}")
|
||||
return cache_key
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to save stock data cache: {e}")
|
||||
return None
|
||||
|
||||
def load_stock_data(self, cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
|
||||
"""
|
||||
Load stock data from cache
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
|
||||
Returns:
|
||||
Stock data or None if not found
|
||||
"""
|
||||
try:
|
||||
# Load metadata
|
||||
metadata = self._load_metadata(cache_key)
|
||||
if not metadata:
|
||||
print(f"⚠️ Cache metadata not found: {cache_key}")
|
||||
return None
|
||||
|
||||
# Get file path
|
||||
cache_path = Path(metadata["file_path"])
|
||||
if not cache_path.exists():
|
||||
print(f"⚠️ Cache file not found: {cache_path}")
|
||||
return None
|
||||
|
||||
# Load data based on type
|
||||
if metadata["data_type"] == "dataframe":
|
||||
data = pd.read_pickle(cache_path)
|
||||
else:
|
||||
with open(cache_path, 'r', encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
data = json_data["data"]
|
||||
|
||||
print(f"📖 Stock data loaded from cache: {metadata['symbol']} -> {cache_key}")
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to load stock data from cache: {e}")
|
||||
return None
|
||||
|
||||
def find_cached_stock_data(self, symbol: str, start_date: str, end_date: str,
|
||||
data_source: str = "unknown") -> Optional[str]:
|
||||
"""
|
||||
Find cached stock data
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
data_source: Data source name
|
||||
|
||||
Returns:
|
||||
Cache key if found, None otherwise
|
||||
"""
|
||||
# Generate expected cache key
|
||||
cache_key = self._generate_cache_key(
|
||||
"stock_data", symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
data_source=data_source
|
||||
)
|
||||
|
||||
# Check if metadata exists
|
||||
metadata = self._load_metadata(cache_key)
|
||||
if metadata:
|
||||
cache_path = Path(metadata["file_path"])
|
||||
if cache_path.exists():
|
||||
return cache_key
|
||||
|
||||
return None
|
||||
|
||||
def is_cache_valid(self, cache_key: str, symbol: str = None, data_type: str = "stock_data") -> bool:
|
||||
"""
|
||||
Check if cache is still valid
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
symbol: Stock symbol (for market type determination)
|
||||
data_type: Data type
|
||||
|
||||
Returns:
|
||||
True if cache is valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Load metadata
|
||||
metadata = self._load_metadata(cache_key)
|
||||
if not metadata:
|
||||
return False
|
||||
|
||||
# Check if file exists
|
||||
cache_path = Path(metadata["file_path"])
|
||||
if not cache_path.exists():
|
||||
return False
|
||||
|
||||
# Determine market type and get TTL
|
||||
if symbol:
|
||||
market_type = self._determine_market_type(symbol)
|
||||
else:
|
||||
market_type = metadata.get("market_type", "us")
|
||||
|
||||
cache_type_key = f"{market_type}_{data_type}"
|
||||
if cache_type_key not in self.cache_config:
|
||||
cache_type_key = "us_stock_data" # Default fallback
|
||||
|
||||
ttl_hours = self.cache_config[cache_type_key]["ttl_hours"]
|
||||
|
||||
# Check if cache has expired
|
||||
cache_time = datetime.fromisoformat(metadata["cache_time"])
|
||||
expiry_time = cache_time + timedelta(hours=ttl_hours)
|
||||
|
||||
is_valid = datetime.now() < expiry_time
|
||||
if not is_valid:
|
||||
print(f"⏰ Cache expired: {cache_key} (cached at {cache_time}, TTL: {ttl_hours}h)")
|
||||
|
||||
return is_valid
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to check cache validity: {e}")
|
||||
return False
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics
|
||||
|
||||
Returns:
|
||||
Dictionary containing cache statistics
|
||||
"""
|
||||
try:
|
||||
stats = {
|
||||
"cache_dir": str(self.cache_dir),
|
||||
"total_files": 0,
|
||||
"total_size_mb": 0,
|
||||
"stock_data_count": 0,
|
||||
"news_count": 0,
|
||||
"fundamentals_count": 0,
|
||||
"us_data_count": 0,
|
||||
"china_data_count": 0
|
||||
}
|
||||
|
||||
# Count files in each directory
|
||||
for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
|
||||
self.china_news_dir, self.us_fundamentals_dir,
|
||||
self.china_fundamentals_dir, self.metadata_dir]:
|
||||
if dir_path.exists():
|
||||
files = list(dir_path.glob("*"))
|
||||
stats["total_files"] += len(files)
|
||||
|
||||
# Calculate total size
|
||||
for file_path in files:
|
||||
if file_path.is_file():
|
||||
stats["total_size_mb"] += file_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
# Count by data type
|
||||
if self.us_stock_dir.exists():
|
||||
stats["stock_data_count"] += len(list(self.us_stock_dir.glob("*")))
|
||||
stats["us_data_count"] += len(list(self.us_stock_dir.glob("*")))
|
||||
|
||||
if self.china_stock_dir.exists():
|
||||
stats["stock_data_count"] += len(list(self.china_stock_dir.glob("*")))
|
||||
stats["china_data_count"] += len(list(self.china_stock_dir.glob("*")))
|
||||
|
||||
if self.us_news_dir.exists():
|
||||
stats["news_count"] += len(list(self.us_news_dir.glob("*")))
|
||||
stats["us_data_count"] += len(list(self.us_news_dir.glob("*")))
|
||||
|
||||
if self.china_news_dir.exists():
|
||||
stats["news_count"] += len(list(self.china_news_dir.glob("*")))
|
||||
stats["china_data_count"] += len(list(self.china_news_dir.glob("*")))
|
||||
|
||||
if self.us_fundamentals_dir.exists():
|
||||
stats["fundamentals_count"] += len(list(self.us_fundamentals_dir.glob("*")))
|
||||
stats["us_data_count"] += len(list(self.us_fundamentals_dir.glob("*")))
|
||||
|
||||
if self.china_fundamentals_dir.exists():
|
||||
stats["fundamentals_count"] += len(list(self.china_fundamentals_dir.glob("*")))
|
||||
stats["china_data_count"] += len(list(self.china_fundamentals_dir.glob("*")))
|
||||
|
||||
# Round size to 2 decimal places
|
||||
stats["total_size_mb"] = round(stats["total_size_mb"], 2)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to get cache statistics: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def cleanup_expired_cache(self):
|
||||
"""Clean up expired cache files"""
|
||||
try:
|
||||
cleaned_count = 0
|
||||
|
||||
# Check all metadata files
|
||||
if self.metadata_dir.exists():
|
||||
for metadata_file in self.metadata_dir.glob("*_meta.json"):
|
||||
try:
|
||||
cache_key = metadata_file.stem.replace("_meta", "")
|
||||
|
||||
# Load metadata
|
||||
with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Check if cache is expired
|
||||
if not self.is_cache_valid(cache_key, metadata.get("symbol"), "stock_data"):
|
||||
# Remove cache file
|
||||
cache_path = Path(metadata["file_path"])
|
||||
if cache_path.exists():
|
||||
cache_path.unlink()
|
||||
|
||||
# Remove metadata file
|
||||
metadata_file.unlink()
|
||||
cleaned_count += 1
|
||||
print(f"🗑️ Cleaned expired cache: {cache_key}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to clean cache file {metadata_file}: {e}")
|
||||
|
||||
print(f"✅ Cache cleanup completed, removed {cleaned_count} expired files")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to cleanup cache: {e}")
|
||||
|
||||
|
||||
# Global cache instance
|
||||
_global_cache = None
|
||||
|
||||
def get_cache(cache_dir: str = None) -> StockDataCache:
|
||||
"""
|
||||
Get global cache instance
|
||||
|
||||
Args:
|
||||
cache_dir: Cache directory path
|
||||
|
||||
Returns:
|
||||
StockDataCache instance
|
||||
"""
|
||||
global _global_cache
|
||||
if _global_cache is None:
|
||||
_global_cache = StockDataCache(cache_dir)
|
||||
return _global_cache
|
||||
|
||||
|
||||
# Convenience functions
|
||||
def save_stock_data(symbol: str, data: Union[str, pd.DataFrame],
|
||||
start_date: str, end_date: str, data_source: str = "unknown") -> str:
|
||||
"""Save stock data to cache (convenience function)"""
|
||||
cache = get_cache()
|
||||
return cache.save_stock_data(symbol, data, start_date, end_date, data_source)
|
||||
|
||||
|
||||
def load_stock_data(cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
|
||||
"""Load stock data from cache (convenience function)"""
|
||||
cache = get_cache()
|
||||
return cache.load_stock_data(cache_key)
|
||||
|
||||
|
||||
def find_cached_stock_data(symbol: str, start_date: str, end_date: str,
|
||||
data_source: str = "unknown") -> Optional[str]:
|
||||
"""Find cached stock data (convenience function)"""
|
||||
cache = get_cache()
|
||||
return cache.find_cached_stock_data(symbol, start_date, end_date, data_source)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the cache manager
|
||||
print("🧪 Testing Stock Data Cache Manager...")
|
||||
|
||||
# Initialize cache
|
||||
cache = StockDataCache()
|
||||
|
||||
# Test data
|
||||
test_data = "Sample stock data for AAPL"
|
||||
cache_key = cache.save_stock_data("AAPL", test_data, "2024-01-01", "2024-01-31", "test")
|
||||
|
||||
# Load data
|
||||
loaded_data = cache.load_stock_data(cache_key)
|
||||
print(f"Loaded data: {loaded_data}")
|
||||
|
||||
# Check cache validity
|
||||
is_valid = cache.is_cache_valid(cache_key, "AAPL")
|
||||
print(f"Cache valid: {is_valid}")
|
||||
|
||||
# Get statistics
|
||||
stats = cache.get_cache_stats()
|
||||
print(f"Cache stats: {stats}")
|
||||
|
||||
print("✅ Cache manager test completed!")
|
||||
|
|
@ -0,0 +1,498 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Stock Data Cache Manager
|
||||
Supports local caching of stock data to reduce API calls and improve response speed
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, Union
|
||||
import hashlib
|
||||
|
||||
|
||||
class StockDataCache:
|
||||
"""Stock Data Cache Manager - Supports optimized caching for US and Chinese stock data"""
|
||||
|
||||
def __init__(self, cache_dir: str = None):
|
||||
"""
|
||||
Initialize cache manager
|
||||
|
||||
Args:
|
||||
cache_dir: Cache directory path, defaults to tradingagents/dataflows/data_cache
|
||||
"""
|
||||
if cache_dir is None:
|
||||
# Get current file directory
|
||||
current_dir = Path(__file__).parent
|
||||
cache_dir = current_dir / "data_cache"
|
||||
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Create subdirectories - categorized by market
|
||||
self.us_stock_dir = self.cache_dir / "us_stocks"
|
||||
self.china_stock_dir = self.cache_dir / "china_stocks"
|
||||
self.us_news_dir = self.cache_dir / "us_news"
|
||||
self.china_news_dir = self.cache_dir / "china_news"
|
||||
self.us_fundamentals_dir = self.cache_dir / "us_fundamentals"
|
||||
self.china_fundamentals_dir = self.cache_dir / "china_fundamentals"
|
||||
self.metadata_dir = self.cache_dir / "metadata"
|
||||
|
||||
# Create all directories
|
||||
for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
|
||||
self.china_news_dir, self.us_fundamentals_dir,
|
||||
self.china_fundamentals_dir, self.metadata_dir]:
|
||||
dir_path.mkdir(exist_ok=True)
|
||||
|
||||
# Cache configuration - different TTL settings for different markets
|
||||
self.cache_config = {
|
||||
'us_stock_data': {
|
||||
'ttl_hours': 2, # US stock data cached for 2 hours (considering API limits)
|
||||
'max_files': 1000,
|
||||
'description': 'US stock historical data'
|
||||
},
|
||||
'china_stock_data': {
|
||||
'ttl_hours': 1, # A-share data cached for 1 hour (high real-time requirement)
|
||||
'max_files': 1000,
|
||||
'description': 'A-share historical data'
|
||||
},
|
||||
'us_news': {
|
||||
'ttl_hours': 6, # US stock news cached for 6 hours
|
||||
'max_files': 500,
|
||||
'description': 'US stock news data'
|
||||
},
|
||||
'china_news': {
|
||||
'ttl_hours': 4, # A-share news cached for 4 hours
|
||||
'max_files': 500,
|
||||
'description': 'A-share news data'
|
||||
},
|
||||
'us_fundamentals': {
|
||||
'ttl_hours': 24, # US stock fundamentals cached for 24 hours
|
||||
'max_files': 200,
|
||||
'description': 'US stock fundamentals data'
|
||||
},
|
||||
'china_fundamentals': {
|
||||
'ttl_hours': 12, # A-share fundamentals cached for 12 hours
|
||||
'max_files': 200,
|
||||
'description': 'A-share fundamentals data'
|
||||
}
|
||||
}
|
||||
|
||||
print(f"📁 Cache manager initialized, cache directory: {self.cache_dir}")
|
||||
print(f"🗄️ Database cache manager initialized")
|
||||
print(f" US stock data: ✅ Configured")
|
||||
print(f" A-share data: ✅ Configured")
|
||||
|
||||
def _determine_market_type(self, symbol: str) -> str:
|
||||
"""Determine market type based on stock symbol"""
|
||||
import re
|
||||
|
||||
# Check if it's Chinese A-share (6-digit number)
|
||||
if re.match(r'^\d{6}$', str(symbol)):
|
||||
return 'china'
|
||||
else:
|
||||
return 'us'
|
||||
|
||||
def _generate_cache_key(self, data_type: str, symbol: str, **kwargs) -> str:
|
||||
"""Generate cache key"""
|
||||
# Create a string containing all parameters
|
||||
params_str = f"{data_type}_{symbol}"
|
||||
for key, value in sorted(kwargs.items()):
|
||||
params_str += f"_{key}_{value}"
|
||||
|
||||
# Use MD5 to generate short unique identifier
|
||||
cache_key = hashlib.md5(params_str.encode()).hexdigest()[:12]
|
||||
return f"{symbol}_{data_type}_{cache_key}"
|
||||
|
||||
def _get_cache_path(self, data_type: str, cache_key: str, file_format: str = "json", symbol: str = None) -> Path:
|
||||
"""Get cache file path - supports market classification"""
|
||||
if symbol:
|
||||
market_type = self._determine_market_type(symbol)
|
||||
else:
|
||||
# Try to extract market type from cache key
|
||||
market_type = 'us' if not cache_key.startswith(('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')) else 'china'
|
||||
|
||||
# Select directory based on data type and market type
|
||||
if data_type == "stock_data":
|
||||
base_dir = self.china_stock_dir if market_type == 'china' else self.us_stock_dir
|
||||
elif data_type == "news":
|
||||
base_dir = self.china_news_dir if market_type == 'china' else self.us_news_dir
|
||||
elif data_type == "fundamentals":
|
||||
base_dir = self.china_fundamentals_dir if market_type == 'china' else self.us_fundamentals_dir
|
||||
else:
|
||||
base_dir = self.cache_dir
|
||||
|
||||
return base_dir / f"{cache_key}.{file_format}"
|
||||
|
||||
def _get_metadata_path(self, cache_key: str) -> Path:
|
||||
"""Get metadata file path"""
|
||||
return self.metadata_dir / f"{cache_key}_meta.json"
|
||||
|
||||
def _save_metadata(self, cache_key: str, metadata: Dict[str, Any]):
|
||||
"""Save cache metadata"""
|
||||
metadata_path = self._get_metadata_path(cache_key)
|
||||
try:
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=2, default=str)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to save metadata: {e}")
|
||||
|
||||
def _load_metadata(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load cache metadata"""
|
||||
metadata_path = self._get_metadata_path(cache_key)
|
||||
if not metadata_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to load metadata: {e}")
|
||||
return None
|
||||
|
||||
def save_stock_data(self, symbol: str, data: Union[str, pd.DataFrame],
|
||||
start_date: str, end_date: str, data_source: str = "unknown") -> str:
|
||||
"""
|
||||
Save stock data to cache
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol
|
||||
data: Stock data (string or DataFrame)
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
data_source: Data source name
|
||||
|
||||
Returns:
|
||||
Cache key
|
||||
"""
|
||||
try:
|
||||
# Generate cache key
|
||||
cache_key = self._generate_cache_key(
|
||||
"stock_data", symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
data_source=data_source
|
||||
)
|
||||
|
||||
# Determine file format and save data
|
||||
if isinstance(data, pd.DataFrame):
|
||||
# Save DataFrame as pickle for better performance
|
||||
cache_path = self._get_cache_path("stock_data", cache_key, "pkl", symbol)
|
||||
data.to_pickle(cache_path)
|
||||
data_type = "dataframe"
|
||||
else:
|
||||
# Save string data as JSON
|
||||
cache_path = self._get_cache_path("stock_data", cache_key, "json", symbol)
|
||||
with open(cache_path, 'w', encoding='utf-8') as f:
|
||||
json.dump({"data": data}, f, ensure_ascii=False, indent=2)
|
||||
data_type = "string"
|
||||
|
||||
# Save metadata
|
||||
market_type = self._determine_market_type(symbol)
|
||||
metadata = {
|
||||
"symbol": symbol,
|
||||
"data_type": data_type,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"data_source": data_source,
|
||||
"market_type": market_type,
|
||||
"cache_time": datetime.now().isoformat(),
|
||||
"file_path": str(cache_path),
|
||||
"cache_key": cache_key
|
||||
}
|
||||
self._save_metadata(cache_key, metadata)
|
||||
|
||||
print(f"💾 Stock data cached: {symbol} ({market_type.upper()}) -> {cache_key}")
|
||||
return cache_key
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to save stock data cache: {e}")
|
||||
return None
|
||||
|
||||
def load_stock_data(self, cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
|
||||
"""
|
||||
Load stock data from cache
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
|
||||
Returns:
|
||||
Stock data or None if not found
|
||||
"""
|
||||
try:
|
||||
# Load metadata
|
||||
metadata = self._load_metadata(cache_key)
|
||||
if not metadata:
|
||||
print(f"⚠️ Cache metadata not found: {cache_key}")
|
||||
return None
|
||||
|
||||
# Get file path
|
||||
cache_path = Path(metadata["file_path"])
|
||||
if not cache_path.exists():
|
||||
print(f"⚠️ Cache file not found: {cache_path}")
|
||||
return None
|
||||
|
||||
# Load data based on type
|
||||
if metadata["data_type"] == "dataframe":
|
||||
data = pd.read_pickle(cache_path)
|
||||
else:
|
||||
with open(cache_path, 'r', encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
data = json_data["data"]
|
||||
|
||||
print(f"📖 Stock data loaded from cache: {metadata['symbol']} -> {cache_key}")
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to load stock data from cache: {e}")
|
||||
return None
|
||||
|
||||
def find_cached_stock_data(self, symbol: str, start_date: str, end_date: str,
|
||||
data_source: str = "unknown") -> Optional[str]:
|
||||
"""
|
||||
Find cached stock data
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
data_source: Data source name
|
||||
|
||||
Returns:
|
||||
Cache key if found, None otherwise
|
||||
"""
|
||||
# Generate expected cache key
|
||||
cache_key = self._generate_cache_key(
|
||||
"stock_data", symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
data_source=data_source
|
||||
)
|
||||
|
||||
# Check if metadata exists
|
||||
metadata = self._load_metadata(cache_key)
|
||||
if metadata:
|
||||
cache_path = Path(metadata["file_path"])
|
||||
if cache_path.exists():
|
||||
return cache_key
|
||||
|
||||
return None
|
||||
|
||||
def is_cache_valid(self, cache_key: str, symbol: str = None, data_type: str = "stock_data") -> bool:
|
||||
"""
|
||||
Check if cache is still valid
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
symbol: Stock symbol (for market type determination)
|
||||
data_type: Data type
|
||||
|
||||
Returns:
|
||||
True if cache is valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Load metadata
|
||||
metadata = self._load_metadata(cache_key)
|
||||
if not metadata:
|
||||
return False
|
||||
|
||||
# Check if file exists
|
||||
cache_path = Path(metadata["file_path"])
|
||||
if not cache_path.exists():
|
||||
return False
|
||||
|
||||
# Determine market type and get TTL
|
||||
if symbol:
|
||||
market_type = self._determine_market_type(symbol)
|
||||
else:
|
||||
market_type = metadata.get("market_type", "us")
|
||||
|
||||
cache_type_key = f"{market_type}_{data_type}"
|
||||
if cache_type_key not in self.cache_config:
|
||||
cache_type_key = "us_stock_data" # Default fallback
|
||||
|
||||
ttl_hours = self.cache_config[cache_type_key]["ttl_hours"]
|
||||
|
||||
# Check if cache has expired
|
||||
cache_time = datetime.fromisoformat(metadata["cache_time"])
|
||||
expiry_time = cache_time + timedelta(hours=ttl_hours)
|
||||
|
||||
is_valid = datetime.now() < expiry_time
|
||||
if not is_valid:
|
||||
print(f"⏰ Cache expired: {cache_key} (cached at {cache_time}, TTL: {ttl_hours}h)")
|
||||
|
||||
return is_valid
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to check cache validity: {e}")
|
||||
return False
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics
|
||||
|
||||
Returns:
|
||||
Dictionary containing cache statistics
|
||||
"""
|
||||
try:
|
||||
stats = {
|
||||
"cache_dir": str(self.cache_dir),
|
||||
"total_files": 0,
|
||||
"total_size_mb": 0,
|
||||
"stock_data_count": 0,
|
||||
"news_count": 0,
|
||||
"fundamentals_count": 0,
|
||||
"us_data_count": 0,
|
||||
"china_data_count": 0
|
||||
}
|
||||
|
||||
# Count files in each directory
|
||||
for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
|
||||
self.china_news_dir, self.us_fundamentals_dir,
|
||||
self.china_fundamentals_dir, self.metadata_dir]:
|
||||
if dir_path.exists():
|
||||
files = list(dir_path.glob("*"))
|
||||
stats["total_files"] += len(files)
|
||||
|
||||
# Calculate total size
|
||||
for file_path in files:
|
||||
if file_path.is_file():
|
||||
stats["total_size_mb"] += file_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
# Count by data type
|
||||
if self.us_stock_dir.exists():
|
||||
stats["stock_data_count"] += len(list(self.us_stock_dir.glob("*")))
|
||||
stats["us_data_count"] += len(list(self.us_stock_dir.glob("*")))
|
||||
|
||||
if self.china_stock_dir.exists():
|
||||
stats["stock_data_count"] += len(list(self.china_stock_dir.glob("*")))
|
||||
stats["china_data_count"] += len(list(self.china_stock_dir.glob("*")))
|
||||
|
||||
if self.us_news_dir.exists():
|
||||
stats["news_count"] += len(list(self.us_news_dir.glob("*")))
|
||||
stats["us_data_count"] += len(list(self.us_news_dir.glob("*")))
|
||||
|
||||
if self.china_news_dir.exists():
|
||||
stats["news_count"] += len(list(self.china_news_dir.glob("*")))
|
||||
stats["china_data_count"] += len(list(self.china_news_dir.glob("*")))
|
||||
|
||||
if self.us_fundamentals_dir.exists():
|
||||
stats["fundamentals_count"] += len(list(self.us_fundamentals_dir.glob("*")))
|
||||
stats["us_data_count"] += len(list(self.us_fundamentals_dir.glob("*")))
|
||||
|
||||
if self.china_fundamentals_dir.exists():
|
||||
stats["fundamentals_count"] += len(list(self.china_fundamentals_dir.glob("*")))
|
||||
stats["china_data_count"] += len(list(self.china_fundamentals_dir.glob("*")))
|
||||
|
||||
# Round size to 2 decimal places
|
||||
stats["total_size_mb"] = round(stats["total_size_mb"], 2)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to get cache statistics: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def cleanup_expired_cache(self):
|
||||
"""Clean up expired cache files"""
|
||||
try:
|
||||
cleaned_count = 0
|
||||
|
||||
# Check all metadata files
|
||||
if self.metadata_dir.exists():
|
||||
for metadata_file in self.metadata_dir.glob("*_meta.json"):
|
||||
try:
|
||||
cache_key = metadata_file.stem.replace("_meta", "")
|
||||
|
||||
# Load metadata
|
||||
with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Check if cache is expired
|
||||
if not self.is_cache_valid(cache_key, metadata.get("symbol"), "stock_data"):
|
||||
# Remove cache file
|
||||
cache_path = Path(metadata["file_path"])
|
||||
if cache_path.exists():
|
||||
cache_path.unlink()
|
||||
|
||||
# Remove metadata file
|
||||
metadata_file.unlink()
|
||||
cleaned_count += 1
|
||||
print(f"🗑️ Cleaned expired cache: {cache_key}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to clean cache file {metadata_file}: {e}")
|
||||
|
||||
print(f"✅ Cache cleanup completed, removed {cleaned_count} expired files")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to cleanup cache: {e}")
|
||||
|
||||
|
||||
# Global cache instance
|
||||
_global_cache = None
|
||||
|
||||
def get_cache(cache_dir: str = None) -> StockDataCache:
|
||||
"""
|
||||
Get global cache instance
|
||||
|
||||
Args:
|
||||
cache_dir: Cache directory path
|
||||
|
||||
Returns:
|
||||
StockDataCache instance
|
||||
"""
|
||||
global _global_cache
|
||||
if _global_cache is None:
|
||||
_global_cache = StockDataCache(cache_dir)
|
||||
return _global_cache
|
||||
|
||||
|
||||
# Convenience functions
|
||||
def save_stock_data(symbol: str, data: Union[str, pd.DataFrame],
|
||||
start_date: str, end_date: str, data_source: str = "unknown") -> str:
|
||||
"""Save stock data to cache (convenience function)"""
|
||||
cache = get_cache()
|
||||
return cache.save_stock_data(symbol, data, start_date, end_date, data_source)
|
||||
|
||||
|
||||
def load_stock_data(cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
|
||||
"""Load stock data from cache (convenience function)"""
|
||||
cache = get_cache()
|
||||
return cache.load_stock_data(cache_key)
|
||||
|
||||
|
||||
def find_cached_stock_data(symbol: str, start_date: str, end_date: str,
|
||||
data_source: str = "unknown") -> Optional[str]:
|
||||
"""Find cached stock data (convenience function)"""
|
||||
cache = get_cache()
|
||||
return cache.find_cached_stock_data(symbol, start_date, end_date, data_source)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the cache manager
|
||||
print("🧪 Testing Stock Data Cache Manager...")
|
||||
|
||||
# Initialize cache
|
||||
cache = StockDataCache()
|
||||
|
||||
# Test data
|
||||
test_data = "Sample stock data for AAPL"
|
||||
cache_key = cache.save_stock_data("AAPL", test_data, "2024-01-01", "2024-01-31", "test")
|
||||
|
||||
# Load data
|
||||
loaded_data = cache.load_stock_data(cache_key)
|
||||
print(f"Loaded data: {loaded_data}")
|
||||
|
||||
# Check cache validity
|
||||
is_valid = cache.is_cache_valid(cache_key, "AAPL")
|
||||
print(f"Cache valid: {is_valid}")
|
||||
|
||||
# Get statistics
|
||||
stats = cache.get_cache_stats()
|
||||
print(f"Cache stats: {stats}")
|
||||
|
||||
print("✅ Cache manager test completed!")
|
||||
|
|
@ -0,0 +1,863 @@
|
|||
# 文件差异报告
|
||||
# 当前文件: tradingagents\dataflows\cache_manager.py
|
||||
# 中文版文件: TradingAgentsCN\tradingagents\dataflows\cache_manager.py
|
||||
# 生成时间: 周日 2025/07/06
|
||||
|
||||
--- current/cache_manager.py+++ chinese_version/cache_manager.py@@ -1,7 +1,7 @@ #!/usr/bin/env python3
|
||||
"""
|
||||
-Stock Data Cache Manager
|
||||
-Supports local caching of stock data to reduce API calls and improve response speed
|
||||
+股票数据缓存管理器
|
||||
+支持本地缓存股票数据,减少API调用,提高响应速度
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -15,24 +15,24 @@
|
||||
|
||||
class StockDataCache:
|
||||
- """Stock Data Cache Manager - Supports optimized caching for US and Chinese stock data"""
|
||||
+ """股票数据缓存管理器 - 支持美股和A股数据缓存优化"""
|
||||
|
||||
def __init__(self, cache_dir: str = None):
|
||||
"""
|
||||
- Initialize cache manager
|
||||
+ 初始化缓存管理器
|
||||
|
||||
Args:
|
||||
- cache_dir: Cache directory path, defaults to tradingagents/dataflows/data_cache
|
||||
+ cache_dir: 缓存目录路径,默认为 tradingagents/dataflows/data_cache
|
||||
"""
|
||||
if cache_dir is None:
|
||||
- # Get current file directory
|
||||
+ # 获取当前文件所在目录
|
||||
current_dir = Path(__file__).parent
|
||||
cache_dir = current_dir / "data_cache"
|
||||
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(exist_ok=True)
|
||||
|
||||
- # Create subdirectories - categorized by market
|
||||
+ # 创建子目录 - 按市场分类
|
||||
self.us_stock_dir = self.cache_dir / "us_stocks"
|
||||
self.china_stock_dir = self.cache_dir / "china_stocks"
|
||||
self.us_news_dir = self.cache_dir / "us_news"
|
||||
@@ -41,81 +41,81 @@ self.china_fundamentals_dir = self.cache_dir / "china_fundamentals"
|
||||
self.metadata_dir = self.cache_dir / "metadata"
|
||||
|
||||
- # Create all directories
|
||||
+ # 创建所有目录
|
||||
for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
|
||||
self.china_news_dir, self.us_fundamentals_dir,
|
||||
self.china_fundamentals_dir, self.metadata_dir]:
|
||||
dir_path.mkdir(exist_ok=True)
|
||||
|
||||
- # Cache configuration - different TTL settings for different markets
|
||||
+ # 缓存配置 - 针对不同市场设置不同的TTL
|
||||
self.cache_config = {
|
||||
'us_stock_data': {
|
||||
- 'ttl_hours': 2, # US stock data cached for 2 hours (considering API limits)
|
||||
+ 'ttl_hours': 2, # 美股数据缓存2小时(考虑到API限制)
|
||||
'max_files': 1000,
|
||||
- 'description': 'US stock historical data'
|
||||
+ 'description': '美股历史数据'
|
||||
},
|
||||
'china_stock_data': {
|
||||
- 'ttl_hours': 1, # A-share data cached for 1 hour (high real-time requirement)
|
||||
+ 'ttl_hours': 1, # A股数据缓存1小时(实时性要求高)
|
||||
'max_files': 1000,
|
||||
- 'description': 'A-share historical data'
|
||||
+ 'description': 'A股历史数据'
|
||||
},
|
||||
'us_news': {
|
||||
- 'ttl_hours': 6, # US stock news cached for 6 hours
|
||||
+ 'ttl_hours': 6, # 美股新闻缓存6小时
|
||||
'max_files': 500,
|
||||
- 'description': 'US stock news data'
|
||||
+ 'description': '美股新闻数据'
|
||||
},
|
||||
'china_news': {
|
||||
- 'ttl_hours': 4, # A-share news cached for 4 hours
|
||||
+ 'ttl_hours': 4, # A股新闻缓存4小时
|
||||
'max_files': 500,
|
||||
- 'description': 'A-share news data'
|
||||
+ 'description': 'A股新闻数据'
|
||||
},
|
||||
'us_fundamentals': {
|
||||
- 'ttl_hours': 24, # US stock fundamentals cached for 24 hours
|
||||
+ 'ttl_hours': 24, # 美股基本面数据缓存24小时
|
||||
'max_files': 200,
|
||||
- 'description': 'US stock fundamentals data'
|
||||
+ 'description': '美股基本面数据'
|
||||
},
|
||||
'china_fundamentals': {
|
||||
- 'ttl_hours': 12, # A-share fundamentals cached for 12 hours
|
||||
+ 'ttl_hours': 12, # A股基本面数据缓存12小时
|
||||
'max_files': 200,
|
||||
- 'description': 'A-share fundamentals data'
|
||||
+ 'description': 'A股基本面数据'
|
||||
}
|
||||
}
|
||||
|
||||
- print(f"📁 Cache manager initialized, cache directory: {self.cache_dir}")
|
||||
- print(f"🗄️ Database cache manager initialized")
|
||||
- print(f" US stock data: ✅ Configured")
|
||||
- print(f" A-share data: ✅ Configured")
|
||||
+ print(f"📁 缓存管理器初始化完成,缓存目录: {self.cache_dir}")
|
||||
+ print(f"🗄️ 数据库缓存管理器初始化完成")
|
||||
+ print(f" 美股数据: ✅ 已配置")
|
||||
+ print(f" A股数据: ✅ 已配置")
|
||||
|
||||
def _determine_market_type(self, symbol: str) -> str:
|
||||
- """Determine market type based on stock symbol"""
|
||||
+ """根据股票代码确定市场类型"""
|
||||
import re
|
||||
|
||||
- # Check if it's Chinese A-share (6-digit number)
|
||||
+ # 判断是否为中国A股(6位数字)
|
||||
if re.match(r'^\d{6}$', str(symbol)):
|
||||
return 'china'
|
||||
else:
|
||||
return 'us'
|
||||
|
||||
def _generate_cache_key(self, data_type: str, symbol: str, **kwargs) -> str:
|
||||
- """Generate cache key"""
|
||||
- # Create a string containing all parameters
|
||||
+ """生成缓存键"""
|
||||
+ # 创建一个包含所有参数的字符串
|
||||
params_str = f"{data_type}_{symbol}"
|
||||
for key, value in sorted(kwargs.items()):
|
||||
params_str += f"_{key}_{value}"
|
||||
|
||||
- # Use MD5 to generate short unique identifier
|
||||
+ # 使用MD5生成短的唯一标识
|
||||
cache_key = hashlib.md5(params_str.encode()).hexdigest()[:12]
|
||||
return f"{symbol}_{data_type}_{cache_key}"
|
||||
|
||||
def _get_cache_path(self, data_type: str, cache_key: str, file_format: str = "json", symbol: str = None) -> Path:
|
||||
- """Get cache file path - supports market classification"""
|
||||
+ """获取缓存文件路径 - 支持市场分类"""
|
||||
if symbol:
|
||||
market_type = self._determine_market_type(symbol)
|
||||
else:
|
||||
- # Try to extract market type from cache key
|
||||
+ # 从缓存键中尝试提取市场类型
|
||||
market_type = 'us' if not cache_key.startswith(('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')) else 'china'
|
||||
|
||||
- # Select directory based on data type and market type
|
||||
+ # 根据数据类型和市场类型选择目录
|
||||
if data_type == "stock_data":
|
||||
base_dir = self.china_stock_dir if market_type == 'china' else self.us_stock_dir
|
||||
elif data_type == "news":
|
||||
@@ -128,20 +128,19 @@ return base_dir / f"{cache_key}.{file_format}"
|
||||
|
||||
def _get_metadata_path(self, cache_key: str) -> Path:
|
||||
- """Get metadata file path"""
|
||||
+ """获取元数据文件路径"""
|
||||
return self.metadata_dir / f"{cache_key}_meta.json"
|
||||
|
||||
def _save_metadata(self, cache_key: str, metadata: Dict[str, Any]):
|
||||
- """Save cache metadata"""
|
||||
+ """保存元数据"""
|
||||
metadata_path = self._get_metadata_path(cache_key)
|
||||
- try:
|
||||
- with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
- json.dump(metadata, f, ensure_ascii=False, indent=2, default=str)
|
||||
- except Exception as e:
|
||||
- print(f"⚠️ Failed to save metadata: {e}")
|
||||
+ metadata['cached_at'] = datetime.now().isoformat()
|
||||
+
|
||||
+ with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
+ json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def _load_metadata(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
- """Load cache metadata"""
|
||||
+ """加载元数据"""
|
||||
metadata_path = self._get_metadata_path(cache_key)
|
||||
if not metadata_path.exists():
|
||||
return None
|
||||
@@ -150,349 +149,355 @@ with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
- print(f"⚠️ Failed to load metadata: {e}")
|
||||
- return None
|
||||
-
|
||||
- def save_stock_data(self, symbol: str, data: Union[str, pd.DataFrame],
|
||||
- start_date: str, end_date: str, data_source: str = "unknown") -> str:
|
||||
- """
|
||||
- Save stock data to cache
|
||||
+ print(f"⚠️ 加载元数据失败: {e}")
|
||||
+ return None
|
||||
+
|
||||
+ def is_cache_valid(self, cache_key: str, max_age_hours: int = None, symbol: str = None, data_type: str = None) -> bool:
|
||||
+ """检查缓存是否有效 - 支持智能TTL配置"""
|
||||
+ metadata = self._load_metadata(cache_key)
|
||||
+ if not metadata:
|
||||
+ return False
|
||||
+
|
||||
+ # 如果没有指定TTL,根据数据类型和市场自动确定
|
||||
+ if max_age_hours is None:
|
||||
+ if symbol and data_type:
|
||||
+ market_type = self._determine_market_type(symbol)
|
||||
+ cache_type = f"{market_type}_{data_type}"
|
||||
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
|
||||
+ else:
|
||||
+ # 从元数据中获取信息
|
||||
+ symbol = metadata.get('symbol', '')
|
||||
+ data_type = metadata.get('data_type', 'stock_data')
|
||||
+ market_type = self._determine_market_type(symbol)
|
||||
+ cache_type = f"{market_type}_{data_type}"
|
||||
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
|
||||
+
|
||||
+ cached_at = datetime.fromisoformat(metadata['cached_at'])
|
||||
+ age = datetime.now() - cached_at
|
||||
+
|
||||
+ is_valid = age.total_seconds() < max_age_hours * 3600
|
||||
+
|
||||
+ if is_valid:
|
||||
+ market_type = self._determine_market_type(metadata.get('symbol', ''))
|
||||
+ cache_type = f"{market_type}_{metadata.get('data_type', 'stock_data')}"
|
||||
+ desc = self.cache_config.get(cache_type, {}).get('description', '数据')
|
||||
+ print(f"✅ 缓存有效: {desc} - {metadata.get('symbol')} (剩余 {max_age_hours - age.total_seconds()/3600:.1f}h)")
|
||||
+
|
||||
+ return is_valid
|
||||
+
|
||||
+ def save_stock_data(self, symbol: str, data: Union[pd.DataFrame, str],
|
||||
+ start_date: str = None, end_date: str = None,
|
||||
+ data_source: str = "unknown") -> str:
|
||||
+ """
|
||||
+ 保存股票数据到缓存 - 支持美股和A股分类存储
|
||||
|
||||
Args:
|
||||
- symbol: Stock symbol
|
||||
- data: Stock data (string or DataFrame)
|
||||
- start_date: Start date
|
||||
- end_date: End date
|
||||
- data_source: Data source name
|
||||
+ symbol: 股票代码
|
||||
+ data: 股票数据(DataFrame或字符串)
|
||||
+ start_date: 开始日期
|
||||
+ end_date: 结束日期
|
||||
+ data_source: 数据源(如 "tdx", "yfinance", "finnhub")
|
||||
|
||||
Returns:
|
||||
- Cache key
|
||||
- """
|
||||
+ cache_key: 缓存键
|
||||
+ """
|
||||
+ market_type = self._determine_market_type(symbol)
|
||||
+ cache_key = self._generate_cache_key("stock_data", symbol,
|
||||
+ start_date=start_date,
|
||||
+ end_date=end_date,
|
||||
+ source=data_source,
|
||||
+ market=market_type)
|
||||
+
|
||||
+ # 保存数据
|
||||
+ if isinstance(data, pd.DataFrame):
|
||||
+ cache_path = self._get_cache_path("stock_data", cache_key, "csv", symbol)
|
||||
+ data.to_csv(cache_path, index=True)
|
||||
+ else:
|
||||
+ cache_path = self._get_cache_path("stock_data", cache_key, "txt", symbol)
|
||||
+ with open(cache_path, 'w', encoding='utf-8') as f:
|
||||
+ f.write(str(data))
|
||||
+
|
||||
+ # 保存元数据
|
||||
+ metadata = {
|
||||
+ 'symbol': symbol,
|
||||
+ 'data_type': 'stock_data',
|
||||
+ 'market_type': market_type,
|
||||
+ 'start_date': start_date,
|
||||
+ 'end_date': end_date,
|
||||
+ 'data_source': data_source,
|
||||
+ 'file_path': str(cache_path),
|
||||
+ 'file_format': 'csv' if isinstance(data, pd.DataFrame) else 'txt'
|
||||
+ }
|
||||
+ self._save_metadata(cache_key, metadata)
|
||||
+
|
||||
+ # 获取描述信息
|
||||
+ cache_type = f"{market_type}_stock_data"
|
||||
+ desc = self.cache_config.get(cache_type, {}).get('description', '股票数据')
|
||||
+ print(f"💾 {desc}已缓存: {symbol} ({data_source}) -> {cache_key}")
|
||||
+ return cache_key
|
||||
+
|
||||
+ def load_stock_data(self, cache_key: str) -> Optional[Union[pd.DataFrame, str]]:
|
||||
+ """从缓存加载股票数据"""
|
||||
+ metadata = self._load_metadata(cache_key)
|
||||
+ if not metadata:
|
||||
+ return None
|
||||
+
|
||||
+ cache_path = Path(metadata['file_path'])
|
||||
+ if not cache_path.exists():
|
||||
+ return None
|
||||
+
|
||||
try:
|
||||
- # Generate cache key
|
||||
- cache_key = self._generate_cache_key(
|
||||
- "stock_data", symbol,
|
||||
- start_date=start_date,
|
||||
- end_date=end_date,
|
||||
- data_source=data_source
|
||||
- )
|
||||
-
|
||||
- # Determine file format and save data
|
||||
- if isinstance(data, pd.DataFrame):
|
||||
- # Save DataFrame as pickle for better performance
|
||||
- cache_path = self._get_cache_path("stock_data", cache_key, "pkl", symbol)
|
||||
- data.to_pickle(cache_path)
|
||||
- data_type = "dataframe"
|
||||
- else:
|
||||
- # Save string data as JSON
|
||||
- cache_path = self._get_cache_path("stock_data", cache_key, "json", symbol)
|
||||
- with open(cache_path, 'w', encoding='utf-8') as f:
|
||||
- json.dump({"data": data}, f, ensure_ascii=False, indent=2)
|
||||
- data_type = "string"
|
||||
-
|
||||
- # Save metadata
|
||||
- market_type = self._determine_market_type(symbol)
|
||||
- metadata = {
|
||||
- "symbol": symbol,
|
||||
- "data_type": data_type,
|
||||
- "start_date": start_date,
|
||||
- "end_date": end_date,
|
||||
- "data_source": data_source,
|
||||
- "market_type": market_type,
|
||||
- "cache_time": datetime.now().isoformat(),
|
||||
- "file_path": str(cache_path),
|
||||
- "cache_key": cache_key
|
||||
- }
|
||||
- self._save_metadata(cache_key, metadata)
|
||||
-
|
||||
- print(f"💾 Stock data cached: {symbol} ({market_type.upper()}) -> {cache_key}")
|
||||
- return cache_key
|
||||
-
|
||||
- except Exception as e:
|
||||
- print(f"❌ Failed to save stock data cache: {e}")
|
||||
- return None
|
||||
-
|
||||
- def load_stock_data(self, cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
|
||||
- """
|
||||
- Load stock data from cache
|
||||
-
|
||||
- Args:
|
||||
- cache_key: Cache key
|
||||
-
|
||||
- Returns:
|
||||
- Stock data or None if not found
|
||||
- """
|
||||
- try:
|
||||
- # Load metadata
|
||||
- metadata = self._load_metadata(cache_key)
|
||||
- if not metadata:
|
||||
- print(f"⚠️ Cache metadata not found: {cache_key}")
|
||||
- return None
|
||||
-
|
||||
- # Get file path
|
||||
- cache_path = Path(metadata["file_path"])
|
||||
- if not cache_path.exists():
|
||||
- print(f"⚠️ Cache file not found: {cache_path}")
|
||||
- return None
|
||||
-
|
||||
- # Load data based on type
|
||||
- if metadata["data_type"] == "dataframe":
|
||||
- data = pd.read_pickle(cache_path)
|
||||
+ if metadata['file_format'] == 'csv':
|
||||
+ return pd.read_csv(cache_path, index_col=0)
|
||||
else:
|
||||
with open(cache_path, 'r', encoding='utf-8') as f:
|
||||
- json_data = json.load(f)
|
||||
- data = json_data["data"]
|
||||
-
|
||||
- print(f"📖 Stock data loaded from cache: {metadata['symbol']} -> {cache_key}")
|
||||
- return data
|
||||
-
|
||||
+ return f.read()
|
||||
except Exception as e:
|
||||
- print(f"❌ Failed to load stock data from cache: {e}")
|
||||
- return None
|
||||
-
|
||||
- def find_cached_stock_data(self, symbol: str, start_date: str, end_date: str,
|
||||
- data_source: str = "unknown") -> Optional[str]:
|
||||
- """
|
||||
- Find cached stock data
|
||||
+ print(f"⚠️ 加载缓存数据失败: {e}")
|
||||
+ return None
|
||||
+
|
||||
+ def find_cached_stock_data(self, symbol: str, start_date: str = None,
|
||||
+ end_date: str = None, data_source: str = None,
|
||||
+ max_age_hours: int = None) -> Optional[str]:
|
||||
+ """
|
||||
+ 查找匹配的缓存数据 - 支持智能市场分类查找
|
||||
|
||||
Args:
|
||||
- symbol: Stock symbol
|
||||
- start_date: Start date
|
||||
- end_date: End date
|
||||
- data_source: Data source name
|
||||
+ symbol: 股票代码
|
||||
+ start_date: 开始日期
|
||||
+ end_date: 结束日期
|
||||
+ data_source: 数据源
|
||||
+ max_age_hours: 最大缓存时间(小时),None时使用智能配置
|
||||
|
||||
Returns:
|
||||
- Cache key if found, None otherwise
|
||||
- """
|
||||
- # Generate expected cache key
|
||||
- cache_key = self._generate_cache_key(
|
||||
- "stock_data", symbol,
|
||||
- start_date=start_date,
|
||||
- end_date=end_date,
|
||||
- data_source=data_source
|
||||
- )
|
||||
-
|
||||
- # Check if metadata exists
|
||||
+ cache_key: 如果找到有效缓存则返回缓存键,否则返回None
|
||||
+ """
|
||||
+ market_type = self._determine_market_type(symbol)
|
||||
+
|
||||
+ # 如果没有指定TTL,使用智能配置
|
||||
+ if max_age_hours is None:
|
||||
+ cache_type = f"{market_type}_stock_data"
|
||||
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
|
||||
+
|
||||
+ # 生成查找键
|
||||
+ search_key = self._generate_cache_key("stock_data", symbol,
|
||||
+ start_date=start_date,
|
||||
+ end_date=end_date,
|
||||
+ source=data_source,
|
||||
+ market=market_type)
|
||||
+
|
||||
+ # 检查精确匹配
|
||||
+ if self.is_cache_valid(search_key, max_age_hours, symbol, 'stock_data'):
|
||||
+ desc = self.cache_config.get(f"{market_type}_stock_data", {}).get('description', '数据')
|
||||
+ print(f"🎯 找到精确匹配的{desc}: {symbol} -> {search_key}")
|
||||
+ return search_key
|
||||
+
|
||||
+ # 如果没有精确匹配,查找部分匹配(相同股票代码的其他缓存)
|
||||
+ for metadata_file in self.metadata_dir.glob(f"*_meta.json"):
|
||||
+ try:
|
||||
+ with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||
+ metadata = json.load(f)
|
||||
+
|
||||
+ if (metadata.get('symbol') == symbol and
|
||||
+ metadata.get('data_type') == 'stock_data' and
|
||||
+ metadata.get('market_type') == market_type and
|
||||
+ (data_source is None or metadata.get('data_source') == data_source)):
|
||||
+
|
||||
+ cache_key = metadata_file.stem.replace('_meta', '')
|
||||
+ if self.is_cache_valid(cache_key, max_age_hours, symbol, 'stock_data'):
|
||||
+ desc = self.cache_config.get(f"{market_type}_stock_data", {}).get('description', '数据')
|
||||
+ print(f"📋 找到部分匹配的{desc}: {symbol} -> {cache_key}")
|
||||
+ return cache_key
|
||||
+ except Exception:
|
||||
+ continue
|
||||
+
|
||||
+ desc = self.cache_config.get(f"{market_type}_stock_data", {}).get('description', '数据')
|
||||
+ print(f"❌ 未找到有效的{desc}缓存: {symbol}")
|
||||
+ return None
|
||||
+
|
||||
+ def save_news_data(self, symbol: str, news_data: str,
|
||||
+ start_date: str = None, end_date: str = None,
|
||||
+ data_source: str = "unknown") -> str:
|
||||
+ """保存新闻数据到缓存"""
|
||||
+ cache_key = self._generate_cache_key("news", symbol,
|
||||
+ start_date=start_date,
|
||||
+ end_date=end_date,
|
||||
+ source=data_source)
|
||||
+
|
||||
+ cache_path = self._get_cache_path("news", cache_key, "txt")
|
||||
+ with open(cache_path, 'w', encoding='utf-8') as f:
|
||||
+ f.write(news_data)
|
||||
+
|
||||
+ metadata = {
|
||||
+ 'symbol': symbol,
|
||||
+ 'data_type': 'news',
|
||||
+ 'start_date': start_date,
|
||||
+ 'end_date': end_date,
|
||||
+ 'data_source': data_source,
|
||||
+ 'file_path': str(cache_path),
|
||||
+ 'file_format': 'txt'
|
||||
+ }
|
||||
+ self._save_metadata(cache_key, metadata)
|
||||
+
|
||||
+ print(f"📰 新闻数据已缓存: {symbol} ({data_source}) -> {cache_key}")
|
||||
+ return cache_key
|
||||
+
|
||||
+ def save_fundamentals_data(self, symbol: str, fundamentals_data: str,
|
||||
+ data_source: str = "unknown") -> str:
|
||||
+ """保存基本面数据到缓存"""
|
||||
+ market_type = self._determine_market_type(symbol)
|
||||
+ cache_key = self._generate_cache_key("fundamentals", symbol,
|
||||
+ source=data_source,
|
||||
+ market=market_type,
|
||||
+ date=datetime.now().strftime("%Y-%m-%d"))
|
||||
+
|
||||
+ cache_path = self._get_cache_path("fundamentals", cache_key, "txt", symbol)
|
||||
+ with open(cache_path, 'w', encoding='utf-8') as f:
|
||||
+ f.write(fundamentals_data)
|
||||
+
|
||||
+ metadata = {
|
||||
+ 'symbol': symbol,
|
||||
+ 'data_type': 'fundamentals',
|
||||
+ 'data_source': data_source,
|
||||
+ 'market_type': market_type,
|
||||
+ 'file_path': str(cache_path),
|
||||
+ 'file_format': 'txt'
|
||||
+ }
|
||||
+ self._save_metadata(cache_key, metadata)
|
||||
+
|
||||
+ desc = self.cache_config.get(f"{market_type}_fundamentals", {}).get('description', '基本面数据')
|
||||
+ print(f"💼 {desc}已缓存: {symbol} ({data_source}) -> {cache_key}")
|
||||
+ return cache_key
|
||||
+
|
||||
+ def load_fundamentals_data(self, cache_key: str) -> Optional[str]:
|
||||
+ """从缓存加载基本面数据"""
|
||||
metadata = self._load_metadata(cache_key)
|
||||
- if metadata:
|
||||
- cache_path = Path(metadata["file_path"])
|
||||
- if cache_path.exists():
|
||||
- return cache_key
|
||||
-
|
||||
+ if not metadata:
|
||||
+ return None
|
||||
+
|
||||
+ cache_path = Path(metadata['file_path'])
|
||||
+ if not cache_path.exists():
|
||||
+ return None
|
||||
+
|
||||
+ try:
|
||||
+ with open(cache_path, 'r', encoding='utf-8') as f:
|
||||
+ return f.read()
|
||||
+ except Exception as e:
|
||||
+ print(f"⚠️ 加载基本面缓存数据失败: {e}")
|
||||
+ return None
|
||||
+
|
||||
+ def find_cached_fundamentals_data(self, symbol: str, data_source: str = None,
|
||||
+ max_age_hours: int = None) -> Optional[str]:
|
||||
+ """
|
||||
+ 查找匹配的基本面缓存数据
|
||||
+
|
||||
+ Args:
|
||||
+ symbol: 股票代码
|
||||
+ data_source: 数据源(如 "openai", "finnhub")
|
||||
+ max_age_hours: 最大缓存时间(小时),None时使用智能配置
|
||||
+
|
||||
+ Returns:
|
||||
+ cache_key: 如果找到有效缓存则返回缓存键,否则返回None
|
||||
+ """
|
||||
+ market_type = self._determine_market_type(symbol)
|
||||
+
|
||||
+ # 如果没有指定TTL,使用智能配置
|
||||
+ if max_age_hours is None:
|
||||
+ cache_type = f"{market_type}_fundamentals"
|
||||
+ max_age_hours = self.cache_config.get(cache_type, {}).get('ttl_hours', 24)
|
||||
+
|
||||
+ # 查找匹配的缓存
|
||||
+ for metadata_file in self.metadata_dir.glob(f"*_meta.json"):
|
||||
+ try:
|
||||
+ with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||
+ metadata = json.load(f)
|
||||
+
|
||||
+ if (metadata.get('symbol') == symbol and
|
||||
+ metadata.get('data_type') == 'fundamentals' and
|
||||
+ metadata.get('market_type') == market_type and
|
||||
+ (data_source is None or metadata.get('data_source') == data_source)):
|
||||
+
|
||||
+ cache_key = metadata_file.stem.replace('_meta', '')
|
||||
+ if self.is_cache_valid(cache_key, max_age_hours, symbol, 'fundamentals'):
|
||||
+ desc = self.cache_config.get(f"{market_type}_fundamentals", {}).get('description', '基本面数据')
|
||||
+ print(f"🎯 找到匹配的{desc}缓存: {symbol} ({data_source}) -> {cache_key}")
|
||||
+ return cache_key
|
||||
+ except Exception:
|
||||
+ continue
|
||||
+
|
||||
+ desc = self.cache_config.get(f"{market_type}_fundamentals", {}).get('description', '基本面数据')
|
||||
+ print(f"❌ 未找到有效的{desc}缓存: {symbol} ({data_source})")
|
||||
return None
|
||||
-
|
||||
- def is_cache_valid(self, cache_key: str, symbol: str = None, data_type: str = "stock_data") -> bool:
|
||||
- """
|
||||
- Check if cache is still valid
|
||||
-
|
||||
- Args:
|
||||
- cache_key: Cache key
|
||||
- symbol: Stock symbol (for market type determination)
|
||||
- data_type: Data type
|
||||
-
|
||||
- Returns:
|
||||
- True if cache is valid, False otherwise
|
||||
- """
|
||||
- try:
|
||||
- # Load metadata
|
||||
- metadata = self._load_metadata(cache_key)
|
||||
- if not metadata:
|
||||
- return False
|
||||
-
|
||||
- # Check if file exists
|
||||
- cache_path = Path(metadata["file_path"])
|
||||
- if not cache_path.exists():
|
||||
- return False
|
||||
-
|
||||
- # Determine market type and get TTL
|
||||
- if symbol:
|
||||
- market_type = self._determine_market_type(symbol)
|
||||
- else:
|
||||
- market_type = metadata.get("market_type", "us")
|
||||
-
|
||||
- cache_type_key = f"{market_type}_{data_type}"
|
||||
- if cache_type_key not in self.cache_config:
|
||||
- cache_type_key = "us_stock_data" # Default fallback
|
||||
-
|
||||
- ttl_hours = self.cache_config[cache_type_key]["ttl_hours"]
|
||||
-
|
||||
- # Check if cache has expired
|
||||
- cache_time = datetime.fromisoformat(metadata["cache_time"])
|
||||
- expiry_time = cache_time + timedelta(hours=ttl_hours)
|
||||
-
|
||||
- is_valid = datetime.now() < expiry_time
|
||||
- if not is_valid:
|
||||
- print(f"⏰ Cache expired: {cache_key} (cached at {cache_time}, TTL: {ttl_hours}h)")
|
||||
-
|
||||
- return is_valid
|
||||
-
|
||||
- except Exception as e:
|
||||
- print(f"❌ Failed to check cache validity: {e}")
|
||||
- return False
|
||||
-
|
||||
+
|
||||
+ def clear_old_cache(self, max_age_days: int = 7):
|
||||
+ """清理过期缓存"""
|
||||
+ cutoff_time = datetime.now() - timedelta(days=max_age_days)
|
||||
+ cleared_count = 0
|
||||
+
|
||||
+ for metadata_file in self.metadata_dir.glob("*_meta.json"):
|
||||
+ try:
|
||||
+ with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||
+ metadata = json.load(f)
|
||||
+
|
||||
+ cached_at = datetime.fromisoformat(metadata['cached_at'])
|
||||
+ if cached_at < cutoff_time:
|
||||
+ # 删除数据文件
|
||||
+ data_file = Path(metadata['file_path'])
|
||||
+ if data_file.exists():
|
||||
+ data_file.unlink()
|
||||
+
|
||||
+ # 删除元数据文件
|
||||
+ metadata_file.unlink()
|
||||
+ cleared_count += 1
|
||||
+
|
||||
+ except Exception as e:
|
||||
+ print(f"⚠️ 清理缓存时出错: {e}")
|
||||
+
|
||||
+ print(f"🧹 已清理 {cleared_count} 个过期缓存文件")
|
||||
+
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
- """
|
||||
- Get cache statistics
|
||||
-
|
||||
- Returns:
|
||||
- Dictionary containing cache statistics
|
||||
- """
|
||||
- try:
|
||||
- stats = {
|
||||
- "cache_dir": str(self.cache_dir),
|
||||
- "total_files": 0,
|
||||
- "total_size_mb": 0,
|
||||
- "stock_data_count": 0,
|
||||
- "news_count": 0,
|
||||
- "fundamentals_count": 0,
|
||||
- "us_data_count": 0,
|
||||
- "china_data_count": 0
|
||||
- }
|
||||
-
|
||||
- # Count files in each directory
|
||||
- for dir_path in [self.us_stock_dir, self.china_stock_dir, self.us_news_dir,
|
||||
- self.china_news_dir, self.us_fundamentals_dir,
|
||||
- self.china_fundamentals_dir, self.metadata_dir]:
|
||||
- if dir_path.exists():
|
||||
- files = list(dir_path.glob("*"))
|
||||
- stats["total_files"] += len(files)
|
||||
-
|
||||
- # Calculate total size
|
||||
- for file_path in files:
|
||||
- if file_path.is_file():
|
||||
- stats["total_size_mb"] += file_path.stat().st_size / (1024 * 1024)
|
||||
-
|
||||
- # Count by data type
|
||||
- if self.us_stock_dir.exists():
|
||||
- stats["stock_data_count"] += len(list(self.us_stock_dir.glob("*")))
|
||||
- stats["us_data_count"] += len(list(self.us_stock_dir.glob("*")))
|
||||
-
|
||||
- if self.china_stock_dir.exists():
|
||||
- stats["stock_data_count"] += len(list(self.china_stock_dir.glob("*")))
|
||||
- stats["china_data_count"] += len(list(self.china_stock_dir.glob("*")))
|
||||
-
|
||||
- if self.us_news_dir.exists():
|
||||
- stats["news_count"] += len(list(self.us_news_dir.glob("*")))
|
||||
- stats["us_data_count"] += len(list(self.us_news_dir.glob("*")))
|
||||
-
|
||||
- if self.china_news_dir.exists():
|
||||
- stats["news_count"] += len(list(self.china_news_dir.glob("*")))
|
||||
- stats["china_data_count"] += len(list(self.china_news_dir.glob("*")))
|
||||
-
|
||||
- if self.us_fundamentals_dir.exists():
|
||||
- stats["fundamentals_count"] += len(list(self.us_fundamentals_dir.glob("*")))
|
||||
- stats["us_data_count"] += len(list(self.us_fundamentals_dir.glob("*")))
|
||||
-
|
||||
- if self.china_fundamentals_dir.exists():
|
||||
- stats["fundamentals_count"] += len(list(self.china_fundamentals_dir.glob("*")))
|
||||
- stats["china_data_count"] += len(list(self.china_fundamentals_dir.glob("*")))
|
||||
-
|
||||
- # Round size to 2 decimal places
|
||||
- stats["total_size_mb"] = round(stats["total_size_mb"], 2)
|
||||
-
|
||||
- return stats
|
||||
-
|
||||
- except Exception as e:
|
||||
- print(f"❌ Failed to get cache statistics: {e}")
|
||||
- return {"error": str(e)}
|
||||
-
|
||||
- def cleanup_expired_cache(self):
|
||||
- """Clean up expired cache files"""
|
||||
- try:
|
||||
- cleaned_count = 0
|
||||
-
|
||||
- # Check all metadata files
|
||||
- if self.metadata_dir.exists():
|
||||
- for metadata_file in self.metadata_dir.glob("*_meta.json"):
|
||||
- try:
|
||||
- cache_key = metadata_file.stem.replace("_meta", "")
|
||||
-
|
||||
- # Load metadata
|
||||
- with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||
- metadata = json.load(f)
|
||||
-
|
||||
- # Check if cache is expired
|
||||
- if not self.is_cache_valid(cache_key, metadata.get("symbol"), "stock_data"):
|
||||
- # Remove cache file
|
||||
- cache_path = Path(metadata["file_path"])
|
||||
- if cache_path.exists():
|
||||
- cache_path.unlink()
|
||||
-
|
||||
- # Remove metadata file
|
||||
- metadata_file.unlink()
|
||||
- cleaned_count += 1
|
||||
- print(f"🗑️ Cleaned expired cache: {cache_key}")
|
||||
-
|
||||
- except Exception as e:
|
||||
- print(f"⚠️ Failed to clean cache file {metadata_file}: {e}")
|
||||
-
|
||||
- print(f"✅ Cache cleanup completed, removed {cleaned_count} expired files")
|
||||
-
|
||||
- except Exception as e:
|
||||
- print(f"❌ Failed to cleanup cache: {e}")
|
||||
-
|
||||
-
|
||||
-# Global cache instance
|
||||
-_global_cache = None
|
||||
-
|
||||
-def get_cache(cache_dir: str = None) -> StockDataCache:
|
||||
- """
|
||||
- Get global cache instance
|
||||
-
|
||||
- Args:
|
||||
- cache_dir: Cache directory path
|
||||
-
|
||||
- Returns:
|
||||
- StockDataCache instance
|
||||
- """
|
||||
- global _global_cache
|
||||
- if _global_cache is None:
|
||||
- _global_cache = StockDataCache(cache_dir)
|
||||
- return _global_cache
|
||||
-
|
||||
-
|
||||
-# Convenience functions
|
||||
-def save_stock_data(symbol: str, data: Union[str, pd.DataFrame],
|
||||
- start_date: str, end_date: str, data_source: str = "unknown") -> str:
|
||||
- """Save stock data to cache (convenience function)"""
|
||||
- cache = get_cache()
|
||||
- return cache.save_stock_data(symbol, data, start_date, end_date, data_source)
|
||||
-
|
||||
-
|
||||
-def load_stock_data(cache_key: str) -> Optional[Union[str, pd.DataFrame]]:
|
||||
- """Load stock data from cache (convenience function)"""
|
||||
- cache = get_cache()
|
||||
- return cache.load_stock_data(cache_key)
|
||||
-
|
||||
-
|
||||
-def find_cached_stock_data(symbol: str, start_date: str, end_date: str,
|
||||
- data_source: str = "unknown") -> Optional[str]:
|
||||
- """Find cached stock data (convenience function)"""
|
||||
- cache = get_cache()
|
||||
- return cache.find_cached_stock_data(symbol, start_date, end_date, data_source)
|
||||
-
|
||||
-
|
||||
-if __name__ == "__main__":
|
||||
- # Test the cache manager
|
||||
- print("🧪 Testing Stock Data Cache Manager...")
|
||||
-
|
||||
- # Initialize cache
|
||||
- cache = StockDataCache()
|
||||
-
|
||||
- # Test data
|
||||
- test_data = "Sample stock data for AAPL"
|
||||
- cache_key = cache.save_stock_data("AAPL", test_data, "2024-01-01", "2024-01-31", "test")
|
||||
-
|
||||
- # Load data
|
||||
- loaded_data = cache.load_stock_data(cache_key)
|
||||
- print(f"Loaded data: {loaded_data}")
|
||||
-
|
||||
- # Check cache validity
|
||||
- is_valid = cache.is_cache_valid(cache_key, "AAPL")
|
||||
- print(f"Cache valid: {is_valid}")
|
||||
-
|
||||
- # Get statistics
|
||||
- stats = cache.get_cache_stats()
|
||||
- print(f"Cache stats: {stats}")
|
||||
-
|
||||
- print("✅ Cache manager test completed!")
|
||||
+ """获取缓存统计信息"""
|
||||
+ stats = {
|
||||
+ 'total_files': 0,
|
||||
+ 'stock_data_count': 0,
|
||||
+ 'news_count': 0,
|
||||
+ 'fundamentals_count': 0,
|
||||
+ 'total_size_mb': 0
|
||||
+ }
|
||||
+
|
||||
+ for metadata_file in self.metadata_dir.glob("*_meta.json"):
|
||||
+ try:
|
||||
+ with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||
+ metadata = json.load(f)
|
||||
+
|
||||
+ data_type = metadata.get('data_type', 'unknown')
|
||||
+ if data_type == 'stock_data':
|
||||
+ stats['stock_data_count'] += 1
|
||||
+ elif data_type == 'news':
|
||||
+ stats['news_count'] += 1
|
||||
+ elif data_type == 'fundamentals':
|
||||
+ stats['fundamentals_count'] += 1
|
||||
+
|
||||
+ # 计算文件大小
|
||||
+ data_file = Path(metadata['file_path'])
|
||||
+ if data_file.exists():
|
||||
+ stats['total_size_mb'] += data_file.stat().st_size / (1024 * 1024)
|
||||
+
|
||||
+ stats['total_files'] += 1
|
||||
+
|
||||
+ except Exception:
|
||||
+ continue
|
||||
+
|
||||
+ stats['total_size_mb'] = round(stats['total_size_mb'], 2)
|
||||
+ return stats
|
||||
+
|
||||
+
|
||||
+# 全局缓存实例
|
||||
+_cache_instance = None
|
||||
+
|
||||
+def get_cache() -> StockDataCache:
|
||||
+ """获取全局缓存实例"""
|
||||
+ global _cache_instance
|
||||
+ if _cache_instance is None:
|
||||
+ _cache_instance = StockDataCache()
|
||||
+ return _cache_instance
|
||||
|
|
@ -0,0 +1,331 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
中国财经数据聚合工具
|
||||
由于微博API申请困难且功能受限,采用多源数据聚合的方式
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Optional
|
||||
import re
|
||||
from bs4 import BeautifulSoup
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class ChineseFinanceDataAggregator:
|
||||
"""中国财经数据聚合器"""
|
||||
|
||||
def __init__(self):
|
||||
self.headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||||
}
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update(self.headers)
|
||||
|
||||
def get_stock_sentiment_summary(self, ticker: str, days: int = 7) -> Dict:
|
||||
"""
|
||||
获取股票情绪分析汇总
|
||||
整合多个可获取的中国财经数据源
|
||||
"""
|
||||
try:
|
||||
# 1. 获取财经新闻情绪
|
||||
news_sentiment = self._get_finance_news_sentiment(ticker, days)
|
||||
|
||||
# 2. 获取股吧讨论热度 (如果可以获取)
|
||||
forum_sentiment = self._get_stock_forum_sentiment(ticker, days)
|
||||
|
||||
# 3. 获取财经媒体报道
|
||||
media_sentiment = self._get_media_coverage_sentiment(ticker, days)
|
||||
|
||||
# 4. 综合分析
|
||||
overall_sentiment = self._calculate_overall_sentiment(
|
||||
news_sentiment, forum_sentiment, media_sentiment
|
||||
)
|
||||
|
||||
return {
|
||||
'ticker': ticker,
|
||||
'analysis_period': f'{days} days',
|
||||
'overall_sentiment': overall_sentiment,
|
||||
'news_sentiment': news_sentiment,
|
||||
'forum_sentiment': forum_sentiment,
|
||||
'media_sentiment': media_sentiment,
|
||||
'summary': self._generate_sentiment_summary(overall_sentiment),
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
'ticker': ticker,
|
||||
'error': f'数据获取失败: {str(e)}',
|
||||
'fallback_message': '由于中国社交媒体API限制,建议使用财经新闻和基本面分析作为主要参考',
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
def _get_finance_news_sentiment(self, ticker: str, days: int) -> Dict:
|
||||
"""获取财经新闻情绪分析"""
|
||||
try:
|
||||
# 搜索相关新闻标题和内容
|
||||
company_name = self._get_company_chinese_name(ticker)
|
||||
search_terms = [ticker, company_name] if company_name else [ticker]
|
||||
|
||||
news_items = []
|
||||
for term in search_terms:
|
||||
# 这里可以集成多个新闻源
|
||||
items = self._search_finance_news(term, days)
|
||||
news_items.extend(items)
|
||||
|
||||
# 简单的情绪分析
|
||||
positive_count = 0
|
||||
negative_count = 0
|
||||
neutral_count = 0
|
||||
|
||||
for item in news_items:
|
||||
sentiment = self._analyze_text_sentiment(item.get('title', '') + ' ' + item.get('content', ''))
|
||||
if sentiment > 0.1:
|
||||
positive_count += 1
|
||||
elif sentiment < -0.1:
|
||||
negative_count += 1
|
||||
else:
|
||||
neutral_count += 1
|
||||
|
||||
total = len(news_items)
|
||||
if total == 0:
|
||||
return {'sentiment_score': 0, 'confidence': 0, 'news_count': 0}
|
||||
|
||||
sentiment_score = (positive_count - negative_count) / total
|
||||
|
||||
return {
|
||||
'sentiment_score': sentiment_score,
|
||||
'positive_ratio': positive_count / total,
|
||||
'negative_ratio': negative_count / total,
|
||||
'neutral_ratio': neutral_count / total,
|
||||
'news_count': total,
|
||||
'confidence': min(total / 10, 1.0) # 新闻数量越多,置信度越高
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {'error': str(e), 'sentiment_score': 0, 'confidence': 0}
|
||||
|
||||
def _get_stock_forum_sentiment(self, ticker: str, days: int) -> Dict:
|
||||
"""获取股票论坛讨论情绪 (模拟数据,实际需要爬虫)"""
|
||||
# 由于东方财富股吧等平台的反爬虫机制,这里返回模拟数据
|
||||
# 实际实现需要更复杂的爬虫技术
|
||||
|
||||
return {
|
||||
'sentiment_score': 0,
|
||||
'discussion_count': 0,
|
||||
'hot_topics': [],
|
||||
'note': '股票论坛数据获取受限,建议关注官方财经新闻',
|
||||
'confidence': 0
|
||||
}
|
||||
|
||||
def _get_media_coverage_sentiment(self, ticker: str, days: int) -> Dict:
|
||||
"""获取媒体报道情绪"""
|
||||
try:
|
||||
# 可以集成RSS源或公开的财经API
|
||||
coverage_items = self._get_media_coverage(ticker, days)
|
||||
|
||||
if not coverage_items:
|
||||
return {'sentiment_score': 0, 'coverage_count': 0, 'confidence': 0}
|
||||
|
||||
# 分析媒体报道的情绪倾向
|
||||
sentiment_scores = []
|
||||
for item in coverage_items:
|
||||
score = self._analyze_text_sentiment(item.get('title', '') + ' ' + item.get('summary', ''))
|
||||
sentiment_scores.append(score)
|
||||
|
||||
avg_sentiment = sum(sentiment_scores) / len(sentiment_scores) if sentiment_scores else 0
|
||||
|
||||
return {
|
||||
'sentiment_score': avg_sentiment,
|
||||
'coverage_count': len(coverage_items),
|
||||
'confidence': min(len(coverage_items) / 5, 1.0)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {'error': str(e), 'sentiment_score': 0, 'confidence': 0}
|
||||
|
||||
def _search_finance_news(self, search_term: str, days: int) -> List[Dict]:
|
||||
"""搜索财经新闻 (示例实现)"""
|
||||
# 这里可以集成多个新闻源的API或RSS
|
||||
# 例如:财联社、新浪财经、东方财富等
|
||||
|
||||
# 模拟返回数据结构
|
||||
return [
|
||||
{
|
||||
'title': f'{search_term}相关财经新闻标题',
|
||||
'content': '新闻内容摘要...',
|
||||
'source': '财联社',
|
||||
'publish_time': datetime.now().isoformat(),
|
||||
'url': 'https://example.com/news/1'
|
||||
}
|
||||
]
|
||||
|
||||
def _get_media_coverage(self, ticker: str, days: int) -> List[Dict]:
|
||||
"""获取媒体报道 (示例实现)"""
|
||||
# 可以集成Google News API或其他新闻聚合服务
|
||||
return []
|
||||
|
||||
def _analyze_text_sentiment(self, text: str) -> float:
|
||||
"""简单的中文文本情绪分析"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# 简单的关键词情绪分析
|
||||
positive_words = ['上涨', '增长', '利好', '看好', '买入', '推荐', '强势', '突破', '创新高']
|
||||
negative_words = ['下跌', '下降', '利空', '看空', '卖出', '风险', '跌破', '创新低', '亏损']
|
||||
|
||||
positive_count = sum(1 for word in positive_words if word in text)
|
||||
negative_count = sum(1 for word in negative_words if word in text)
|
||||
|
||||
if positive_count + negative_count == 0:
|
||||
return 0
|
||||
|
||||
return (positive_count - negative_count) / (positive_count + negative_count)
|
||||
|
||||
def _get_company_chinese_name(self, ticker: str) -> Optional[str]:
|
||||
"""获取公司中文名称"""
|
||||
# 简单的映射表,实际可以从数据库或API获取
|
||||
name_mapping = {
|
||||
'AAPL': '苹果',
|
||||
'TSLA': '特斯拉',
|
||||
'NVDA': '英伟达',
|
||||
'MSFT': '微软',
|
||||
'GOOGL': '谷歌',
|
||||
'AMZN': '亚马逊'
|
||||
}
|
||||
return name_mapping.get(ticker.upper())
|
||||
|
||||
def _calculate_overall_sentiment(self, news_sentiment: Dict, forum_sentiment: Dict, media_sentiment: Dict) -> Dict:
|
||||
"""计算综合情绪分析"""
|
||||
# 根据各数据源的置信度加权计算
|
||||
news_weight = news_sentiment.get('confidence', 0)
|
||||
forum_weight = forum_sentiment.get('confidence', 0)
|
||||
media_weight = media_sentiment.get('confidence', 0)
|
||||
|
||||
total_weight = news_weight + forum_weight + media_weight
|
||||
|
||||
if total_weight == 0:
|
||||
return {'sentiment_score': 0, 'confidence': 0, 'level': 'neutral'}
|
||||
|
||||
weighted_sentiment = (
|
||||
news_sentiment.get('sentiment_score', 0) * news_weight +
|
||||
forum_sentiment.get('sentiment_score', 0) * forum_weight +
|
||||
media_sentiment.get('sentiment_score', 0) * media_weight
|
||||
) / total_weight
|
||||
|
||||
# 确定情绪等级
|
||||
if weighted_sentiment > 0.3:
|
||||
level = 'very_positive'
|
||||
elif weighted_sentiment > 0.1:
|
||||
level = 'positive'
|
||||
elif weighted_sentiment > -0.1:
|
||||
level = 'neutral'
|
||||
elif weighted_sentiment > -0.3:
|
||||
level = 'negative'
|
||||
else:
|
||||
level = 'very_negative'
|
||||
|
||||
return {
|
||||
'sentiment_score': weighted_sentiment,
|
||||
'confidence': total_weight / 3, # 平均置信度
|
||||
'level': level
|
||||
}
|
||||
|
||||
def _generate_sentiment_summary(self, overall_sentiment: Dict) -> str:
|
||||
"""生成情绪分析摘要"""
|
||||
level = overall_sentiment.get('level', 'neutral')
|
||||
score = overall_sentiment.get('sentiment_score', 0)
|
||||
confidence = overall_sentiment.get('confidence', 0)
|
||||
|
||||
level_descriptions = {
|
||||
'very_positive': '非常积极',
|
||||
'positive': '积极',
|
||||
'neutral': '中性',
|
||||
'negative': '消极',
|
||||
'very_negative': '非常消极'
|
||||
}
|
||||
|
||||
description = level_descriptions.get(level, '中性')
|
||||
confidence_level = '高' if confidence > 0.7 else '中' if confidence > 0.3 else '低'
|
||||
|
||||
return f"市场情绪: {description} (评分: {score:.2f}, 置信度: {confidence_level})"
|
||||
|
||||
|
||||
def get_chinese_social_sentiment(ticker: str, curr_date: str) -> str:
|
||||
"""
|
||||
获取中国社交媒体情绪分析的主要接口函数
|
||||
"""
|
||||
aggregator = ChineseFinanceDataAggregator()
|
||||
|
||||
try:
|
||||
# 获取情绪分析数据
|
||||
sentiment_data = aggregator.get_stock_sentiment_summary(ticker, days=7)
|
||||
|
||||
# 格式化输出
|
||||
if 'error' in sentiment_data:
|
||||
return f"""
|
||||
中国市场情绪分析报告 - {ticker}
|
||||
分析日期: {curr_date}
|
||||
|
||||
⚠️ 数据获取限制说明:
|
||||
{sentiment_data.get('fallback_message', '数据获取遇到技术限制')}
|
||||
|
||||
建议:
|
||||
1. 重点关注财经新闻和基本面分析
|
||||
2. 参考官方财报和业绩指导
|
||||
3. 关注行业政策和监管动态
|
||||
4. 考虑国际市场情绪对中概股的影响
|
||||
|
||||
注: 由于中国社交媒体平台API限制,当前主要依赖公开财经数据源进行分析。
|
||||
"""
|
||||
|
||||
overall = sentiment_data.get('overall_sentiment', {})
|
||||
news = sentiment_data.get('news_sentiment', {})
|
||||
|
||||
return f"""
|
||||
中国市场情绪分析报告 - {ticker}
|
||||
分析日期: {curr_date}
|
||||
分析周期: {sentiment_data.get('analysis_period', '7天')}
|
||||
|
||||
📊 综合情绪评估:
|
||||
{sentiment_data.get('summary', '数据不足')}
|
||||
|
||||
📰 财经新闻情绪:
|
||||
- 情绪评分: {news.get('sentiment_score', 0):.2f}
|
||||
- 正面新闻比例: {news.get('positive_ratio', 0):.1%}
|
||||
- 负面新闻比例: {news.get('negative_ratio', 0):.1%}
|
||||
- 新闻数量: {news.get('news_count', 0)}条
|
||||
|
||||
💡 投资建议:
|
||||
基于当前可获取的中国市场数据,建议投资者:
|
||||
1. 密切关注官方财经媒体报道
|
||||
2. 重视基本面分析和财务数据
|
||||
3. 考虑政策环境对股价的影响
|
||||
4. 关注国际市场动态
|
||||
|
||||
⚠️ 数据说明:
|
||||
由于中国社交媒体平台API获取限制,本分析主要基于公开财经新闻数据。
|
||||
建议结合其他分析维度进行综合判断。
|
||||
|
||||
生成时间: {sentiment_data.get('timestamp', datetime.now().isoformat())}
|
||||
"""
|
||||
|
||||
except Exception as e:
|
||||
return f"""
|
||||
中国市场情绪分析 - {ticker}
|
||||
分析日期: {curr_date}
|
||||
|
||||
❌ 分析失败: {str(e)}
|
||||
|
||||
💡 替代建议:
|
||||
1. 查看财经新闻网站的相关报道
|
||||
2. 关注雪球、东方财富等投资社区讨论
|
||||
3. 参考专业机构的研究报告
|
||||
4. 重点分析基本面和技术面数据
|
||||
|
||||
注: 中国社交媒体数据获取存在技术限制,建议以基本面分析为主。
|
||||
"""
|
||||
|
|
@ -0,0 +1,528 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
MongoDB + Redis 数据库缓存管理器
|
||||
提供高性能的股票数据缓存和持久化存储
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
import pandas as pd
|
||||
|
||||
# MongoDB
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
|
||||
MONGODB_AVAILABLE = True
|
||||
except ImportError:
|
||||
MONGODB_AVAILABLE = False
|
||||
print("⚠️ pymongo 未安装,MongoDB功能不可用")
|
||||
|
||||
# Redis
|
||||
try:
|
||||
import redis
|
||||
from redis.exceptions import ConnectionError as RedisConnectionError
|
||||
REDIS_AVAILABLE = True
|
||||
except ImportError:
|
||||
REDIS_AVAILABLE = False
|
||||
print("⚠️ redis 未安装,Redis功能不可用")
|
||||
|
||||
|
||||
class DatabaseCacheManager:
|
||||
"""MongoDB + Redis 数据库缓存管理器"""
|
||||
|
||||
def __init__(self,
|
||||
mongodb_url: Optional[str] = None,
|
||||
redis_url: Optional[str] = None,
|
||||
mongodb_db: str = "tradingagents",
|
||||
redis_db: int = 0):
|
||||
"""
|
||||
初始化数据库缓存管理器
|
||||
|
||||
Args:
|
||||
mongodb_url: MongoDB连接URL,默认使用配置文件端口
|
||||
redis_url: Redis连接URL,默认使用配置文件端口
|
||||
mongodb_db: MongoDB数据库名
|
||||
redis_db: Redis数据库编号
|
||||
"""
|
||||
# 从配置文件获取正确的端口
|
||||
mongodb_port = os.getenv("MONGODB_PORT", "27018")
|
||||
redis_port = os.getenv("REDIS_PORT", "6380")
|
||||
mongodb_password = os.getenv("MONGODB_PASSWORD", "tradingagents123")
|
||||
redis_password = os.getenv("REDIS_PASSWORD", "tradingagents123")
|
||||
|
||||
self.mongodb_url = mongodb_url or os.getenv("MONGODB_URL", f"mongodb://admin:{mongodb_password}@localhost:{mongodb_port}")
|
||||
self.redis_url = redis_url or os.getenv("REDIS_URL", f"redis://:{redis_password}@localhost:{redis_port}")
|
||||
self.mongodb_db_name = mongodb_db
|
||||
self.redis_db = redis_db
|
||||
|
||||
# 初始化连接
|
||||
self.mongodb_client = None
|
||||
self.mongodb_db = None
|
||||
self.redis_client = None
|
||||
|
||||
self._init_mongodb()
|
||||
self._init_redis()
|
||||
|
||||
print(f"🗄️ 数据库缓存管理器初始化完成")
|
||||
print(f" MongoDB: {'✅ 已连接' if self.mongodb_client else '❌ 未连接'}")
|
||||
print(f" Redis: {'✅ 已连接' if self.redis_client else '❌ 未连接'}")
|
||||
|
||||
def _init_mongodb(self):
|
||||
"""初始化MongoDB连接"""
|
||||
if not MONGODB_AVAILABLE:
|
||||
return
|
||||
|
||||
try:
|
||||
self.mongodb_client = MongoClient(
|
||||
self.mongodb_url,
|
||||
serverSelectionTimeoutMS=5000, # 5秒超时
|
||||
connectTimeoutMS=5000
|
||||
)
|
||||
# 测试连接
|
||||
self.mongodb_client.admin.command('ping')
|
||||
self.mongodb_db = self.mongodb_client[self.mongodb_db_name]
|
||||
|
||||
# 创建索引
|
||||
self._create_mongodb_indexes()
|
||||
|
||||
print(f"✅ MongoDB连接成功: {self.mongodb_url}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ MongoDB连接失败: {e}")
|
||||
self.mongodb_client = None
|
||||
self.mongodb_db = None
|
||||
|
||||
def _init_redis(self):
|
||||
"""初始化Redis连接"""
|
||||
if not REDIS_AVAILABLE:
|
||||
return
|
||||
|
||||
try:
|
||||
self.redis_client = redis.from_url(
|
||||
self.redis_url,
|
||||
db=self.redis_db,
|
||||
socket_timeout=5,
|
||||
socket_connect_timeout=5,
|
||||
decode_responses=True
|
||||
)
|
||||
# 测试连接
|
||||
self.redis_client.ping()
|
||||
|
||||
print(f"✅ Redis连接成功: {self.redis_url}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Redis连接失败: {e}")
|
||||
self.redis_client = None
|
||||
|
||||
def _create_mongodb_indexes(self):
|
||||
"""创建MongoDB索引"""
|
||||
if self.mongodb_db is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# 股票数据集合索引
|
||||
stock_collection = self.mongodb_db.stock_data
|
||||
stock_collection.create_index([
|
||||
("symbol", 1),
|
||||
("data_source", 1),
|
||||
("start_date", 1),
|
||||
("end_date", 1)
|
||||
])
|
||||
stock_collection.create_index([("created_at", 1)])
|
||||
|
||||
# 新闻数据集合索引
|
||||
news_collection = self.mongodb_db.news_data
|
||||
news_collection.create_index([
|
||||
("symbol", 1),
|
||||
("data_source", 1),
|
||||
("date_range", 1)
|
||||
])
|
||||
news_collection.create_index([("created_at", 1)])
|
||||
|
||||
# 基本面数据集合索引
|
||||
fundamentals_collection = self.mongodb_db.fundamentals_data
|
||||
fundamentals_collection.create_index([
|
||||
("symbol", 1),
|
||||
("data_source", 1),
|
||||
("analysis_date", 1)
|
||||
])
|
||||
fundamentals_collection.create_index([("created_at", 1)])
|
||||
|
||||
print("✅ MongoDB索引创建完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ MongoDB索引创建失败: {e}")
|
||||
|
||||
def _generate_cache_key(self, data_type: str, symbol: str, **kwargs) -> str:
|
||||
"""生成缓存键"""
|
||||
params_str = f"{data_type}_{symbol}"
|
||||
for key, value in sorted(kwargs.items()):
|
||||
params_str += f"_{key}_{value}"
|
||||
|
||||
cache_key = hashlib.md5(params_str.encode()).hexdigest()[:16]
|
||||
return f"{data_type}:{symbol}:{cache_key}"
|
||||
|
||||
def save_stock_data(self, symbol: str, data: Union[pd.DataFrame, str],
|
||||
start_date: str = None, end_date: str = None,
|
||||
data_source: str = "unknown", market_type: str = None) -> str:
|
||||
"""
|
||||
保存股票数据到MongoDB和Redis
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
data: 股票数据
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
data_source: 数据源
|
||||
market_type: 市场类型 (us/china)
|
||||
|
||||
Returns:
|
||||
cache_key: 缓存键
|
||||
"""
|
||||
cache_key = self._generate_cache_key("stock", symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
source=data_source)
|
||||
|
||||
# 自动推断市场类型
|
||||
if market_type is None:
|
||||
# 根据股票代码格式推断市场类型
|
||||
import re
|
||||
if re.match(r'^\d{6}$', symbol): # 6位数字为A股
|
||||
market_type = "china"
|
||||
else: # 其他格式为美股
|
||||
market_type = "us"
|
||||
|
||||
# 准备文档数据
|
||||
doc = {
|
||||
"_id": cache_key,
|
||||
"symbol": symbol,
|
||||
"market_type": market_type,
|
||||
"data_type": "stock_data",
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"data_source": data_source,
|
||||
"created_at": datetime.utcnow(),
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
# 处理数据格式
|
||||
if isinstance(data, pd.DataFrame):
|
||||
doc["data"] = data.to_json(orient='records', date_format='iso')
|
||||
doc["data_format"] = "dataframe_json"
|
||||
else:
|
||||
doc["data"] = str(data)
|
||||
doc["data_format"] = "text"
|
||||
|
||||
# 保存到MongoDB(持久化)
|
||||
if self.mongodb_db is not None:
|
||||
try:
|
||||
collection = self.mongodb_db.stock_data
|
||||
collection.replace_one({"_id": cache_key}, doc, upsert=True)
|
||||
print(f"💾 股票数据已保存到MongoDB: {symbol} -> {cache_key}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ MongoDB保存失败: {e}")
|
||||
|
||||
# 保存到Redis(快速缓存,6小时过期)
|
||||
if self.redis_client:
|
||||
try:
|
||||
redis_data = {
|
||||
"data": doc["data"],
|
||||
"data_format": doc["data_format"],
|
||||
"symbol": symbol,
|
||||
"data_source": data_source,
|
||||
"created_at": doc["created_at"].isoformat()
|
||||
}
|
||||
self.redis_client.setex(
|
||||
cache_key,
|
||||
6 * 3600, # 6小时过期
|
||||
json.dumps(redis_data, ensure_ascii=False)
|
||||
)
|
||||
print(f"⚡ 股票数据已缓存到Redis: {symbol} -> {cache_key}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Redis缓存失败: {e}")
|
||||
|
||||
return cache_key
|
||||
|
||||
def load_stock_data(self, cache_key: str) -> Optional[Union[pd.DataFrame, str]]:
|
||||
"""从Redis或MongoDB加载股票数据"""
|
||||
|
||||
# 首先尝试从Redis加载(更快)
|
||||
if self.redis_client:
|
||||
try:
|
||||
redis_data = self.redis_client.get(cache_key)
|
||||
if redis_data:
|
||||
data_dict = json.loads(redis_data)
|
||||
print(f"⚡ 从Redis加载数据: {cache_key}")
|
||||
|
||||
if data_dict["data_format"] == "dataframe_json":
|
||||
return pd.read_json(data_dict["data"], orient='records')
|
||||
else:
|
||||
return data_dict["data"]
|
||||
except Exception as e:
|
||||
print(f"⚠️ Redis加载失败: {e}")
|
||||
|
||||
# 如果Redis没有,从MongoDB加载
|
||||
if self.mongodb_db is not None:
|
||||
try:
|
||||
collection = self.mongodb_db.stock_data
|
||||
doc = collection.find_one({"_id": cache_key})
|
||||
|
||||
if doc:
|
||||
print(f"💾 从MongoDB加载数据: {cache_key}")
|
||||
|
||||
# 同时更新到Redis缓存
|
||||
if self.redis_client:
|
||||
try:
|
||||
redis_data = {
|
||||
"data": doc["data"],
|
||||
"data_format": doc["data_format"],
|
||||
"symbol": doc["symbol"],
|
||||
"data_source": doc["data_source"],
|
||||
"created_at": doc["created_at"].isoformat()
|
||||
}
|
||||
self.redis_client.setex(
|
||||
cache_key,
|
||||
6 * 3600,
|
||||
json.dumps(redis_data, ensure_ascii=False)
|
||||
)
|
||||
print(f"⚡ 数据已同步到Redis缓存")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Redis同步失败: {e}")
|
||||
|
||||
if doc["data_format"] == "dataframe_json":
|
||||
return pd.read_json(doc["data"], orient='records')
|
||||
else:
|
||||
return doc["data"]
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ MongoDB加载失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def find_cached_stock_data(self, symbol: str, start_date: str = None,
|
||||
end_date: str = None, data_source: str = None,
|
||||
max_age_hours: int = 6) -> Optional[str]:
|
||||
"""查找匹配的缓存数据"""
|
||||
|
||||
# 生成精确匹配的缓存键
|
||||
exact_key = self._generate_cache_key("stock", symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
source=data_source)
|
||||
|
||||
# 检查Redis中是否有精确匹配
|
||||
if self.redis_client and self.redis_client.exists(exact_key):
|
||||
print(f"⚡ Redis中找到精确匹配: {symbol} -> {exact_key}")
|
||||
return exact_key
|
||||
|
||||
# 检查MongoDB中的匹配项
|
||||
if self.mongodb_db is not None:
|
||||
try:
|
||||
collection = self.mongodb_db.stock_data
|
||||
cutoff_time = datetime.utcnow() - timedelta(hours=max_age_hours)
|
||||
|
||||
query = {
|
||||
"symbol": symbol,
|
||||
"created_at": {"$gte": cutoff_time}
|
||||
}
|
||||
|
||||
if data_source:
|
||||
query["data_source"] = data_source
|
||||
if start_date:
|
||||
query["start_date"] = start_date
|
||||
if end_date:
|
||||
query["end_date"] = end_date
|
||||
|
||||
doc = collection.find_one(query, sort=[("created_at", -1)])
|
||||
|
||||
if doc:
|
||||
cache_key = doc["_id"]
|
||||
print(f"💾 MongoDB中找到匹配: {symbol} -> {cache_key}")
|
||||
return cache_key
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ MongoDB查询失败: {e}")
|
||||
|
||||
print(f"❌ 未找到有效缓存: {symbol}")
|
||||
return None
|
||||
|
||||
def save_news_data(self, symbol: str, news_data: str,
|
||||
start_date: str = None, end_date: str = None,
|
||||
data_source: str = "unknown") -> str:
|
||||
"""保存新闻数据到MongoDB和Redis"""
|
||||
cache_key = self._generate_cache_key("news", symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
source=data_source)
|
||||
|
||||
doc = {
|
||||
"_id": cache_key,
|
||||
"symbol": symbol,
|
||||
"data_type": "news_data",
|
||||
"date_range": f"{start_date}_{end_date}",
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"data_source": data_source,
|
||||
"data": news_data,
|
||||
"created_at": datetime.utcnow(),
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
# 保存到MongoDB
|
||||
if self.mongodb_db is not None:
|
||||
try:
|
||||
collection = self.mongodb_db.news_data
|
||||
collection.replace_one({"_id": cache_key}, doc, upsert=True)
|
||||
print(f"📰 新闻数据已保存到MongoDB: {symbol} -> {cache_key}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ MongoDB保存失败: {e}")
|
||||
|
||||
# 保存到Redis(24小时过期)
|
||||
if self.redis_client:
|
||||
try:
|
||||
redis_data = {
|
||||
"data": news_data,
|
||||
"symbol": symbol,
|
||||
"data_source": data_source,
|
||||
"created_at": doc["created_at"].isoformat()
|
||||
}
|
||||
self.redis_client.setex(
|
||||
cache_key,
|
||||
24 * 3600, # 24小时过期
|
||||
json.dumps(redis_data, ensure_ascii=False)
|
||||
)
|
||||
print(f"⚡ 新闻数据已缓存到Redis: {symbol} -> {cache_key}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Redis缓存失败: {e}")
|
||||
|
||||
return cache_key
|
||||
|
||||
def save_fundamentals_data(self, symbol: str, fundamentals_data: str,
|
||||
analysis_date: str = None,
|
||||
data_source: str = "unknown") -> str:
|
||||
"""保存基本面数据到MongoDB和Redis"""
|
||||
if not analysis_date:
|
||||
analysis_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
cache_key = self._generate_cache_key("fundamentals", symbol,
|
||||
date=analysis_date,
|
||||
source=data_source)
|
||||
|
||||
doc = {
|
||||
"_id": cache_key,
|
||||
"symbol": symbol,
|
||||
"data_type": "fundamentals_data",
|
||||
"analysis_date": analysis_date,
|
||||
"data_source": data_source,
|
||||
"data": fundamentals_data,
|
||||
"created_at": datetime.utcnow(),
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
# 保存到MongoDB
|
||||
if self.mongodb_db is not None:
|
||||
try:
|
||||
collection = self.mongodb_db.fundamentals_data
|
||||
collection.replace_one({"_id": cache_key}, doc, upsert=True)
|
||||
print(f"💼 基本面数据已保存到MongoDB: {symbol} -> {cache_key}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ MongoDB保存失败: {e}")
|
||||
|
||||
# 保存到Redis(24小时过期)
|
||||
if self.redis_client:
|
||||
try:
|
||||
redis_data = {
|
||||
"data": fundamentals_data,
|
||||
"symbol": symbol,
|
||||
"data_source": data_source,
|
||||
"analysis_date": analysis_date,
|
||||
"created_at": doc["created_at"].isoformat()
|
||||
}
|
||||
self.redis_client.setex(
|
||||
cache_key,
|
||||
24 * 3600, # 24小时过期
|
||||
json.dumps(redis_data, ensure_ascii=False)
|
||||
)
|
||||
print(f"⚡ 基本面数据已缓存到Redis: {symbol} -> {cache_key}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Redis缓存失败: {e}")
|
||||
|
||||
return cache_key
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
stats = {
|
||||
"mongodb": {"available": self.mongodb_db is not None, "collections": {}},
|
||||
"redis": {"available": self.redis_client is not None, "keys": 0, "memory_usage": "N/A"}
|
||||
}
|
||||
|
||||
# MongoDB统计
|
||||
if self.mongodb_db is not None:
|
||||
try:
|
||||
for collection_name in ["stock_data", "news_data", "fundamentals_data"]:
|
||||
collection = self.mongodb_db[collection_name]
|
||||
count = collection.count_documents({})
|
||||
size = self.mongodb_db.command("collStats", collection_name).get("size", 0)
|
||||
stats["mongodb"]["collections"][collection_name] = {
|
||||
"count": count,
|
||||
"size_mb": round(size / (1024 * 1024), 2)
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"⚠️ MongoDB统计获取失败: {e}")
|
||||
|
||||
# Redis统计
|
||||
if self.redis_client:
|
||||
try:
|
||||
info = self.redis_client.info()
|
||||
stats["redis"]["keys"] = info.get("db0", {}).get("keys", 0)
|
||||
stats["redis"]["memory_usage"] = f"{info.get('used_memory_human', 'N/A')}"
|
||||
except Exception as e:
|
||||
print(f"⚠️ Redis统计获取失败: {e}")
|
||||
|
||||
return stats
|
||||
|
||||
def clear_old_cache(self, max_age_days: int = 7):
|
||||
"""清理过期缓存"""
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=max_age_days)
|
||||
cleared_count = 0
|
||||
|
||||
# 清理MongoDB
|
||||
if self.mongodb_db is not None:
|
||||
try:
|
||||
for collection_name in ["stock_data", "news_data", "fundamentals_data"]:
|
||||
collection = self.mongodb_db[collection_name]
|
||||
result = collection.delete_many({"created_at": {"$lt": cutoff_time}})
|
||||
cleared_count += result.deleted_count
|
||||
print(f"🧹 MongoDB {collection_name} 清理了 {result.deleted_count} 条记录")
|
||||
except Exception as e:
|
||||
print(f"⚠️ MongoDB清理失败: {e}")
|
||||
|
||||
# Redis会自动过期,不需要手动清理
|
||||
print(f"🧹 总共清理了 {cleared_count} 条过期记录")
|
||||
return cleared_count
|
||||
|
||||
def close(self):
|
||||
"""关闭数据库连接"""
|
||||
if self.mongodb_client:
|
||||
self.mongodb_client.close()
|
||||
print("🔒 MongoDB连接已关闭")
|
||||
|
||||
if self.redis_client:
|
||||
self.redis_client.close()
|
||||
print("🔒 Redis连接已关闭")
|
||||
|
||||
|
||||
# 全局数据库缓存实例
|
||||
_db_cache_instance = None
|
||||
|
||||
def get_db_cache() -> DatabaseCacheManager:
|
||||
"""获取全局数据库缓存实例"""
|
||||
global _db_cache_instance
|
||||
if _db_cache_instance is None:
|
||||
_db_cache_instance = DatabaseCacheManager()
|
||||
return _db_cache_instance
|
||||
|
|
@ -0,0 +1,286 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
集成缓存管理器
|
||||
结合原有缓存系统和新的自适应数据库支持
|
||||
提供向后兼容的接口
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import pandas as pd
|
||||
|
||||
# 导入原有缓存系统
|
||||
from .cache_manager import StockDataCache
|
||||
|
||||
# 导入自适应缓存系统
|
||||
try:
|
||||
from .adaptive_cache import get_cache_system
|
||||
from ..config.database_manager import get_database_manager
|
||||
ADAPTIVE_CACHE_AVAILABLE = True
|
||||
except ImportError:
|
||||
ADAPTIVE_CACHE_AVAILABLE = False
|
||||
|
||||
class IntegratedCacheManager:
|
||||
"""集成缓存管理器 - 智能选择缓存策略"""
|
||||
|
||||
def __init__(self, cache_dir: str = None):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化原有缓存系统(作为备用)
|
||||
self.legacy_cache = StockDataCache(cache_dir)
|
||||
|
||||
# 尝试初始化自适应缓存系统
|
||||
self.adaptive_cache = None
|
||||
self.use_adaptive = False
|
||||
|
||||
if ADAPTIVE_CACHE_AVAILABLE:
|
||||
try:
|
||||
self.adaptive_cache = get_cache_system()
|
||||
self.db_manager = get_database_manager()
|
||||
self.use_adaptive = True
|
||||
self.logger.info("✅ 自适应缓存系统已启用")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"自适应缓存系统初始化失败,使用传统缓存: {e}")
|
||||
self.use_adaptive = False
|
||||
else:
|
||||
self.logger.info("自适应缓存系统不可用,使用传统文件缓存")
|
||||
|
||||
# 显示当前配置
|
||||
self._log_cache_status()
|
||||
|
||||
def _log_cache_status(self):
|
||||
"""记录缓存状态"""
|
||||
if self.use_adaptive:
|
||||
backend = self.adaptive_cache.primary_backend
|
||||
mongodb_available = self.db_manager.is_mongodb_available()
|
||||
redis_available = self.db_manager.is_redis_available()
|
||||
|
||||
self.logger.info(f"📊 缓存配置:")
|
||||
self.logger.info(f" 主要后端: {backend}")
|
||||
self.logger.info(f" MongoDB: {'✅ 可用' if mongodb_available else '❌ 不可用'}")
|
||||
self.logger.info(f" Redis: {'✅ 可用' if redis_available else '❌ 不可用'}")
|
||||
self.logger.info(f" 降级支持: {'✅ 启用' if self.adaptive_cache.fallback_enabled else '❌ 禁用'}")
|
||||
else:
|
||||
self.logger.info("📁 使用传统文件缓存系统")
|
||||
|
||||
def save_stock_data(self, symbol: str, data: Any, start_date: str = None,
|
||||
end_date: str = None, data_source: str = "default") -> str:
|
||||
"""
|
||||
保存股票数据到缓存
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
data: 股票数据
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
data_source: 数据源
|
||||
|
||||
Returns:
|
||||
缓存键
|
||||
"""
|
||||
if self.use_adaptive:
|
||||
# 使用自适应缓存系统
|
||||
return self.adaptive_cache.save_data(
|
||||
symbol=symbol,
|
||||
data=data,
|
||||
start_date=start_date or "",
|
||||
end_date=end_date or "",
|
||||
data_source=data_source,
|
||||
data_type="stock_data"
|
||||
)
|
||||
else:
|
||||
# 使用传统缓存系统
|
||||
return self.legacy_cache.save_stock_data(
|
||||
symbol=symbol,
|
||||
data=data,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
data_source=data_source
|
||||
)
|
||||
|
||||
def load_stock_data(self, cache_key: str) -> Optional[Any]:
|
||||
"""
|
||||
从缓存加载股票数据
|
||||
|
||||
Args:
|
||||
cache_key: 缓存键
|
||||
|
||||
Returns:
|
||||
股票数据或None
|
||||
"""
|
||||
if self.use_adaptive:
|
||||
# 使用自适应缓存系统
|
||||
return self.adaptive_cache.load_data(cache_key)
|
||||
else:
|
||||
# 使用传统缓存系统
|
||||
return self.legacy_cache.load_stock_data(cache_key)
|
||||
|
||||
def find_cached_stock_data(self, symbol: str, start_date: str = None,
|
||||
end_date: str = None, data_source: str = "default") -> Optional[str]:
|
||||
"""
|
||||
查找缓存的股票数据
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
data_source: 数据源
|
||||
|
||||
Returns:
|
||||
缓存键或None
|
||||
"""
|
||||
if self.use_adaptive:
|
||||
# 使用自适应缓存系统
|
||||
return self.adaptive_cache.find_cached_data(
|
||||
symbol=symbol,
|
||||
start_date=start_date or "",
|
||||
end_date=end_date or "",
|
||||
data_source=data_source,
|
||||
data_type="stock_data"
|
||||
)
|
||||
else:
|
||||
# 使用传统缓存系统
|
||||
return self.legacy_cache.find_cached_stock_data(
|
||||
symbol=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
data_source=data_source
|
||||
)
|
||||
|
||||
def save_news_data(self, symbol: str, data: Any, data_source: str = "default") -> str:
|
||||
"""保存新闻数据"""
|
||||
if self.use_adaptive:
|
||||
return self.adaptive_cache.save_data(
|
||||
symbol=symbol,
|
||||
data=data,
|
||||
data_source=data_source,
|
||||
data_type="news_data"
|
||||
)
|
||||
else:
|
||||
return self.legacy_cache.save_news_data(symbol, data, data_source)
|
||||
|
||||
def load_news_data(self, cache_key: str) -> Optional[Any]:
|
||||
"""加载新闻数据"""
|
||||
if self.use_adaptive:
|
||||
return self.adaptive_cache.load_data(cache_key)
|
||||
else:
|
||||
return self.legacy_cache.load_news_data(cache_key)
|
||||
|
||||
def save_fundamentals_data(self, symbol: str, data: Any, data_source: str = "default") -> str:
|
||||
"""保存基本面数据"""
|
||||
if self.use_adaptive:
|
||||
return self.adaptive_cache.save_data(
|
||||
symbol=symbol,
|
||||
data=data,
|
||||
data_source=data_source,
|
||||
data_type="fundamentals_data"
|
||||
)
|
||||
else:
|
||||
return self.legacy_cache.save_fundamentals_data(symbol, data, data_source)
|
||||
|
||||
def load_fundamentals_data(self, cache_key: str) -> Optional[Any]:
|
||||
"""加载基本面数据"""
|
||||
if self.use_adaptive:
|
||||
return self.adaptive_cache.load_data(cache_key)
|
||||
else:
|
||||
return self.legacy_cache.load_fundamentals_data(cache_key)
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
if self.use_adaptive:
|
||||
# 获取自适应缓存统计
|
||||
adaptive_stats = self.adaptive_cache.get_cache_stats()
|
||||
|
||||
# 添加传统缓存统计
|
||||
legacy_stats = self.legacy_cache.get_cache_stats()
|
||||
|
||||
return {
|
||||
"cache_system": "adaptive",
|
||||
"adaptive_cache": adaptive_stats,
|
||||
"legacy_cache": legacy_stats,
|
||||
"database_available": self.db_manager.is_database_available(),
|
||||
"mongodb_available": self.db_manager.is_mongodb_available(),
|
||||
"redis_available": self.db_manager.is_redis_available()
|
||||
}
|
||||
else:
|
||||
# 只返回传统缓存统计
|
||||
legacy_stats = self.legacy_cache.get_cache_stats()
|
||||
return {
|
||||
"cache_system": "legacy",
|
||||
"legacy_cache": legacy_stats,
|
||||
"database_available": False,
|
||||
"mongodb_available": False,
|
||||
"redis_available": False
|
||||
}
|
||||
|
||||
def clear_expired_cache(self):
|
||||
"""清理过期缓存"""
|
||||
if self.use_adaptive:
|
||||
self.adaptive_cache.clear_expired_cache()
|
||||
|
||||
# 总是清理传统缓存
|
||||
self.legacy_cache.clear_expired_cache()
|
||||
|
||||
def get_cache_backend_info(self) -> Dict[str, Any]:
|
||||
"""获取缓存后端信息"""
|
||||
if self.use_adaptive:
|
||||
return {
|
||||
"system": "adaptive",
|
||||
"primary_backend": self.adaptive_cache.primary_backend,
|
||||
"fallback_enabled": self.adaptive_cache.fallback_enabled,
|
||||
"mongodb_available": self.db_manager.is_mongodb_available(),
|
||||
"redis_available": self.db_manager.is_redis_available()
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"system": "legacy",
|
||||
"primary_backend": "file",
|
||||
"fallback_enabled": False,
|
||||
"mongodb_available": False,
|
||||
"redis_available": False
|
||||
}
|
||||
|
||||
def is_database_available(self) -> bool:
|
||||
"""检查数据库是否可用"""
|
||||
if self.use_adaptive:
|
||||
return self.db_manager.is_database_available()
|
||||
return False
|
||||
|
||||
def get_performance_mode(self) -> str:
|
||||
"""获取性能模式"""
|
||||
if not self.use_adaptive:
|
||||
return "基础模式 (文件缓存)"
|
||||
|
||||
mongodb_available = self.db_manager.is_mongodb_available()
|
||||
redis_available = self.db_manager.is_redis_available()
|
||||
|
||||
if redis_available and mongodb_available:
|
||||
return "高性能模式 (Redis + MongoDB + 文件)"
|
||||
elif redis_available:
|
||||
return "快速模式 (Redis + 文件)"
|
||||
elif mongodb_available:
|
||||
return "持久化模式 (MongoDB + 文件)"
|
||||
else:
|
||||
return "标准模式 (智能文件缓存)"
|
||||
|
||||
|
||||
# 全局集成缓存管理器实例
|
||||
_integrated_cache = None
|
||||
|
||||
def get_cache() -> IntegratedCacheManager:
|
||||
"""获取全局集成缓存管理器实例"""
|
||||
global _integrated_cache
|
||||
if _integrated_cache is None:
|
||||
_integrated_cache = IntegratedCacheManager()
|
||||
return _integrated_cache
|
||||
|
||||
# 向后兼容的函数
|
||||
def get_stock_cache():
|
||||
"""向后兼容:获取股票缓存"""
|
||||
return get_cache()
|
||||
|
||||
def create_cache_manager(cache_dir: str = None):
|
||||
"""向后兼容:创建缓存管理器"""
|
||||
return IntegratedCacheManager(cache_dir)
|
||||
|
|
@ -0,0 +1,807 @@
|
|||
from typing import Annotated, Dict
|
||||
from .reddit_utils import fetch_top_from_category
|
||||
from .yfin_utils import *
|
||||
from .stockstats_utils import *
|
||||
from .googlenews_utils import *
|
||||
from .finnhub_utils import get_data_in_range
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
import yfinance as yf
|
||||
from openai import OpenAI
|
||||
from .config import get_config, set_config, DATA_DIR
|
||||
|
||||
|
||||
def get_finnhub_news(
|
||||
ticker: Annotated[
|
||||
str,
|
||||
"Search query of a company's, e.g. 'AAPL, TSM, etc.",
|
||||
],
|
||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||
look_back_days: Annotated[int, "how many days to look back"],
|
||||
):
|
||||
"""
|
||||
Retrieve news about a company within a time frame
|
||||
|
||||
Args
|
||||
ticker (str): ticker for the company you are interested in
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns
|
||||
str: dataframe containing the news of the company in the time frame
|
||||
|
||||
"""
|
||||
|
||||
start_date = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
before = start_date - relativedelta(days=look_back_days)
|
||||
before = before.strftime("%Y-%m-%d")
|
||||
|
||||
result = get_data_in_range(ticker, before, curr_date, "news_data", DATA_DIR)
|
||||
|
||||
if len(result) == 0:
|
||||
return ""
|
||||
|
||||
combined_result = ""
|
||||
for day, data in result.items():
|
||||
if len(data) == 0:
|
||||
continue
|
||||
for entry in data:
|
||||
current_news = (
|
||||
"### " + entry["headline"] + f" ({day})" + "\n" + entry["summary"]
|
||||
)
|
||||
combined_result += current_news + "\n\n"
|
||||
|
||||
return f"## {ticker} News, from {before} to {curr_date}:\n" + str(combined_result)
|
||||
|
||||
|
||||
def get_finnhub_company_insider_sentiment(
|
||||
ticker: Annotated[str, "ticker symbol for the company"],
|
||||
curr_date: Annotated[
|
||||
str,
|
||||
"current date of you are trading at, yyyy-mm-dd",
|
||||
],
|
||||
look_back_days: Annotated[int, "number of days to look back"],
|
||||
):
|
||||
"""
|
||||
Retrieve insider sentiment about a company (retrieved from public SEC information) for the past 15 days
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
curr_date (str): current date you are trading on, yyyy-mm-dd
|
||||
Returns:
|
||||
str: a report of the sentiment in the past 15 days starting at curr_date
|
||||
"""
|
||||
|
||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
before = date_obj - relativedelta(days=look_back_days)
|
||||
before = before.strftime("%Y-%m-%d")
|
||||
|
||||
data = get_data_in_range(ticker, before, curr_date, "insider_senti", DATA_DIR)
|
||||
|
||||
if len(data) == 0:
|
||||
return ""
|
||||
|
||||
result_str = ""
|
||||
seen_dicts = []
|
||||
for date, senti_list in data.items():
|
||||
for entry in senti_list:
|
||||
if entry not in seen_dicts:
|
||||
result_str += f"### {entry['year']}-{entry['month']}:\nChange: {entry['change']}\nMonthly Share Purchase Ratio: {entry['mspr']}\n\n"
|
||||
seen_dicts.append(entry)
|
||||
|
||||
return (
|
||||
f"## {ticker} Insider Sentiment Data for {before} to {curr_date}:\n"
|
||||
+ result_str
|
||||
+ "The change field refers to the net buying/selling from all insiders' transactions. The mspr field refers to monthly share purchase ratio."
|
||||
)
|
||||
|
||||
|
||||
def get_finnhub_company_insider_transactions(
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
curr_date: Annotated[
|
||||
str,
|
||||
"current date you are trading at, yyyy-mm-dd",
|
||||
],
|
||||
look_back_days: Annotated[int, "how many days to look back"],
|
||||
):
|
||||
"""
|
||||
Retrieve insider transcaction information about a company (retrieved from public SEC information) for the past 15 days
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
str: a report of the company's insider transaction/trading informtaion in the past 15 days
|
||||
"""
|
||||
|
||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
before = date_obj - relativedelta(days=look_back_days)
|
||||
before = before.strftime("%Y-%m-%d")
|
||||
|
||||
data = get_data_in_range(ticker, before, curr_date, "insider_trans", DATA_DIR)
|
||||
|
||||
if len(data) == 0:
|
||||
return ""
|
||||
|
||||
result_str = ""
|
||||
|
||||
seen_dicts = []
|
||||
for date, senti_list in data.items():
|
||||
for entry in senti_list:
|
||||
if entry not in seen_dicts:
|
||||
result_str += f"### Filing Date: {entry['filingDate']}, {entry['name']}:\nChange:{entry['change']}\nShares: {entry['share']}\nTransaction Price: {entry['transactionPrice']}\nTransaction Code: {entry['transactionCode']}\n\n"
|
||||
seen_dicts.append(entry)
|
||||
|
||||
return (
|
||||
f"## {ticker} insider transactions from {before} to {curr_date}:\n"
|
||||
+ result_str
|
||||
+ "The change field reflects the variation in share count—here a negative number indicates a reduction in holdings—while share specifies the total number of shares involved. The transactionPrice denotes the per-share price at which the trade was executed, and transactionDate marks when the transaction occurred. The name field identifies the insider making the trade, and transactionCode (e.g., S for sale) clarifies the nature of the transaction. FilingDate records when the transaction was officially reported, and the unique id links to the specific SEC filing, as indicated by the source. Additionally, the symbol ties the transaction to a particular company, isDerivative flags whether the trade involves derivative securities, and currency notes the currency context of the transaction."
|
||||
)
|
||||
|
||||
|
||||
def get_simfin_balance_sheet(
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
freq: Annotated[
|
||||
str,
|
||||
"reporting frequency of the company's financial history: annual / quarterly",
|
||||
],
|
||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
||||
):
|
||||
data_path = os.path.join(
|
||||
DATA_DIR,
|
||||
"fundamental_data",
|
||||
"simfin_data_all",
|
||||
"balance_sheet",
|
||||
"companies",
|
||||
"us",
|
||||
f"us-balance-{freq}.csv",
|
||||
)
|
||||
df = pd.read_csv(data_path, sep=";")
|
||||
|
||||
# Convert date strings to datetime objects and remove any time components
|
||||
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
|
||||
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
|
||||
|
||||
# Convert the current date to datetime and normalize
|
||||
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
|
||||
|
||||
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
|
||||
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
|
||||
|
||||
# Check if there are any available reports; if not, return a notification
|
||||
if filtered_df.empty:
|
||||
print("No balance sheet available before the given current date.")
|
||||
return ""
|
||||
|
||||
# Get the most recent balance sheet by selecting the row with the latest Publish Date
|
||||
latest_balance_sheet = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
|
||||
|
||||
# drop the SimFinID column
|
||||
latest_balance_sheet = latest_balance_sheet.drop("SimFinId")
|
||||
|
||||
return (
|
||||
f"## {freq} balance sheet for {ticker} released on {str(latest_balance_sheet['Publish Date'])[0:10]}: \n"
|
||||
+ str(latest_balance_sheet)
|
||||
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a breakdown of assets, liabilities, and equity. Assets are grouped as current (liquid items like cash and receivables) and noncurrent (long-term investments and property). Liabilities are split between short-term obligations and long-term debts, while equity reflects shareholder funds such as paid-in capital and retained earnings. Together, these components ensure that total assets equal the sum of liabilities and equity."
|
||||
)
|
||||
|
||||
|
||||
def get_simfin_cashflow(
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
freq: Annotated[
|
||||
str,
|
||||
"reporting frequency of the company's financial history: annual / quarterly",
|
||||
],
|
||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
||||
):
|
||||
data_path = os.path.join(
|
||||
DATA_DIR,
|
||||
"fundamental_data",
|
||||
"simfin_data_all",
|
||||
"cash_flow",
|
||||
"companies",
|
||||
"us",
|
||||
f"us-cashflow-{freq}.csv",
|
||||
)
|
||||
df = pd.read_csv(data_path, sep=";")
|
||||
|
||||
# Convert date strings to datetime objects and remove any time components
|
||||
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
|
||||
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
|
||||
|
||||
# Convert the current date to datetime and normalize
|
||||
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
|
||||
|
||||
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
|
||||
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
|
||||
|
||||
# Check if there are any available reports; if not, return a notification
|
||||
if filtered_df.empty:
|
||||
print("No cash flow statement available before the given current date.")
|
||||
return ""
|
||||
|
||||
# Get the most recent cash flow statement by selecting the row with the latest Publish Date
|
||||
latest_cash_flow = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
|
||||
|
||||
# drop the SimFinID column
|
||||
latest_cash_flow = latest_cash_flow.drop("SimFinId")
|
||||
|
||||
return (
|
||||
f"## {freq} cash flow statement for {ticker} released on {str(latest_cash_flow['Publish Date'])[0:10]}: \n"
|
||||
+ str(latest_cash_flow)
|
||||
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a breakdown of cash movements. Operating activities show cash generated from core business operations, including net income adjustments for non-cash items and working capital changes. Investing activities cover asset acquisitions/disposals and investments. Financing activities include debt transactions, equity issuances/repurchases, and dividend payments. The net change in cash represents the overall increase or decrease in the company's cash position during the reporting period."
|
||||
)
|
||||
|
||||
|
||||
def get_simfin_income_statements(
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
freq: Annotated[
|
||||
str,
|
||||
"reporting frequency of the company's financial history: annual / quarterly",
|
||||
],
|
||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
||||
):
|
||||
data_path = os.path.join(
|
||||
DATA_DIR,
|
||||
"fundamental_data",
|
||||
"simfin_data_all",
|
||||
"income_statements",
|
||||
"companies",
|
||||
"us",
|
||||
f"us-income-{freq}.csv",
|
||||
)
|
||||
df = pd.read_csv(data_path, sep=";")
|
||||
|
||||
# Convert date strings to datetime objects and remove any time components
|
||||
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
|
||||
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
|
||||
|
||||
# Convert the current date to datetime and normalize
|
||||
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
|
||||
|
||||
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
|
||||
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
|
||||
|
||||
# Check if there are any available reports; if not, return a notification
|
||||
if filtered_df.empty:
|
||||
print("No income statement available before the given current date.")
|
||||
return ""
|
||||
|
||||
# Get the most recent income statement by selecting the row with the latest Publish Date
|
||||
latest_income = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
|
||||
|
||||
# drop the SimFinID column
|
||||
latest_income = latest_income.drop("SimFinId")
|
||||
|
||||
return (
|
||||
f"## {freq} income statement for {ticker} released on {str(latest_income['Publish Date'])[0:10]}: \n"
|
||||
+ str(latest_income)
|
||||
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a comprehensive breakdown of the company's financial performance. Starting with Revenue, it shows Cost of Revenue and resulting Gross Profit. Operating Expenses are detailed, including SG&A, R&D, and Depreciation. The statement then shows Operating Income, followed by non-operating items and Interest Expense, leading to Pretax Income. After accounting for Income Tax and any Extraordinary items, it concludes with Net Income, representing the company's bottom-line profit or loss for the period."
|
||||
)
|
||||
|
||||
|
||||
def get_google_news(
|
||||
query: Annotated[str, "Query to search with"],
|
||||
curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"],
|
||||
look_back_days: Annotated[int, "how many days to look back"],
|
||||
) -> str:
|
||||
query = query.replace(" ", "+")
|
||||
|
||||
start_date = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
before = start_date - relativedelta(days=look_back_days)
|
||||
before = before.strftime("%Y-%m-%d")
|
||||
|
||||
news_results = getNewsData(query, before, curr_date)
|
||||
|
||||
news_str = ""
|
||||
|
||||
for news in news_results:
|
||||
news_str += (
|
||||
f"### {news['title']} (source: {news['source']}) \n\n{news['snippet']}\n\n"
|
||||
)
|
||||
|
||||
if len(news_results) == 0:
|
||||
return ""
|
||||
|
||||
return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"
|
||||
|
||||
|
||||
def get_reddit_global_news(
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
look_back_days: Annotated[int, "how many days to look back"],
|
||||
max_limit_per_day: Annotated[int, "Maximum number of news per day"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the latest top reddit news
|
||||
Args:
|
||||
start_date: Start date in yyyy-mm-dd format
|
||||
end_date: End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
|
||||
"""
|
||||
|
||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
before = start_date - relativedelta(days=look_back_days)
|
||||
before = before.strftime("%Y-%m-%d")
|
||||
|
||||
posts = []
|
||||
# iterate from start_date to end_date
|
||||
curr_date = datetime.strptime(before, "%Y-%m-%d")
|
||||
|
||||
total_iterations = (start_date - curr_date).days + 1
|
||||
pbar = tqdm(desc=f"Getting Global News on {start_date}", total=total_iterations)
|
||||
|
||||
while curr_date <= start_date:
|
||||
curr_date_str = curr_date.strftime("%Y-%m-%d")
|
||||
fetch_result = fetch_top_from_category(
|
||||
"global_news",
|
||||
curr_date_str,
|
||||
max_limit_per_day,
|
||||
data_path=os.path.join(DATA_DIR, "reddit_data"),
|
||||
)
|
||||
posts.extend(fetch_result)
|
||||
curr_date += relativedelta(days=1)
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
if len(posts) == 0:
|
||||
return ""
|
||||
|
||||
news_str = ""
|
||||
for post in posts:
|
||||
if post["content"] == "":
|
||||
news_str += f"### {post['title']}\n\n"
|
||||
else:
|
||||
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
|
||||
|
||||
return f"## Global News Reddit, from {before} to {curr_date}:\n{news_str}"
|
||||
|
||||
|
||||
def get_reddit_company_news(
|
||||
ticker: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
look_back_days: Annotated[int, "how many days to look back"],
|
||||
max_limit_per_day: Annotated[int, "Maximum number of news per day"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the latest top reddit news
|
||||
Args:
|
||||
ticker: ticker symbol of the company
|
||||
start_date: Start date in yyyy-mm-dd format
|
||||
end_date: End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
|
||||
"""
|
||||
|
||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
before = start_date - relativedelta(days=look_back_days)
|
||||
before = before.strftime("%Y-%m-%d")
|
||||
|
||||
posts = []
|
||||
# iterate from start_date to end_date
|
||||
curr_date = datetime.strptime(before, "%Y-%m-%d")
|
||||
|
||||
total_iterations = (start_date - curr_date).days + 1
|
||||
pbar = tqdm(
|
||||
desc=f"Getting Company News for {ticker} on {start_date}",
|
||||
total=total_iterations,
|
||||
)
|
||||
|
||||
while curr_date <= start_date:
|
||||
curr_date_str = curr_date.strftime("%Y-%m-%d")
|
||||
fetch_result = fetch_top_from_category(
|
||||
"company_news",
|
||||
curr_date_str,
|
||||
max_limit_per_day,
|
||||
ticker,
|
||||
data_path=os.path.join(DATA_DIR, "reddit_data"),
|
||||
)
|
||||
posts.extend(fetch_result)
|
||||
curr_date += relativedelta(days=1)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
if len(posts) == 0:
|
||||
return ""
|
||||
|
||||
news_str = ""
|
||||
for post in posts:
|
||||
if post["content"] == "":
|
||||
news_str += f"### {post['title']}\n\n"
|
||||
else:
|
||||
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
|
||||
|
||||
return f"##{ticker} News Reddit, from {before} to {curr_date}:\n\n{news_str}"
|
||||
|
||||
|
||||
def get_stock_stats_indicators_window(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
||||
curr_date: Annotated[
|
||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
||||
],
|
||||
look_back_days: Annotated[int, "how many days to look back"],
|
||||
online: Annotated[bool, "to fetch data online or offline"],
|
||||
) -> str:
|
||||
|
||||
best_ind_params = {
|
||||
# Moving Averages
|
||||
"close_50_sma": (
|
||||
"50 SMA: A medium-term trend indicator. "
|
||||
"Usage: Identify trend direction and serve as dynamic support/resistance. "
|
||||
"Tips: It lags price; combine with faster indicators for timely signals."
|
||||
),
|
||||
"close_200_sma": (
|
||||
"200 SMA: A long-term trend benchmark. "
|
||||
"Usage: Confirm overall market trend and identify golden/death cross setups. "
|
||||
"Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries."
|
||||
),
|
||||
"close_10_ema": (
|
||||
"10 EMA: A responsive short-term average. "
|
||||
"Usage: Capture quick shifts in momentum and potential entry points. "
|
||||
"Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals."
|
||||
),
|
||||
# MACD Related
|
||||
"macd": (
|
||||
"MACD: Computes momentum via differences of EMAs. "
|
||||
"Usage: Look for crossovers and divergence as signals of trend changes. "
|
||||
"Tips: Confirm with other indicators in low-volatility or sideways markets."
|
||||
),
|
||||
"macds": (
|
||||
"MACD Signal: An EMA smoothing of the MACD line. "
|
||||
"Usage: Use crossovers with the MACD line to trigger trades. "
|
||||
"Tips: Should be part of a broader strategy to avoid false positives."
|
||||
),
|
||||
"macdh": (
|
||||
"MACD Histogram: Shows the gap between the MACD line and its signal. "
|
||||
"Usage: Visualize momentum strength and spot divergence early. "
|
||||
"Tips: Can be volatile; complement with additional filters in fast-moving markets."
|
||||
),
|
||||
# Momentum Indicators
|
||||
"rsi": (
|
||||
"RSI: Measures momentum to flag overbought/oversold conditions. "
|
||||
"Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. "
|
||||
"Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis."
|
||||
),
|
||||
# Volatility Indicators
|
||||
"boll": (
|
||||
"Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. "
|
||||
"Usage: Acts as a dynamic benchmark for price movement. "
|
||||
"Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals."
|
||||
),
|
||||
"boll_ub": (
|
||||
"Bollinger Upper Band: Typically 2 standard deviations above the middle line. "
|
||||
"Usage: Signals potential overbought conditions and breakout zones. "
|
||||
"Tips: Confirm signals with other tools; prices may ride the band in strong trends."
|
||||
),
|
||||
"boll_lb": (
|
||||
"Bollinger Lower Band: Typically 2 standard deviations below the middle line. "
|
||||
"Usage: Indicates potential oversold conditions. "
|
||||
"Tips: Use additional analysis to avoid false reversal signals."
|
||||
),
|
||||
"atr": (
|
||||
"ATR: Averages true range to measure volatility. "
|
||||
"Usage: Set stop-loss levels and adjust position sizes based on current market volatility. "
|
||||
"Tips: It's a reactive measure, so use it as part of a broader risk management strategy."
|
||||
),
|
||||
# Volume-Based Indicators
|
||||
"vwma": (
|
||||
"VWMA: A moving average weighted by volume. "
|
||||
"Usage: Confirm trends by integrating price action with volume data. "
|
||||
"Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses."
|
||||
),
|
||||
"mfi": (
|
||||
"MFI: The Money Flow Index is a momentum indicator that uses both price and volume to measure buying and selling pressure. "
|
||||
"Usage: Identify overbought (>80) or oversold (<20) conditions and confirm the strength of trends or reversals. "
|
||||
"Tips: Use alongside RSI or MACD to confirm signals; divergence between price and MFI can indicate potential reversals."
|
||||
),
|
||||
}
|
||||
|
||||
if indicator not in best_ind_params:
|
||||
raise ValueError(
|
||||
f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}"
|
||||
)
|
||||
|
||||
end_date = curr_date
|
||||
curr_date = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
before = curr_date - relativedelta(days=look_back_days)
|
||||
|
||||
if not online:
|
||||
# read from YFin data
|
||||
data = pd.read_csv(
|
||||
os.path.join(
|
||||
DATA_DIR,
|
||||
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
||||
)
|
||||
)
|
||||
data["Date"] = pd.to_datetime(data["Date"], utc=True)
|
||||
dates_in_df = data["Date"].astype(str).str[:10]
|
||||
|
||||
ind_string = ""
|
||||
while curr_date >= before:
|
||||
# only do the trading dates
|
||||
if curr_date.strftime("%Y-%m-%d") in dates_in_df.values:
|
||||
indicator_value = get_stockstats_indicator(
|
||||
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online
|
||||
)
|
||||
|
||||
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
||||
|
||||
curr_date = curr_date - relativedelta(days=1)
|
||||
else:
|
||||
# online gathering
|
||||
ind_string = ""
|
||||
while curr_date >= before:
|
||||
indicator_value = get_stockstats_indicator(
|
||||
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online
|
||||
)
|
||||
|
||||
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
||||
|
||||
curr_date = curr_date - relativedelta(days=1)
|
||||
|
||||
result_str = (
|
||||
f"## {indicator} values from {before.strftime('%Y-%m-%d')} to {end_date}:\n\n"
|
||||
+ ind_string
|
||||
+ "\n\n"
|
||||
+ best_ind_params.get(indicator, "No description available.")
|
||||
)
|
||||
|
||||
return result_str
|
||||
|
||||
|
||||
def get_stockstats_indicator(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
|
||||
curr_date: Annotated[
|
||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
||||
],
|
||||
online: Annotated[bool, "to fetch data online or offline"],
|
||||
) -> str:
|
||||
|
||||
curr_date = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
curr_date = curr_date.strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
indicator_value = StockstatsUtils.get_stock_stats(
|
||||
symbol,
|
||||
indicator,
|
||||
curr_date,
|
||||
os.path.join(DATA_DIR, "market_data", "price_data"),
|
||||
online=online,
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}"
|
||||
)
|
||||
return ""
|
||||
|
||||
return str(indicator_value)
|
||||
|
||||
|
||||
def get_YFin_data_window(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
curr_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
look_back_days: Annotated[int, "how many days to look back"],
|
||||
) -> str:
|
||||
# calculate past days
|
||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
before = date_obj - relativedelta(days=look_back_days)
|
||||
start_date = before.strftime("%Y-%m-%d")
|
||||
|
||||
# read in data
|
||||
data = pd.read_csv(
|
||||
os.path.join(
|
||||
DATA_DIR,
|
||||
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
||||
)
|
||||
)
|
||||
|
||||
# Extract just the date part for comparison
|
||||
data["DateOnly"] = data["Date"].str[:10]
|
||||
|
||||
# Filter data between the start and end dates (inclusive)
|
||||
filtered_data = data[
|
||||
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= curr_date)
|
||||
]
|
||||
|
||||
# Drop the temporary column we created
|
||||
filtered_data = filtered_data.drop("DateOnly", axis=1)
|
||||
|
||||
# Set pandas display options to show the full DataFrame
|
||||
with pd.option_context(
|
||||
"display.max_rows", None, "display.max_columns", None, "display.width", None
|
||||
):
|
||||
df_string = filtered_data.to_string()
|
||||
|
||||
return (
|
||||
f"## Raw Market Data for {symbol} from {start_date} to {curr_date}:\n\n"
|
||||
+ df_string
|
||||
)
|
||||
|
||||
|
||||
def get_YFin_data_online(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
):
|
||||
|
||||
datetime.strptime(start_date, "%Y-%m-%d")
|
||||
datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
# Create ticker object
|
||||
ticker = yf.Ticker(symbol.upper())
|
||||
|
||||
# Fetch historical data for the specified date range
|
||||
data = ticker.history(start=start_date, end=end_date)
|
||||
|
||||
# Check if data is empty
|
||||
if data.empty:
|
||||
return (
|
||||
f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
|
||||
)
|
||||
|
||||
# Remove timezone info from index for cleaner output
|
||||
if data.index.tz is not None:
|
||||
data.index = data.index.tz_localize(None)
|
||||
|
||||
# Round numerical values to 2 decimal places for cleaner display
|
||||
numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"]
|
||||
for col in numeric_columns:
|
||||
if col in data.columns:
|
||||
data[col] = data[col].round(2)
|
||||
|
||||
# Convert DataFrame to CSV string
|
||||
csv_string = data.to_csv()
|
||||
|
||||
# Add header information
|
||||
header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n"
|
||||
header += f"# Total records: {len(data)}\n"
|
||||
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||
|
||||
return header + csv_string
|
||||
|
||||
|
||||
def get_YFin_data(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
# read in data
|
||||
data = pd.read_csv(
|
||||
os.path.join(
|
||||
DATA_DIR,
|
||||
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
||||
)
|
||||
)
|
||||
|
||||
if end_date > "2025-03-25":
|
||||
raise Exception(
|
||||
f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25"
|
||||
)
|
||||
|
||||
# Extract just the date part for comparison
|
||||
data["DateOnly"] = data["Date"].str[:10]
|
||||
|
||||
# Filter data between the start and end dates (inclusive)
|
||||
filtered_data = data[
|
||||
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= end_date)
|
||||
]
|
||||
|
||||
# Drop the temporary column we created
|
||||
filtered_data = filtered_data.drop("DateOnly", axis=1)
|
||||
|
||||
# remove the index from the dataframe
|
||||
filtered_data = filtered_data.reset_index(drop=True)
|
||||
|
||||
return filtered_data
|
||||
|
||||
|
||||
def get_stock_news_openai(ticker, curr_date):
|
||||
config = get_config()
|
||||
client = OpenAI(base_url=config["backend_url"])
|
||||
|
||||
response = client.responses.create(
|
||||
model=config["quick_think_llm"],
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
text={"format": {"type": "text"}},
|
||||
reasoning={},
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"user_location": {"type": "approximate"},
|
||||
"search_context_size": "low",
|
||||
}
|
||||
],
|
||||
temperature=1,
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
)
|
||||
|
||||
return response.output[1].content[0].text
|
||||
|
||||
|
||||
def get_global_news_openai(curr_date):
|
||||
config = get_config()
|
||||
client = OpenAI(base_url=config["backend_url"])
|
||||
|
||||
response = client.responses.create(
|
||||
model=config["quick_think_llm"],
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
text={"format": {"type": "text"}},
|
||||
reasoning={},
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"user_location": {"type": "approximate"},
|
||||
"search_context_size": "low",
|
||||
}
|
||||
],
|
||||
temperature=1,
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
)
|
||||
|
||||
return response.output[1].content[0].text
|
||||
|
||||
|
||||
def get_fundamentals_openai(ticker, curr_date):
|
||||
config = get_config()
|
||||
client = OpenAI(base_url=config["backend_url"])
|
||||
|
||||
response = client.responses.create(
|
||||
model=config["quick_think_llm"],
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
text={"format": {"type": "text"}},
|
||||
reasoning={},
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"user_location": {"type": "approximate"},
|
||||
"search_context_size": "low",
|
||||
}
|
||||
],
|
||||
temperature=1,
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
)
|
||||
|
||||
return response.output[1].content[0].text
|
||||
|
|
@ -0,0 +1,275 @@
|
|||
# 文件差异报告
|
||||
# 当前文件: tradingagents\dataflows\interface.py
|
||||
# 中文版文件: TradingAgentsCN\tradingagents\dataflows\interface.py
|
||||
# 生成时间: 周日 2025/07/06
|
||||
|
||||
--- current/interface.py+++ chinese_version/interface.py@@ -1,5 +1,6 @@ from typing import Annotated, Dict
|
||||
from .reddit_utils import fetch_top_from_category
|
||||
+from .chinese_finance_utils import get_chinese_social_sentiment
|
||||
from .yfin_utils import *
|
||||
from .stockstats_utils import *
|
||||
from .googlenews_utils import *
|
||||
@@ -43,7 +44,14 @@ result = get_data_in_range(ticker, before, curr_date, "news_data", DATA_DIR)
|
||||
|
||||
if len(result) == 0:
|
||||
- return ""
|
||||
+ error_msg = f"⚠️ 无法获取{ticker}的新闻数据 ({before} 到 {curr_date})\n"
|
||||
+ error_msg += f"可能的原因:\n"
|
||||
+ error_msg += f"1. 数据文件不存在或路径配置错误\n"
|
||||
+ error_msg += f"2. 指定日期范围内没有新闻数据\n"
|
||||
+ error_msg += f"3. 需要先下载或更新Finnhub新闻数据\n"
|
||||
+ error_msg += f"建议:检查数据目录配置或重新获取新闻数据"
|
||||
+ print(f"📰 [DEBUG] {error_msg}")
|
||||
+ return error_msg
|
||||
|
||||
combined_result = ""
|
||||
for day, data in result.items():
|
||||
@@ -772,36 +780,217 @@ return response.output[1].content[0].text
|
||||
|
||||
|
||||
+def get_fundamentals_finnhub(ticker, curr_date):
|
||||
+ """
|
||||
+ 使用Finnhub API获取股票基本面数据作为OpenAI的备选方案
|
||||
+ Args:
|
||||
+ ticker (str): 股票代码
|
||||
+ curr_date (str): 当前日期,格式为yyyy-mm-dd
|
||||
+ Returns:
|
||||
+ str: 格式化的基本面数据报告
|
||||
+ """
|
||||
+ try:
|
||||
+ import finnhub
|
||||
+ import os
|
||||
+ from .cache_manager import get_cache
|
||||
+
|
||||
+ # 检查缓存
|
||||
+ cache = get_cache()
|
||||
+ cached_key = cache.find_cached_fundamentals_data(ticker, data_source="finnhub")
|
||||
+ if cached_key:
|
||||
+ cached_data = cache.load_fundamentals_data(cached_key)
|
||||
+ if cached_data:
|
||||
+ print(f"💾 [DEBUG] 从缓存加载Finnhub基本面数据: {ticker}")
|
||||
+ return cached_data
|
||||
+
|
||||
+ # 获取Finnhub API密钥
|
||||
+ api_key = os.getenv('FINNHUB_API_KEY')
|
||||
+ if not api_key:
|
||||
+ return "错误:未配置FINNHUB_API_KEY环境变量"
|
||||
+
|
||||
+ # 初始化Finnhub客户端
|
||||
+ finnhub_client = finnhub.Client(api_key=api_key)
|
||||
+
|
||||
+ print(f"📊 [DEBUG] 使用Finnhub API获取 {ticker} 的基本面数据...")
|
||||
+
|
||||
+ # 获取基本财务数据
|
||||
+ try:
|
||||
+ basic_financials = finnhub_client.company_basic_financials(ticker, 'all')
|
||||
+ except Exception as e:
|
||||
+ print(f"❌ [DEBUG] Finnhub基本财务数据获取失败: {str(e)}")
|
||||
+ basic_financials = None
|
||||
+
|
||||
+ # 获取公司概况
|
||||
+ try:
|
||||
+ company_profile = finnhub_client.company_profile2(symbol=ticker)
|
||||
+ except Exception as e:
|
||||
+ print(f"❌ [DEBUG] Finnhub公司概况获取失败: {str(e)}")
|
||||
+ company_profile = None
|
||||
+
|
||||
+ # 获取收益数据
|
||||
+ try:
|
||||
+ earnings = finnhub_client.company_earnings(ticker, limit=4)
|
||||
+ except Exception as e:
|
||||
+ print(f"❌ [DEBUG] Finnhub收益数据获取失败: {str(e)}")
|
||||
+ earnings = None
|
||||
+
|
||||
+ # 格式化报告
|
||||
+ report = f"# {ticker} 基本面分析报告(Finnhub数据源)\n\n"
|
||||
+ report += f"**数据获取时间**: {curr_date}\n"
|
||||
+ report += f"**数据来源**: Finnhub API\n\n"
|
||||
+
|
||||
+ # 公司概况部分
|
||||
+ if company_profile:
|
||||
+ report += "## 公司概况\n"
|
||||
+ report += f"- **公司名称**: {company_profile.get('name', 'N/A')}\n"
|
||||
+ report += f"- **行业**: {company_profile.get('finnhubIndustry', 'N/A')}\n"
|
||||
+ report += f"- **国家**: {company_profile.get('country', 'N/A')}\n"
|
||||
+ report += f"- **货币**: {company_profile.get('currency', 'N/A')}\n"
|
||||
+ report += f"- **市值**: {company_profile.get('marketCapitalization', 'N/A')} 百万美元\n"
|
||||
+ report += f"- **流通股数**: {company_profile.get('shareOutstanding', 'N/A')} 百万股\n\n"
|
||||
+
|
||||
+ # 基本财务指标
|
||||
+ if basic_financials and 'metric' in basic_financials:
|
||||
+ metrics = basic_financials['metric']
|
||||
+ report += "## 关键财务指标\n"
|
||||
+ report += "| 指标 | 数值 |\n"
|
||||
+ report += "|------|------|\n"
|
||||
+
|
||||
+ # 估值指标
|
||||
+ if 'peBasicExclExtraTTM' in metrics:
|
||||
+ report += f"| 市盈率 (PE) | {metrics['peBasicExclExtraTTM']:.2f} |\n"
|
||||
+ if 'psAnnual' in metrics:
|
||||
+ report += f"| 市销率 (PS) | {metrics['psAnnual']:.2f} |\n"
|
||||
+ if 'pbAnnual' in metrics:
|
||||
+ report += f"| 市净率 (PB) | {metrics['pbAnnual']:.2f} |\n"
|
||||
+
|
||||
+ # 盈利能力指标
|
||||
+ if 'roeTTM' in metrics:
|
||||
+ report += f"| 净资产收益率 (ROE) | {metrics['roeTTM']:.2f}% |\n"
|
||||
+ if 'roaTTM' in metrics:
|
||||
+ report += f"| 总资产收益率 (ROA) | {metrics['roaTTM']:.2f}% |\n"
|
||||
+ if 'netProfitMarginTTM' in metrics:
|
||||
+ report += f"| 净利润率 | {metrics['netProfitMarginTTM']:.2f}% |\n"
|
||||
+
|
||||
+ # 财务健康指标
|
||||
+ if 'currentRatioAnnual' in metrics:
|
||||
+ report += f"| 流动比率 | {metrics['currentRatioAnnual']:.2f} |\n"
|
||||
+ if 'totalDebt/totalEquityAnnual' in metrics:
|
||||
+ report += f"| 负债权益比 | {metrics['totalDebt/totalEquityAnnual']:.2f} |\n"
|
||||
+
|
||||
+ report += "\n"
|
||||
+
|
||||
+ # 收益历史
|
||||
+ if earnings:
|
||||
+ report += "## 收益历史\n"
|
||||
+ report += "| 季度 | 实际EPS | 预期EPS | 差异 |\n"
|
||||
+ report += "|------|---------|---------|------|\n"
|
||||
+ for earning in earnings[:4]: # 显示最近4个季度
|
||||
+ actual = earning.get('actual', 'N/A')
|
||||
+ estimate = earning.get('estimate', 'N/A')
|
||||
+ period = earning.get('period', 'N/A')
|
||||
+ surprise = earning.get('surprise', 'N/A')
|
||||
+ report += f"| {period} | {actual} | {estimate} | {surprise} |\n"
|
||||
+ report += "\n"
|
||||
+
|
||||
+ # 数据可用性说明
|
||||
+ report += "## 数据说明\n"
|
||||
+ report += "- 本报告使用Finnhub API提供的官方财务数据\n"
|
||||
+ report += "- 数据来源于公司财报和SEC文件\n"
|
||||
+ report += "- TTM表示过去12个月数据\n"
|
||||
+ report += "- Annual表示年度数据\n\n"
|
||||
+
|
||||
+ if not basic_financials and not company_profile and not earnings:
|
||||
+ report += "⚠️ **警告**: 无法获取该股票的基本面数据,可能原因:\n"
|
||||
+ report += "- 股票代码不正确\n"
|
||||
+ report += "- Finnhub API限制\n"
|
||||
+ report += "- 该股票暂无基本面数据\n"
|
||||
+
|
||||
+ # 保存到缓存
|
||||
+ if report and len(report) > 100: # 只有当报告有实际内容时才缓存
|
||||
+ cache.save_fundamentals_data(ticker, report, data_source="finnhub")
|
||||
+
|
||||
+ print(f"📊 [DEBUG] Finnhub基本面数据获取完成,报告长度: {len(report)}")
|
||||
+ return report
|
||||
+
|
||||
+ except ImportError:
|
||||
+ return "错误:未安装finnhub-python库,请运行: pip install finnhub-python"
|
||||
+ except Exception as e:
|
||||
+ print(f"❌ [DEBUG] Finnhub基本面数据获取失败: {str(e)}")
|
||||
+ return f"Finnhub基本面数据获取失败: {str(e)}"
|
||||
+
|
||||
+
|
||||
def get_fundamentals_openai(ticker, curr_date):
|
||||
- config = get_config()
|
||||
- client = OpenAI(base_url=config["backend_url"])
|
||||
-
|
||||
- response = client.responses.create(
|
||||
- model=config["quick_think_llm"],
|
||||
- input=[
|
||||
- {
|
||||
- "role": "system",
|
||||
- "content": [
|
||||
- {
|
||||
- "type": "input_text",
|
||||
- "text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
|
||||
- }
|
||||
- ],
|
||||
- }
|
||||
- ],
|
||||
- text={"format": {"type": "text"}},
|
||||
- reasoning={},
|
||||
- tools=[
|
||||
- {
|
||||
- "type": "web_search_preview",
|
||||
- "user_location": {"type": "approximate"},
|
||||
- "search_context_size": "low",
|
||||
- }
|
||||
- ],
|
||||
- temperature=1,
|
||||
- max_output_tokens=4096,
|
||||
- top_p=1,
|
||||
- store=True,
|
||||
- )
|
||||
-
|
||||
- return response.output[1].content[0].text
|
||||
+ """
|
||||
+ 获取股票基本面数据,优先使用OpenAI,失败时回退到Finnhub API
|
||||
+ 支持缓存机制以提高性能
|
||||
+ Args:
|
||||
+ ticker (str): 股票代码
|
||||
+ curr_date (str): 当前日期,格式为yyyy-mm-dd
|
||||
+ Returns:
|
||||
+ str: 基本面数据报告
|
||||
+ """
|
||||
+ try:
|
||||
+ from .cache_manager import get_cache
|
||||
+
|
||||
+ # 检查缓存 - 优先检查OpenAI缓存
|
||||
+ cache = get_cache()
|
||||
+ cached_key = cache.find_cached_fundamentals_data(ticker, data_source="openai")
|
||||
+ if cached_key:
|
||||
+ cached_data = cache.load_fundamentals_data(cached_key)
|
||||
+ if cached_data:
|
||||
+ print(f"💾 [DEBUG] 从缓存加载OpenAI基本面数据: {ticker}")
|
||||
+ return cached_data
|
||||
+
|
||||
+ config = get_config()
|
||||
+
|
||||
+ # 检查是否配置了OpenAI相关设置
|
||||
+ if not config.get("backend_url") or not config.get("quick_think_llm"):
|
||||
+ print(f"📊 [DEBUG] OpenAI配置不完整,直接使用Finnhub API")
|
||||
+ return get_fundamentals_finnhub(ticker, curr_date)
|
||||
+
|
||||
+ print(f"📊 [DEBUG] 尝试使用OpenAI获取 {ticker} 的基本面数据...")
|
||||
+
|
||||
+ client = OpenAI(base_url=config["backend_url"])
|
||||
+
|
||||
+ response = client.responses.create(
|
||||
+ model=config["quick_think_llm"],
|
||||
+ input=[
|
||||
+ {
|
||||
+ "role": "system",
|
||||
+ "content": [
|
||||
+ {
|
||||
+ "type": "input_text",
|
||||
+ "text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
|
||||
+ }
|
||||
+ ],
|
||||
+ }
|
||||
+ ],
|
||||
+ text={"format": {"type": "text"}},
|
||||
+ reasoning={},
|
||||
+ tools=[
|
||||
+ {
|
||||
+ "type": "web_search_preview",
|
||||
+ "user_location": {"type": "approximate"},
|
||||
+ "search_context_size": "low",
|
||||
+ }
|
||||
+ ],
|
||||
+ temperature=1,
|
||||
+ max_output_tokens=4096,
|
||||
+ top_p=1,
|
||||
+ store=True,
|
||||
+ )
|
||||
+
|
||||
+ result = response.output[1].content[0].text
|
||||
+
|
||||
+ # 保存到缓存
|
||||
+ if result and len(result) > 100: # 只有当结果有实际内容时才缓存
|
||||
+ cache.save_fundamentals_data(ticker, result, data_source="openai")
|
||||
+
|
||||
+ print(f"📊 [DEBUG] OpenAI基本面数据获取成功,长度: {len(result)}")
|
||||
+ return result
|
||||
+
|
||||
+ except Exception as e:
|
||||
+ print(f"❌ [DEBUG] OpenAI基本面数据获取失败: {str(e)}")
|
||||
+ print(f"📊 [DEBUG] 回退到Finnhub API...")
|
||||
+ return get_fundamentals_finnhub(ticker, curr_date)
|
||||
|
|
@ -0,0 +1,398 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
优化的A股数据获取工具
|
||||
集成缓存策略和通达信API,提高数据获取效率
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from .cache_manager import get_cache
|
||||
from .config import get_config
|
||||
|
||||
|
||||
class OptimizedChinaDataProvider:
|
||||
"""优化的A股数据提供器 - 集成缓存和通达信API"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = get_cache()
|
||||
self.config = get_config()
|
||||
self.last_api_call = 0
|
||||
self.min_api_interval = 0.5 # 通达信API调用间隔较短
|
||||
|
||||
print("📊 优化A股数据提供器初始化完成")
|
||||
|
||||
def _wait_for_rate_limit(self):
|
||||
"""等待API限制"""
|
||||
current_time = time.time()
|
||||
time_since_last_call = current_time - self.last_api_call
|
||||
|
||||
if time_since_last_call < self.min_api_interval:
|
||||
wait_time = self.min_api_interval - time_since_last_call
|
||||
time.sleep(wait_time)
|
||||
|
||||
self.last_api_call = time.time()
|
||||
|
||||
def get_stock_data(self, symbol: str, start_date: str, end_date: str,
|
||||
force_refresh: bool = False) -> str:
|
||||
"""
|
||||
获取A股数据 - 优先使用缓存
|
||||
|
||||
Args:
|
||||
symbol: 股票代码(6位数字)
|
||||
start_date: 开始日期 (YYYY-MM-DD)
|
||||
end_date: 结束日期 (YYYY-MM-DD)
|
||||
force_refresh: 是否强制刷新缓存
|
||||
|
||||
Returns:
|
||||
格式化的股票数据字符串
|
||||
"""
|
||||
print(f"📈 获取A股数据: {symbol} ({start_date} 到 {end_date})")
|
||||
|
||||
# 检查缓存(除非强制刷新)
|
||||
if not force_refresh:
|
||||
cache_key = self.cache.find_cached_stock_data(
|
||||
symbol=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
data_source="tdx"
|
||||
)
|
||||
|
||||
if cache_key:
|
||||
cached_data = self.cache.load_stock_data(cache_key)
|
||||
if cached_data:
|
||||
print(f"⚡ 从缓存加载A股数据: {symbol}")
|
||||
return cached_data
|
||||
|
||||
# 缓存未命中,从通达信API获取
|
||||
print(f"🌐 从通达信API获取数据: {symbol}")
|
||||
|
||||
try:
|
||||
# API限制处理
|
||||
self._wait_for_rate_limit()
|
||||
|
||||
# 调用通达信API
|
||||
from .tdx_utils import get_china_stock_data
|
||||
|
||||
formatted_data = get_china_stock_data(
|
||||
stock_code=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
# 检查是否获取成功
|
||||
if "❌" in formatted_data or "错误" in formatted_data:
|
||||
print(f"❌ 通达信API调用失败: {symbol}")
|
||||
# 尝试从旧缓存获取数据
|
||||
old_cache = self._try_get_old_cache(symbol, start_date, end_date)
|
||||
if old_cache:
|
||||
print(f"📁 使用过期缓存数据: {symbol}")
|
||||
return old_cache
|
||||
|
||||
# 生成备用数据
|
||||
return self._generate_fallback_data(symbol, start_date, end_date, "通达信API调用失败")
|
||||
|
||||
# 保存到缓存
|
||||
self.cache.save_stock_data(
|
||||
symbol=symbol,
|
||||
data=formatted_data,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
data_source="tdx"
|
||||
)
|
||||
|
||||
print(f"✅ A股数据获取成功: {symbol}")
|
||||
return formatted_data
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"通达信API调用异常: {str(e)}"
|
||||
print(f"❌ {error_msg}")
|
||||
|
||||
# 尝试从旧缓存获取数据
|
||||
old_cache = self._try_get_old_cache(symbol, start_date, end_date)
|
||||
if old_cache:
|
||||
print(f"📁 使用过期缓存数据: {symbol}")
|
||||
return old_cache
|
||||
|
||||
# 生成备用数据
|
||||
return self._generate_fallback_data(symbol, start_date, end_date, error_msg)
|
||||
|
||||
def get_fundamentals_data(self, symbol: str, force_refresh: bool = False) -> str:
|
||||
"""
|
||||
获取A股基本面数据 - 优先使用缓存
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
force_refresh: 是否强制刷新缓存
|
||||
|
||||
Returns:
|
||||
格式化的基本面数据字符串
|
||||
"""
|
||||
print(f"📊 获取A股基本面数据: {symbol}")
|
||||
|
||||
# 检查缓存(除非强制刷新)
|
||||
if not force_refresh:
|
||||
# 查找基本面数据缓存
|
||||
for metadata_file in self.cache.metadata_dir.glob(f"*_meta.json"):
|
||||
try:
|
||||
import json
|
||||
with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
if (metadata.get('symbol') == symbol and
|
||||
metadata.get('data_type') == 'fundamentals' and
|
||||
metadata.get('market_type') == 'china'):
|
||||
|
||||
cache_key = metadata_file.stem.replace('_meta', '')
|
||||
if self.cache.is_cache_valid(cache_key, symbol=symbol, data_type='fundamentals'):
|
||||
cached_data = self.cache.load_stock_data(cache_key)
|
||||
if cached_data:
|
||||
print(f"⚡ 从缓存加载A股基本面数据: {symbol}")
|
||||
return cached_data
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 缓存未命中,生成基本面分析
|
||||
print(f"🔍 生成A股基本面分析: {symbol}")
|
||||
|
||||
try:
|
||||
# 先获取股票数据
|
||||
current_date = datetime.now().strftime('%Y-%m-%d')
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d')
|
||||
|
||||
stock_data = self.get_stock_data(symbol, start_date, current_date)
|
||||
|
||||
# 生成基本面分析报告
|
||||
fundamentals_data = self._generate_fundamentals_report(symbol, stock_data)
|
||||
|
||||
# 保存到缓存
|
||||
self.cache.save_fundamentals_data(
|
||||
symbol=symbol,
|
||||
fundamentals_data=fundamentals_data,
|
||||
data_source="tdx_analysis"
|
||||
)
|
||||
|
||||
print(f"✅ A股基本面数据生成成功: {symbol}")
|
||||
return fundamentals_data
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"基本面数据生成失败: {str(e)}"
|
||||
print(f"❌ {error_msg}")
|
||||
return self._generate_fallback_fundamentals(symbol, error_msg)
|
||||
|
||||
def _generate_fundamentals_report(self, symbol: str, stock_data: str) -> str:
|
||||
"""基于股票数据生成基本面分析报告"""
|
||||
|
||||
# 从股票数据中提取信息
|
||||
company_name = "未知公司"
|
||||
current_price = "N/A"
|
||||
|
||||
if "股票名称:" in stock_data:
|
||||
lines = stock_data.split('\n')
|
||||
for line in lines:
|
||||
if "股票名称:" in line:
|
||||
company_name = line.split(':')[1].strip()
|
||||
elif "当前价格:" in line:
|
||||
current_price = line.split(':')[1].strip()
|
||||
|
||||
report = f"""# 中国A股基本面分析报告 - {symbol}({company_name})
|
||||
|
||||
## 公司基本信息
|
||||
- 股票代码:{symbol}
|
||||
- 股票名称:{company_name}
|
||||
- 行业分类:根据股票代码判断所属行业
|
||||
- 所属市场:深圳证券交易所/上海证券交易所
|
||||
- 最新股价:{current_price}
|
||||
- 分析日期:{datetime.now().strftime('%Y年%m月%d日')}
|
||||
|
||||
## 财务状况分析
|
||||
基于最新的市场数据和技术指标分析:
|
||||
|
||||
### 资产负债表分析
|
||||
- **总资产规模**:作为A股上市公司,具备一定的资产规模
|
||||
- **负债结构**:需要关注资产负债率和流动比率
|
||||
- **股东权益**:关注净资产收益率和每股净资产
|
||||
|
||||
### 现金流分析
|
||||
- **经营现金流**:关注主营业务现金流入情况
|
||||
- **投资现金流**:分析公司投资扩张策略
|
||||
- **筹资现金流**:关注融资结构和偿债能力
|
||||
|
||||
## 盈利能力分析
|
||||
### 收入分析
|
||||
- **营业收入增长率**:关注收入增长趋势
|
||||
- **主营业务收入占比**:分析业务集中度
|
||||
- **收入季节性**:识别业务周期性特征
|
||||
|
||||
### 利润分析
|
||||
- **毛利率水平**:反映产品竞争力
|
||||
- **净利润率**:体现整体盈利能力
|
||||
- **ROE(净资产收益率)**:衡量股东回报水平
|
||||
|
||||
## 成长性分析
|
||||
### 历史成长性
|
||||
- **营收复合增长率**:过去3-5年的收入增长情况
|
||||
- **净利润增长率**:盈利增长的可持续性
|
||||
- **市场份额变化**:在行业中的竞争地位
|
||||
|
||||
### 未来成长潜力
|
||||
- **行业发展前景**:所处行业的成长空间
|
||||
- **公司战略规划**:未来发展方向和投资计划
|
||||
- **创新能力**:研发投入和技术优势
|
||||
|
||||
## 估值分析
|
||||
### 相对估值
|
||||
- **市盈率(PE)**:与同行业公司对比
|
||||
- **市净率(PB)**:相对于净资产的估值水平
|
||||
- **市销率(PS)**:相对于营业收入的估值
|
||||
|
||||
### 绝对估值
|
||||
- **DCF估值**:基于现金流贴现的内在价值
|
||||
- **资产价值**:净资产重估价值
|
||||
- **分红收益率**:股息回报分析
|
||||
|
||||
## 风险分析
|
||||
### 系统性风险
|
||||
- **宏观经济风险**:经济周期对公司的影响
|
||||
- **政策风险**:行业政策变化的影响
|
||||
- **市场风险**:股市波动对估值的影响
|
||||
|
||||
### 非系统性风险
|
||||
- **经营风险**:公司特有的经营风险
|
||||
- **财务风险**:债务结构和偿债能力风险
|
||||
- **管理风险**:管理层变动和决策风险
|
||||
|
||||
## 投资建议
|
||||
### 综合评价
|
||||
基于以上分析,该股票的投资价值评估:
|
||||
|
||||
**优势:**
|
||||
- A股市场上市公司,监管相对完善
|
||||
- 具备一定的市场地位和品牌价值
|
||||
- 财务信息透明度较高
|
||||
|
||||
**风险:**
|
||||
- 需要关注宏观经济环境变化
|
||||
- 行业竞争加剧的影响
|
||||
- 政策调整对业务的潜在影响
|
||||
|
||||
### 操作建议
|
||||
- **投资策略**:建议采用价值投资策略,关注长期基本面
|
||||
- **仓位建议**:根据风险承受能力合理配置仓位
|
||||
- **关注指标**:重点关注ROE、PE、现金流等核心指标
|
||||
|
||||
---
|
||||
*注:本报告基于公开信息和技术分析生成,仅供参考,不构成投资建议。投资有风险,入市需谨慎。*
|
||||
|
||||
数据来源:通达信API + 基本面分析
|
||||
生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
"""
|
||||
|
||||
return report
|
||||
|
||||
def _try_get_old_cache(self, symbol: str, start_date: str, end_date: str) -> Optional[str]:
|
||||
"""尝试获取过期的缓存数据作为备用"""
|
||||
try:
|
||||
# 查找任何相关的缓存,不考虑TTL
|
||||
for metadata_file in self.cache.metadata_dir.glob(f"*_meta.json"):
|
||||
try:
|
||||
import json
|
||||
with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
if (metadata.get('symbol') == symbol and
|
||||
metadata.get('data_type') == 'stock_data' and
|
||||
metadata.get('market_type') == 'china'):
|
||||
|
||||
cache_key = metadata_file.stem.replace('_meta', '')
|
||||
cached_data = self.cache.load_stock_data(cache_key)
|
||||
if cached_data:
|
||||
return cached_data + "\n\n⚠️ 注意: 使用的是过期缓存数据"
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _generate_fallback_data(self, symbol: str, start_date: str, end_date: str, error_msg: str) -> str:
|
||||
"""生成备用数据"""
|
||||
return f"""# {symbol} A股数据获取失败
|
||||
|
||||
## ❌ 错误信息
|
||||
{error_msg}
|
||||
|
||||
## 📊 模拟数据(仅供演示)
|
||||
- 股票代码: {symbol}
|
||||
- 股票名称: 模拟公司
|
||||
- 数据期间: {start_date} 至 {end_date}
|
||||
- 模拟价格: ¥{random.uniform(10, 50):.2f}
|
||||
- 模拟涨跌: {random.uniform(-5, 5):+.2f}%
|
||||
|
||||
## ⚠️ 重要提示
|
||||
由于通达信API限制或网络问题,无法获取实时数据。
|
||||
建议稍后重试或检查网络连接。
|
||||
|
||||
生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
"""
|
||||
|
||||
def _generate_fallback_fundamentals(self, symbol: str, error_msg: str) -> str:
|
||||
"""生成备用基本面数据"""
|
||||
return f"""# {symbol} A股基本面分析失败
|
||||
|
||||
## ❌ 错误信息
|
||||
{error_msg}
|
||||
|
||||
## 📊 基本信息
|
||||
- 股票代码: {symbol}
|
||||
- 分析状态: 数据获取失败
|
||||
- 建议: 稍后重试或检查网络连接
|
||||
|
||||
生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
"""
|
||||
|
||||
|
||||
# 全局实例
|
||||
_china_data_provider = None
|
||||
|
||||
def get_optimized_china_data_provider() -> OptimizedChinaDataProvider:
|
||||
"""获取全局A股数据提供器实例"""
|
||||
global _china_data_provider
|
||||
if _china_data_provider is None:
|
||||
_china_data_provider = OptimizedChinaDataProvider()
|
||||
return _china_data_provider
|
||||
|
||||
|
||||
def get_china_stock_data_cached(symbol: str, start_date: str, end_date: str,
|
||||
force_refresh: bool = False) -> str:
|
||||
"""
|
||||
获取A股数据的便捷函数
|
||||
|
||||
Args:
|
||||
symbol: 股票代码(6位数字)
|
||||
start_date: 开始日期 (YYYY-MM-DD)
|
||||
end_date: 结束日期 (YYYY-MM-DD)
|
||||
force_refresh: 是否强制刷新缓存
|
||||
|
||||
Returns:
|
||||
格式化的股票数据字符串
|
||||
"""
|
||||
provider = get_optimized_china_data_provider()
|
||||
return provider.get_stock_data(symbol, start_date, end_date, force_refresh)
|
||||
|
||||
|
||||
def get_china_fundamentals_cached(symbol: str, force_refresh: bool = False) -> str:
|
||||
"""
|
||||
获取A股基本面数据的便捷函数
|
||||
|
||||
Args:
|
||||
symbol: 股票代码(6位数字)
|
||||
force_refresh: 是否强制刷新缓存
|
||||
|
||||
Returns:
|
||||
格式化的基本面数据字符串
|
||||
"""
|
||||
provider = get_optimized_china_data_provider()
|
||||
return provider.get_fundamentals_data(symbol, force_refresh)
|
||||
|
|
@ -0,0 +1,404 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Optimized US Stock Data Fetcher
|
||||
Integrates caching strategy to reduce API calls and improve response speed
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
import yfinance as yf
|
||||
import pandas as pd
|
||||
from .cache_manager import get_cache
|
||||
from .config import get_config
|
||||
|
||||
|
||||
class OptimizedUSDataProvider:
|
||||
"""Optimized US Stock Data Provider - Integrates caching and API rate limiting"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = get_cache()
|
||||
self.config = get_config()
|
||||
self.last_api_call = 0
|
||||
self.min_api_interval = 1.0 # Minimum API call interval (seconds)
|
||||
|
||||
print("📊 Optimized US stock data provider initialized")
|
||||
|
||||
def _wait_for_rate_limit(self):
|
||||
"""Wait for API rate limit"""
|
||||
current_time = time.time()
|
||||
time_since_last_call = current_time - self.last_api_call
|
||||
|
||||
if time_since_last_call < self.min_api_interval:
|
||||
wait_time = self.min_api_interval - time_since_last_call
|
||||
print(f"⏳ API rate limit wait {wait_time:.1f}s...")
|
||||
time.sleep(wait_time)
|
||||
|
||||
self.last_api_call = time.time()
|
||||
|
||||
def get_stock_data(self, symbol: str, start_date: str, end_date: str,
|
||||
force_refresh: bool = False) -> str:
|
||||
"""
|
||||
Get US stock data - prioritize cache usage
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol
|
||||
start_date: Start date (YYYY-MM-DD)
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
force_refresh: Whether to force refresh cache
|
||||
|
||||
Returns:
|
||||
Formatted stock data string
|
||||
"""
|
||||
try:
|
||||
# Check cache first (unless force refresh)
|
||||
if not force_refresh:
|
||||
cache_key = self.cache.find_cached_stock_data(
|
||||
symbol, start_date, end_date, "optimized_yfinance"
|
||||
)
|
||||
|
||||
if cache_key and self.cache.is_cache_valid(cache_key, symbol):
|
||||
cached_data = self.cache.load_stock_data(cache_key)
|
||||
if cached_data:
|
||||
print(f"📖 Using cached data for {symbol}")
|
||||
if isinstance(cached_data, pd.DataFrame):
|
||||
return self._format_stock_data(cached_data, symbol)
|
||||
else:
|
||||
return cached_data
|
||||
|
||||
# Fetch new data from API
|
||||
print(f"🌐 Fetching new data for {symbol} from {start_date} to {end_date}")
|
||||
|
||||
# Wait for rate limit
|
||||
self._wait_for_rate_limit()
|
||||
|
||||
# Try Yahoo Finance first
|
||||
try:
|
||||
data = self._fetch_from_yfinance(symbol, start_date, end_date)
|
||||
if data is not None and not data.empty:
|
||||
# Cache the DataFrame
|
||||
cache_key = self.cache.save_stock_data(
|
||||
symbol, data, start_date, end_date, "optimized_yfinance"
|
||||
)
|
||||
|
||||
# Format and return
|
||||
formatted_data = self._format_stock_data(data, symbol)
|
||||
print(f"✅ Successfully fetched and cached data for {symbol}")
|
||||
return formatted_data
|
||||
else:
|
||||
print(f"⚠️ No data returned from Yahoo Finance for {symbol}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Yahoo Finance error for {symbol}: {e}")
|
||||
|
||||
# Fallback: Try FINNHUB (if API key available)
|
||||
try:
|
||||
finnhub_data = self._fetch_from_finnhub(symbol, start_date, end_date)
|
||||
if finnhub_data:
|
||||
# Cache the string data
|
||||
cache_key = self.cache.save_stock_data(
|
||||
symbol, finnhub_data, start_date, end_date, "optimized_finnhub"
|
||||
)
|
||||
print(f"✅ Successfully fetched data from FINNHUB for {symbol}")
|
||||
return finnhub_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ FINNHUB error for {symbol}: {e}")
|
||||
|
||||
# If all fails, return error message
|
||||
error_msg = f"❌ Failed to fetch data for {symbol} from {start_date} to {end_date}"
|
||||
print(error_msg)
|
||||
return error_msg
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Unexpected error fetching data for {symbol}: {e}"
|
||||
print(error_msg)
|
||||
return error_msg
|
||||
|
||||
def _fetch_from_yfinance(self, symbol: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""Fetch data from Yahoo Finance"""
|
||||
try:
|
||||
ticker = yf.Ticker(symbol)
|
||||
data = ticker.history(start=start_date, end=end_date)
|
||||
|
||||
if data.empty:
|
||||
print(f"⚠️ No data available for {symbol} in the specified date range")
|
||||
return None
|
||||
|
||||
# Reset index to make Date a column
|
||||
data = data.reset_index()
|
||||
|
||||
# Ensure we have the required columns
|
||||
required_columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
|
||||
missing_columns = [col for col in required_columns if col not in data.columns]
|
||||
|
||||
if missing_columns:
|
||||
print(f"⚠️ Missing columns for {symbol}: {missing_columns}")
|
||||
return None
|
||||
|
||||
return data[required_columns]
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Yahoo Finance fetch error for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _fetch_from_finnhub(self, symbol: str, start_date: str, end_date: str) -> Optional[str]:
|
||||
"""Fetch data from FINNHUB API"""
|
||||
try:
|
||||
# Check if FINNHUB API key is available
|
||||
finnhub_api_key = os.getenv('FINNHUB_API_KEY')
|
||||
if not finnhub_api_key:
|
||||
print("⚠️ FINNHUB API key not found, skipping FINNHUB data fetch")
|
||||
return None
|
||||
|
||||
import finnhub
|
||||
|
||||
# Initialize FINNHUB client
|
||||
finnhub_client = finnhub.Client(api_key=finnhub_api_key)
|
||||
|
||||
# Convert dates to timestamps
|
||||
start_timestamp = int(datetime.strptime(start_date, '%Y-%m-%d').timestamp())
|
||||
end_timestamp = int(datetime.strptime(end_date, '%Y-%m-%d').timestamp())
|
||||
|
||||
# Fetch candle data
|
||||
candle_data = finnhub_client.stock_candles(symbol, 'D', start_timestamp, end_timestamp)
|
||||
|
||||
if candle_data['s'] != 'ok':
|
||||
print(f"⚠️ FINNHUB returned status: {candle_data['s']} for {symbol}")
|
||||
return None
|
||||
|
||||
# Format data
|
||||
formatted_data = self._format_finnhub_data(candle_data, symbol)
|
||||
return formatted_data
|
||||
|
||||
except ImportError:
|
||||
print("⚠️ finnhub-python package not installed, skipping FINNHUB data fetch")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"❌ FINNHUB fetch error for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _format_stock_data(self, data: pd.DataFrame, symbol: str) -> str:
|
||||
"""Format DataFrame stock data into string"""
|
||||
try:
|
||||
# Ensure Date column is properly formatted
|
||||
if 'Date' in data.columns:
|
||||
data['Date'] = pd.to_datetime(data['Date']).dt.strftime('%Y-%m-%d')
|
||||
|
||||
# Round numerical columns to 2 decimal places
|
||||
numeric_columns = ['Open', 'High', 'Low', 'Close']
|
||||
for col in numeric_columns:
|
||||
if col in data.columns:
|
||||
data[col] = data[col].round(2)
|
||||
|
||||
# Format volume as integer
|
||||
if 'Volume' in data.columns:
|
||||
data['Volume'] = data['Volume'].astype(int)
|
||||
|
||||
# Create formatted string
|
||||
formatted_lines = [f"Stock Data for {symbol}:"]
|
||||
formatted_lines.append("Date,Open,High,Low,Close,Volume")
|
||||
|
||||
for _, row in data.iterrows():
|
||||
line = f"{row['Date']},{row['Open']},{row['High']},{row['Low']},{row['Close']},{row['Volume']}"
|
||||
formatted_lines.append(line)
|
||||
|
||||
# Add summary statistics
|
||||
if len(data) > 0:
|
||||
formatted_lines.append(f"\nSummary for {symbol}:")
|
||||
formatted_lines.append(f"Period: {data['Date'].iloc[0]} to {data['Date'].iloc[-1]}")
|
||||
formatted_lines.append(f"Total trading days: {len(data)}")
|
||||
formatted_lines.append(f"Average volume: {data['Volume'].mean():,.0f}")
|
||||
formatted_lines.append(f"Price range: ${data['Low'].min():.2f} - ${data['High'].max():.2f}")
|
||||
|
||||
# Calculate basic statistics
|
||||
start_price = data['Open'].iloc[0]
|
||||
end_price = data['Close'].iloc[-1]
|
||||
price_change = end_price - start_price
|
||||
price_change_pct = (price_change / start_price) * 100
|
||||
|
||||
formatted_lines.append(f"Period return: {price_change_pct:+.2f}% (${price_change:+.2f})")
|
||||
|
||||
return "\n".join(formatted_lines)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error formatting stock data for {symbol}: {e}")
|
||||
return f"Error formatting data for {symbol}: {str(e)}"
|
||||
|
||||
def _format_finnhub_data(self, candle_data: Dict, symbol: str) -> str:
|
||||
"""Format FINNHUB candle data into string"""
|
||||
try:
|
||||
# Extract data arrays
|
||||
timestamps = candle_data['t']
|
||||
opens = candle_data['o']
|
||||
highs = candle_data['h']
|
||||
lows = candle_data['l']
|
||||
closes = candle_data['c']
|
||||
volumes = candle_data['v']
|
||||
|
||||
# Create formatted string
|
||||
formatted_lines = [f"Stock Data for {symbol} (FINNHUB):"]
|
||||
formatted_lines.append("Date,Open,High,Low,Close,Volume")
|
||||
|
||||
for i in range(len(timestamps)):
|
||||
date = datetime.fromtimestamp(timestamps[i]).strftime('%Y-%m-%d')
|
||||
line = f"{date},{opens[i]:.2f},{highs[i]:.2f},{lows[i]:.2f},{closes[i]:.2f},{int(volumes[i])}"
|
||||
formatted_lines.append(line)
|
||||
|
||||
# Add summary
|
||||
if len(timestamps) > 0:
|
||||
start_date = datetime.fromtimestamp(timestamps[0]).strftime('%Y-%m-%d')
|
||||
end_date = datetime.fromtimestamp(timestamps[-1]).strftime('%Y-%m-%d')
|
||||
|
||||
formatted_lines.append(f"\nSummary for {symbol}:")
|
||||
formatted_lines.append(f"Period: {start_date} to {end_date}")
|
||||
formatted_lines.append(f"Total trading days: {len(timestamps)}")
|
||||
formatted_lines.append(f"Average volume: {sum(volumes)/len(volumes):,.0f}")
|
||||
formatted_lines.append(f"Price range: ${min(lows):.2f} - ${max(highs):.2f}")
|
||||
|
||||
# Calculate return
|
||||
price_change = closes[-1] - opens[0]
|
||||
price_change_pct = (price_change / opens[0]) * 100
|
||||
formatted_lines.append(f"Period return: {price_change_pct:+.2f}% (${price_change:+.2f})")
|
||||
|
||||
return "\n".join(formatted_lines)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error formatting FINNHUB data for {symbol}: {e}")
|
||||
return f"Error formatting FINNHUB data for {symbol}: {str(e)}"
|
||||
|
||||
def get_stock_with_indicators(self, symbol: str, start_date: str, end_date: str,
|
||||
indicators: list = None) -> str:
|
||||
"""
|
||||
Get stock data with technical indicators
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol
|
||||
start_date: Start date (YYYY-MM-DD)
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
indicators: List of indicators to calculate ['sma_20', 'rsi', 'macd']
|
||||
|
||||
Returns:
|
||||
Formatted stock data with indicators
|
||||
"""
|
||||
try:
|
||||
# Get basic stock data
|
||||
basic_data = self.get_stock_data(symbol, start_date, end_date)
|
||||
|
||||
if basic_data.startswith("❌"):
|
||||
return basic_data
|
||||
|
||||
# If no indicators requested, return basic data
|
||||
if not indicators:
|
||||
return basic_data
|
||||
|
||||
# Fetch DataFrame for indicator calculation
|
||||
data_df = self._fetch_from_yfinance(symbol, start_date, end_date)
|
||||
if data_df is None or data_df.empty:
|
||||
return basic_data
|
||||
|
||||
# Calculate indicators
|
||||
indicator_data = self._calculate_indicators(data_df, indicators)
|
||||
|
||||
# Combine basic data with indicators
|
||||
combined_data = basic_data + "\n\nTechnical Indicators:\n" + indicator_data
|
||||
|
||||
return combined_data
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Error getting stock data with indicators for {symbol}: {e}"
|
||||
print(error_msg)
|
||||
return error_msg
|
||||
|
||||
def _calculate_indicators(self, data: pd.DataFrame, indicators: list) -> str:
|
||||
"""Calculate technical indicators"""
|
||||
try:
|
||||
indicator_lines = []
|
||||
|
||||
for indicator in indicators:
|
||||
if indicator == 'sma_20':
|
||||
data['SMA_20'] = data['Close'].rolling(window=20).mean()
|
||||
latest_sma = data['SMA_20'].iloc[-1]
|
||||
indicator_lines.append(f"SMA(20): ${latest_sma:.2f}")
|
||||
|
||||
elif indicator == 'rsi':
|
||||
# Simple RSI calculation
|
||||
delta = data['Close'].diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
|
||||
rs = gain / loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
latest_rsi = rsi.iloc[-1]
|
||||
indicator_lines.append(f"RSI(14): {latest_rsi:.2f}")
|
||||
|
||||
elif indicator == 'macd':
|
||||
# Simple MACD calculation
|
||||
ema_12 = data['Close'].ewm(span=12).mean()
|
||||
ema_26 = data['Close'].ewm(span=26).mean()
|
||||
macd_line = ema_12 - ema_26
|
||||
signal_line = macd_line.ewm(span=9).mean()
|
||||
latest_macd = macd_line.iloc[-1]
|
||||
latest_signal = signal_line.iloc[-1]
|
||||
indicator_lines.append(f"MACD: {latest_macd:.4f}, Signal: {latest_signal:.4f}")
|
||||
|
||||
return "\n".join(indicator_lines)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error calculating indicators: {e}")
|
||||
return f"Error calculating indicators: {str(e)}"
|
||||
|
||||
|
||||
# Global provider instance
|
||||
_global_provider = None
|
||||
|
||||
def get_optimized_us_data_provider() -> OptimizedUSDataProvider:
|
||||
"""
|
||||
Get global optimized US data provider instance
|
||||
|
||||
Returns:
|
||||
OptimizedUSDataProvider instance
|
||||
"""
|
||||
global _global_provider
|
||||
if _global_provider is None:
|
||||
_global_provider = OptimizedUSDataProvider()
|
||||
return _global_provider
|
||||
|
||||
|
||||
# Convenience functions
|
||||
def get_optimized_stock_data(symbol: str, start_date: str, end_date: str,
|
||||
force_refresh: bool = False) -> str:
|
||||
"""Get optimized stock data (convenience function)"""
|
||||
provider = get_optimized_us_data_provider()
|
||||
return provider.get_stock_data(symbol, start_date, end_date, force_refresh)
|
||||
|
||||
|
||||
def get_stock_with_indicators(symbol: str, start_date: str, end_date: str,
|
||||
indicators: list = None) -> str:
|
||||
"""Get stock data with technical indicators (convenience function)"""
|
||||
provider = get_optimized_us_data_provider()
|
||||
return provider.get_stock_with_indicators(symbol, start_date, end_date, indicators)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the optimized data provider
|
||||
print("🧪 Testing Optimized US Data Provider...")
|
||||
|
||||
# Initialize provider
|
||||
provider = OptimizedUSDataProvider()
|
||||
|
||||
# Test data fetch
|
||||
data = provider.get_stock_data("AAPL", "2024-01-01", "2024-01-31")
|
||||
print("Sample data:")
|
||||
print(data[:500] + "..." if len(data) > 500 else data)
|
||||
|
||||
# Test with indicators
|
||||
data_with_indicators = provider.get_stock_with_indicators(
|
||||
"AAPL", "2024-01-01", "2024-01-31",
|
||||
indicators=['sma_20', 'rsi', 'macd']
|
||||
)
|
||||
print("\nData with indicators:")
|
||||
print(data_with_indicators[-500:] if len(data_with_indicators) > 500 else data_with_indicators)
|
||||
|
||||
print("✅ Optimized data provider test completed!")
|
||||
|
|
@ -0,0 +1,404 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Optimized US Stock Data Fetcher
|
||||
Integrates caching strategy to reduce API calls and improve response speed
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
import yfinance as yf
|
||||
import pandas as pd
|
||||
from .cache_manager import get_cache
|
||||
from .config import get_config
|
||||
|
||||
|
||||
class OptimizedUSDataProvider:
|
||||
"""Optimized US Stock Data Provider - Integrates caching and API rate limiting"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = get_cache()
|
||||
self.config = get_config()
|
||||
self.last_api_call = 0
|
||||
self.min_api_interval = 1.0 # Minimum API call interval (seconds)
|
||||
|
||||
print("📊 Optimized US stock data provider initialized")
|
||||
|
||||
def _wait_for_rate_limit(self):
|
||||
"""Wait for API rate limit"""
|
||||
current_time = time.time()
|
||||
time_since_last_call = current_time - self.last_api_call
|
||||
|
||||
if time_since_last_call < self.min_api_interval:
|
||||
wait_time = self.min_api_interval - time_since_last_call
|
||||
print(f"⏳ API rate limit wait {wait_time:.1f}s...")
|
||||
time.sleep(wait_time)
|
||||
|
||||
self.last_api_call = time.time()
|
||||
|
||||
def get_stock_data(self, symbol: str, start_date: str, end_date: str,
|
||||
force_refresh: bool = False) -> str:
|
||||
"""
|
||||
Get US stock data - prioritize cache usage
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol
|
||||
start_date: Start date (YYYY-MM-DD)
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
force_refresh: Whether to force refresh cache
|
||||
|
||||
Returns:
|
||||
Formatted stock data string
|
||||
"""
|
||||
try:
|
||||
# Check cache first (unless force refresh)
|
||||
if not force_refresh:
|
||||
cache_key = self.cache.find_cached_stock_data(
|
||||
symbol, start_date, end_date, "optimized_yfinance"
|
||||
)
|
||||
|
||||
if cache_key and self.cache.is_cache_valid(cache_key, symbol):
|
||||
cached_data = self.cache.load_stock_data(cache_key)
|
||||
if cached_data:
|
||||
print(f"📖 Using cached data for {symbol}")
|
||||
if isinstance(cached_data, pd.DataFrame):
|
||||
return self._format_stock_data(cached_data, symbol)
|
||||
else:
|
||||
return cached_data
|
||||
|
||||
# Fetch new data from API
|
||||
print(f"🌐 Fetching new data for {symbol} from {start_date} to {end_date}")
|
||||
|
||||
# Wait for rate limit
|
||||
self._wait_for_rate_limit()
|
||||
|
||||
# Try Yahoo Finance first
|
||||
try:
|
||||
data = self._fetch_from_yfinance(symbol, start_date, end_date)
|
||||
if data is not None and not data.empty:
|
||||
# Cache the DataFrame
|
||||
cache_key = self.cache.save_stock_data(
|
||||
symbol, data, start_date, end_date, "optimized_yfinance"
|
||||
)
|
||||
|
||||
# Format and return
|
||||
formatted_data = self._format_stock_data(data, symbol)
|
||||
print(f"✅ Successfully fetched and cached data for {symbol}")
|
||||
return formatted_data
|
||||
else:
|
||||
print(f"⚠️ No data returned from Yahoo Finance for {symbol}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Yahoo Finance error for {symbol}: {e}")
|
||||
|
||||
# Fallback: Try FINNHUB (if API key available)
|
||||
try:
|
||||
finnhub_data = self._fetch_from_finnhub(symbol, start_date, end_date)
|
||||
if finnhub_data:
|
||||
# Cache the string data
|
||||
cache_key = self.cache.save_stock_data(
|
||||
symbol, finnhub_data, start_date, end_date, "optimized_finnhub"
|
||||
)
|
||||
print(f"✅ Successfully fetched data from FINNHUB for {symbol}")
|
||||
return finnhub_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ FINNHUB error for {symbol}: {e}")
|
||||
|
||||
# If all fails, return error message
|
||||
error_msg = f"❌ Failed to fetch data for {symbol} from {start_date} to {end_date}"
|
||||
print(error_msg)
|
||||
return error_msg
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Unexpected error fetching data for {symbol}: {e}"
|
||||
print(error_msg)
|
||||
return error_msg
|
||||
|
||||
def _fetch_from_yfinance(self, symbol: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""Fetch data from Yahoo Finance"""
|
||||
try:
|
||||
ticker = yf.Ticker(symbol)
|
||||
data = ticker.history(start=start_date, end=end_date)
|
||||
|
||||
if data.empty:
|
||||
print(f"⚠️ No data available for {symbol} in the specified date range")
|
||||
return None
|
||||
|
||||
# Reset index to make Date a column
|
||||
data = data.reset_index()
|
||||
|
||||
# Ensure we have the required columns
|
||||
required_columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
|
||||
missing_columns = [col for col in required_columns if col not in data.columns]
|
||||
|
||||
if missing_columns:
|
||||
print(f"⚠️ Missing columns for {symbol}: {missing_columns}")
|
||||
return None
|
||||
|
||||
return data[required_columns]
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Yahoo Finance fetch error for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _fetch_from_finnhub(self, symbol: str, start_date: str, end_date: str) -> Optional[str]:
|
||||
"""Fetch data from FINNHUB API"""
|
||||
try:
|
||||
# Check if FINNHUB API key is available
|
||||
finnhub_api_key = os.getenv('FINNHUB_API_KEY')
|
||||
if not finnhub_api_key:
|
||||
print("⚠️ FINNHUB API key not found, skipping FINNHUB data fetch")
|
||||
return None
|
||||
|
||||
import finnhub
|
||||
|
||||
# Initialize FINNHUB client
|
||||
finnhub_client = finnhub.Client(api_key=finnhub_api_key)
|
||||
|
||||
# Convert dates to timestamps
|
||||
start_timestamp = int(datetime.strptime(start_date, '%Y-%m-%d').timestamp())
|
||||
end_timestamp = int(datetime.strptime(end_date, '%Y-%m-%d').timestamp())
|
||||
|
||||
# Fetch candle data
|
||||
candle_data = finnhub_client.stock_candles(symbol, 'D', start_timestamp, end_timestamp)
|
||||
|
||||
if candle_data['s'] != 'ok':
|
||||
print(f"⚠️ FINNHUB returned status: {candle_data['s']} for {symbol}")
|
||||
return None
|
||||
|
||||
# Format data
|
||||
formatted_data = self._format_finnhub_data(candle_data, symbol)
|
||||
return formatted_data
|
||||
|
||||
except ImportError:
|
||||
print("⚠️ finnhub-python package not installed, skipping FINNHUB data fetch")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"❌ FINNHUB fetch error for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _format_stock_data(self, data: pd.DataFrame, symbol: str) -> str:
|
||||
"""Format DataFrame stock data into string"""
|
||||
try:
|
||||
# Ensure Date column is properly formatted
|
||||
if 'Date' in data.columns:
|
||||
data['Date'] = pd.to_datetime(data['Date']).dt.strftime('%Y-%m-%d')
|
||||
|
||||
# Round numerical columns to 2 decimal places
|
||||
numeric_columns = ['Open', 'High', 'Low', 'Close']
|
||||
for col in numeric_columns:
|
||||
if col in data.columns:
|
||||
data[col] = data[col].round(2)
|
||||
|
||||
# Format volume as integer
|
||||
if 'Volume' in data.columns:
|
||||
data['Volume'] = data['Volume'].astype(int)
|
||||
|
||||
# Create formatted string
|
||||
formatted_lines = [f"Stock Data for {symbol}:"]
|
||||
formatted_lines.append("Date,Open,High,Low,Close,Volume")
|
||||
|
||||
for _, row in data.iterrows():
|
||||
line = f"{row['Date']},{row['Open']},{row['High']},{row['Low']},{row['Close']},{row['Volume']}"
|
||||
formatted_lines.append(line)
|
||||
|
||||
# Add summary statistics
|
||||
if len(data) > 0:
|
||||
formatted_lines.append(f"\nSummary for {symbol}:")
|
||||
formatted_lines.append(f"Period: {data['Date'].iloc[0]} to {data['Date'].iloc[-1]}")
|
||||
formatted_lines.append(f"Total trading days: {len(data)}")
|
||||
formatted_lines.append(f"Average volume: {data['Volume'].mean():,.0f}")
|
||||
formatted_lines.append(f"Price range: ${data['Low'].min():.2f} - ${data['High'].max():.2f}")
|
||||
|
||||
# Calculate basic statistics
|
||||
start_price = data['Open'].iloc[0]
|
||||
end_price = data['Close'].iloc[-1]
|
||||
price_change = end_price - start_price
|
||||
price_change_pct = (price_change / start_price) * 100
|
||||
|
||||
formatted_lines.append(f"Period return: {price_change_pct:+.2f}% (${price_change:+.2f})")
|
||||
|
||||
return "\n".join(formatted_lines)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error formatting stock data for {symbol}: {e}")
|
||||
return f"Error formatting data for {symbol}: {str(e)}"
|
||||
|
||||
def _format_finnhub_data(self, candle_data: Dict, symbol: str) -> str:
|
||||
"""Format FINNHUB candle data into string"""
|
||||
try:
|
||||
# Extract data arrays
|
||||
timestamps = candle_data['t']
|
||||
opens = candle_data['o']
|
||||
highs = candle_data['h']
|
||||
lows = candle_data['l']
|
||||
closes = candle_data['c']
|
||||
volumes = candle_data['v']
|
||||
|
||||
# Create formatted string
|
||||
formatted_lines = [f"Stock Data for {symbol} (FINNHUB):"]
|
||||
formatted_lines.append("Date,Open,High,Low,Close,Volume")
|
||||
|
||||
for i in range(len(timestamps)):
|
||||
date = datetime.fromtimestamp(timestamps[i]).strftime('%Y-%m-%d')
|
||||
line = f"{date},{opens[i]:.2f},{highs[i]:.2f},{lows[i]:.2f},{closes[i]:.2f},{int(volumes[i])}"
|
||||
formatted_lines.append(line)
|
||||
|
||||
# Add summary
|
||||
if len(timestamps) > 0:
|
||||
start_date = datetime.fromtimestamp(timestamps[0]).strftime('%Y-%m-%d')
|
||||
end_date = datetime.fromtimestamp(timestamps[-1]).strftime('%Y-%m-%d')
|
||||
|
||||
formatted_lines.append(f"\nSummary for {symbol}:")
|
||||
formatted_lines.append(f"Period: {start_date} to {end_date}")
|
||||
formatted_lines.append(f"Total trading days: {len(timestamps)}")
|
||||
formatted_lines.append(f"Average volume: {sum(volumes)/len(volumes):,.0f}")
|
||||
formatted_lines.append(f"Price range: ${min(lows):.2f} - ${max(highs):.2f}")
|
||||
|
||||
# Calculate return
|
||||
price_change = closes[-1] - opens[0]
|
||||
price_change_pct = (price_change / opens[0]) * 100
|
||||
formatted_lines.append(f"Period return: {price_change_pct:+.2f}% (${price_change:+.2f})")
|
||||
|
||||
return "\n".join(formatted_lines)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error formatting FINNHUB data for {symbol}: {e}")
|
||||
return f"Error formatting FINNHUB data for {symbol}: {str(e)}"
|
||||
|
||||
def get_stock_with_indicators(self, symbol: str, start_date: str, end_date: str,
|
||||
indicators: list = None) -> str:
|
||||
"""
|
||||
Get stock data with technical indicators
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol
|
||||
start_date: Start date (YYYY-MM-DD)
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
indicators: List of indicators to calculate ['sma_20', 'rsi', 'macd']
|
||||
|
||||
Returns:
|
||||
Formatted stock data with indicators
|
||||
"""
|
||||
try:
|
||||
# Get basic stock data
|
||||
basic_data = self.get_stock_data(symbol, start_date, end_date)
|
||||
|
||||
if basic_data.startswith("❌"):
|
||||
return basic_data
|
||||
|
||||
# If no indicators requested, return basic data
|
||||
if not indicators:
|
||||
return basic_data
|
||||
|
||||
# Fetch DataFrame for indicator calculation
|
||||
data_df = self._fetch_from_yfinance(symbol, start_date, end_date)
|
||||
if data_df is None or data_df.empty:
|
||||
return basic_data
|
||||
|
||||
# Calculate indicators
|
||||
indicator_data = self._calculate_indicators(data_df, indicators)
|
||||
|
||||
# Combine basic data with indicators
|
||||
combined_data = basic_data + "\n\nTechnical Indicators:\n" + indicator_data
|
||||
|
||||
return combined_data
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Error getting stock data with indicators for {symbol}: {e}"
|
||||
print(error_msg)
|
||||
return error_msg
|
||||
|
||||
def _calculate_indicators(self, data: pd.DataFrame, indicators: list) -> str:
|
||||
"""Calculate technical indicators"""
|
||||
try:
|
||||
indicator_lines = []
|
||||
|
||||
for indicator in indicators:
|
||||
if indicator == 'sma_20':
|
||||
data['SMA_20'] = data['Close'].rolling(window=20).mean()
|
||||
latest_sma = data['SMA_20'].iloc[-1]
|
||||
indicator_lines.append(f"SMA(20): ${latest_sma:.2f}")
|
||||
|
||||
elif indicator == 'rsi':
|
||||
# Simple RSI calculation
|
||||
delta = data['Close'].diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
|
||||
rs = gain / loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
latest_rsi = rsi.iloc[-1]
|
||||
indicator_lines.append(f"RSI(14): {latest_rsi:.2f}")
|
||||
|
||||
elif indicator == 'macd':
|
||||
# Simple MACD calculation
|
||||
ema_12 = data['Close'].ewm(span=12).mean()
|
||||
ema_26 = data['Close'].ewm(span=26).mean()
|
||||
macd_line = ema_12 - ema_26
|
||||
signal_line = macd_line.ewm(span=9).mean()
|
||||
latest_macd = macd_line.iloc[-1]
|
||||
latest_signal = signal_line.iloc[-1]
|
||||
indicator_lines.append(f"MACD: {latest_macd:.4f}, Signal: {latest_signal:.4f}")
|
||||
|
||||
return "\n".join(indicator_lines)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error calculating indicators: {e}")
|
||||
return f"Error calculating indicators: {str(e)}"
|
||||
|
||||
|
||||
# Global provider instance
|
||||
_global_provider = None
|
||||
|
||||
def get_optimized_us_data_provider() -> OptimizedUSDataProvider:
|
||||
"""
|
||||
Get global optimized US data provider instance
|
||||
|
||||
Returns:
|
||||
OptimizedUSDataProvider instance
|
||||
"""
|
||||
global _global_provider
|
||||
if _global_provider is None:
|
||||
_global_provider = OptimizedUSDataProvider()
|
||||
return _global_provider
|
||||
|
||||
|
||||
# Convenience functions
|
||||
def get_optimized_stock_data(symbol: str, start_date: str, end_date: str,
|
||||
force_refresh: bool = False) -> str:
|
||||
"""Get optimized stock data (convenience function)"""
|
||||
provider = get_optimized_us_data_provider()
|
||||
return provider.get_stock_data(symbol, start_date, end_date, force_refresh)
|
||||
|
||||
|
||||
def get_stock_with_indicators(symbol: str, start_date: str, end_date: str,
|
||||
indicators: list = None) -> str:
|
||||
"""Get stock data with technical indicators (convenience function)"""
|
||||
provider = get_optimized_us_data_provider()
|
||||
return provider.get_stock_with_indicators(symbol, start_date, end_date, indicators)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the optimized data provider
|
||||
print("🧪 Testing Optimized US Data Provider...")
|
||||
|
||||
# Initialize provider
|
||||
provider = OptimizedUSDataProvider()
|
||||
|
||||
# Test data fetch
|
||||
data = provider.get_stock_data("AAPL", "2024-01-01", "2024-01-31")
|
||||
print("Sample data:")
|
||||
print(data[:500] + "..." if len(data) > 500 else data)
|
||||
|
||||
# Test with indicators
|
||||
data_with_indicators = provider.get_stock_with_indicators(
|
||||
"AAPL", "2024-01-01", "2024-01-31",
|
||||
indicators=['sma_20', 'rsi', 'macd']
|
||||
)
|
||||
print("\nData with indicators:")
|
||||
print(data_with_indicators[-500:] if len(data_with_indicators) > 500 else data_with_indicators)
|
||||
|
||||
print("✅ Optimized data provider test completed!")
|
||||
|
|
@ -0,0 +1,679 @@
|
|||
# 文件差异报告
|
||||
# 当前文件: tradingagents\dataflows\optimized_us_data.py
|
||||
# 中文版文件: TradingAgentsCN\tradingagents\dataflows\optimized_us_data.py
|
||||
# 生成时间: 周日 2025/07/06
|
||||
|
||||
--- current/optimized_us_data.py+++ chinese_version/optimized_us_data.py@@ -1,7 +1,7 @@ #!/usr/bin/env python3
|
||||
"""
|
||||
-Optimized US Stock Data Fetcher
|
||||
-Integrates caching strategy to reduce API calls and improve response speed
|
||||
+优化的美股数据获取工具
|
||||
+集成缓存策略,减少API调用,提高响应速度
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -16,24 +16,24 @@
|
||||
|
||||
class OptimizedUSDataProvider:
|
||||
- """Optimized US Stock Data Provider - Integrates caching and API rate limiting"""
|
||||
+ """优化的美股数据提供器 - 集成缓存和API限制处理"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = get_cache()
|
||||
self.config = get_config()
|
||||
self.last_api_call = 0
|
||||
- self.min_api_interval = 1.0 # Minimum API call interval (seconds)
|
||||
-
|
||||
- print("📊 Optimized US stock data provider initialized")
|
||||
+ self.min_api_interval = 1.0 # 最小API调用间隔(秒)
|
||||
+
|
||||
+ print("📊 优化美股数据提供器初始化完成")
|
||||
|
||||
def _wait_for_rate_limit(self):
|
||||
- """Wait for API rate limit"""
|
||||
+ """等待API限制"""
|
||||
current_time = time.time()
|
||||
time_since_last_call = current_time - self.last_api_call
|
||||
|
||||
if time_since_last_call < self.min_api_interval:
|
||||
wait_time = self.min_api_interval - time_since_last_call
|
||||
- print(f"⏳ API rate limit wait {wait_time:.1f}s...")
|
||||
+ print(f"⏳ API限制等待 {wait_time:.1f}s...")
|
||||
time.sleep(wait_time)
|
||||
|
||||
self.last_api_call = time.time()
|
||||
@@ -41,364 +41,292 @@ def get_stock_data(self, symbol: str, start_date: str, end_date: str,
|
||||
force_refresh: bool = False) -> str:
|
||||
"""
|
||||
- Get US stock data - prioritize cache usage
|
||||
+ 获取美股数据 - 优先使用缓存
|
||||
|
||||
Args:
|
||||
- symbol: Stock symbol
|
||||
- start_date: Start date (YYYY-MM-DD)
|
||||
- end_date: End date (YYYY-MM-DD)
|
||||
- force_refresh: Whether to force refresh cache
|
||||
-
|
||||
+ symbol: 股票代码
|
||||
+ start_date: 开始日期 (YYYY-MM-DD)
|
||||
+ end_date: 结束日期 (YYYY-MM-DD)
|
||||
+ force_refresh: 是否强制刷新缓存
|
||||
+
|
||||
Returns:
|
||||
- Formatted stock data string
|
||||
+ 格式化的股票数据字符串
|
||||
"""
|
||||
+ print(f"📈 获取美股数据: {symbol} ({start_date} 到 {end_date})")
|
||||
+
|
||||
+ # 检查缓存(除非强制刷新)
|
||||
+ if not force_refresh:
|
||||
+ # 优先查找FINNHUB缓存
|
||||
+ cache_key = self.cache.find_cached_stock_data(
|
||||
+ symbol=symbol,
|
||||
+ start_date=start_date,
|
||||
+ end_date=end_date,
|
||||
+ data_source="finnhub"
|
||||
+ )
|
||||
+
|
||||
+ # 如果没有FINNHUB缓存,查找Yahoo Finance缓存
|
||||
+ if not cache_key:
|
||||
+ cache_key = self.cache.find_cached_stock_data(
|
||||
+ symbol=symbol,
|
||||
+ start_date=start_date,
|
||||
+ end_date=end_date,
|
||||
+ data_source="yfinance"
|
||||
+ )
|
||||
+
|
||||
+ if cache_key:
|
||||
+ cached_data = self.cache.load_stock_data(cache_key)
|
||||
+ if cached_data:
|
||||
+ print(f"⚡ 从缓存加载美股数据: {symbol}")
|
||||
+ return cached_data
|
||||
+
|
||||
+ # 缓存未命中,从API获取 - 优先使用FINNHUB
|
||||
+ formatted_data = None
|
||||
+ data_source = None
|
||||
+
|
||||
+ # 尝试FINNHUB API(优先)
|
||||
try:
|
||||
- # Check cache first (unless force refresh)
|
||||
- if not force_refresh:
|
||||
- cache_key = self.cache.find_cached_stock_data(
|
||||
- symbol, start_date, end_date, "optimized_yfinance"
|
||||
- )
|
||||
-
|
||||
- if cache_key and self.cache.is_cache_valid(cache_key, symbol):
|
||||
- cached_data = self.cache.load_stock_data(cache_key)
|
||||
- if cached_data:
|
||||
- print(f"📖 Using cached data for {symbol}")
|
||||
- if isinstance(cached_data, pd.DataFrame):
|
||||
- return self._format_stock_data(cached_data, symbol)
|
||||
- else:
|
||||
- return cached_data
|
||||
-
|
||||
- # Fetch new data from API
|
||||
- print(f"🌐 Fetching new data for {symbol} from {start_date} to {end_date}")
|
||||
-
|
||||
- # Wait for rate limit
|
||||
+ print(f"🌐 从FINNHUB API获取数据: {symbol}")
|
||||
self._wait_for_rate_limit()
|
||||
-
|
||||
- # Try Yahoo Finance first
|
||||
+
|
||||
+ formatted_data = self._get_data_from_finnhub(symbol, start_date, end_date)
|
||||
+ if formatted_data and "❌" not in formatted_data:
|
||||
+ data_source = "finnhub"
|
||||
+ print(f"✅ FINNHUB数据获取成功: {symbol}")
|
||||
+ else:
|
||||
+ print(f"⚠️ FINNHUB数据获取失败,尝试备用方案")
|
||||
+ formatted_data = None
|
||||
+
|
||||
+ except Exception as e:
|
||||
+ print(f"❌ FINNHUB API调用失败: {e}")
|
||||
+ formatted_data = None
|
||||
+
|
||||
+ # 备用方案:Yahoo Finance API
|
||||
+ if not formatted_data:
|
||||
try:
|
||||
- data = self._fetch_from_yfinance(symbol, start_date, end_date)
|
||||
- if data is not None and not data.empty:
|
||||
- # Cache the DataFrame
|
||||
- cache_key = self.cache.save_stock_data(
|
||||
- symbol, data, start_date, end_date, "optimized_yfinance"
|
||||
- )
|
||||
+ print(f"🌐 从Yahoo Finance API获取数据: {symbol}")
|
||||
+ self._wait_for_rate_limit()
|
||||
+
|
||||
+ # 获取数据
|
||||
+ ticker = yf.Ticker(symbol.upper())
|
||||
+ data = ticker.history(start=start_date, end=end_date)
|
||||
+
|
||||
+ if data.empty:
|
||||
+ error_msg = f"未找到股票 '{symbol}' 在 {start_date} 到 {end_date} 期间的数据"
|
||||
+ print(f"❌ {error_msg}")
|
||||
+ else:
|
||||
+ # 格式化数据
|
||||
+ formatted_data = self._format_stock_data(symbol, data, start_date, end_date)
|
||||
+ data_source = "yfinance"
|
||||
+ print(f"✅ Yahoo Finance数据获取成功: {symbol}")
|
||||
+
|
||||
+ except Exception as e:
|
||||
+ print(f"❌ Yahoo Finance API调用失败: {e}")
|
||||
+ formatted_data = None
|
||||
+
|
||||
+ # 如果所有API都失败,生成备用数据
|
||||
+ if not formatted_data:
|
||||
+ error_msg = "所有美股数据源都不可用"
|
||||
+ print(f"❌ {error_msg}")
|
||||
+ return self._generate_fallback_data(symbol, start_date, end_date, error_msg)
|
||||
+
|
||||
+ # 保存到缓存
|
||||
+ self.cache.save_stock_data(
|
||||
+ symbol=symbol,
|
||||
+ data=formatted_data,
|
||||
+ start_date=start_date,
|
||||
+ end_date=end_date,
|
||||
+ data_source=data_source
|
||||
+ )
|
||||
+
|
||||
+ return formatted_data
|
||||
+
|
||||
+ def _format_stock_data(self, symbol: str, data: pd.DataFrame,
|
||||
+ start_date: str, end_date: str) -> str:
|
||||
+ """格式化股票数据为字符串"""
|
||||
+
|
||||
+ # 移除时区信息
|
||||
+ if data.index.tz is not None:
|
||||
+ data.index = data.index.tz_localize(None)
|
||||
+
|
||||
+ # 四舍五入数值
|
||||
+ numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"]
|
||||
+ for col in numeric_columns:
|
||||
+ if col in data.columns:
|
||||
+ data[col] = data[col].round(2)
|
||||
+
|
||||
+ # 获取最新价格和统计信息
|
||||
+ latest_price = data['Close'].iloc[-1]
|
||||
+ price_change = data['Close'].iloc[-1] - data['Close'].iloc[0]
|
||||
+ price_change_pct = (price_change / data['Close'].iloc[0]) * 100
|
||||
+
|
||||
+ # 计算技术指标
|
||||
+ data['MA5'] = data['Close'].rolling(window=5).mean()
|
||||
+ data['MA10'] = data['Close'].rolling(window=10).mean()
|
||||
+ data['MA20'] = data['Close'].rolling(window=20).mean()
|
||||
+
|
||||
+ # 计算RSI
|
||||
+ delta = data['Close'].diff()
|
||||
+ gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
|
||||
+ loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
|
||||
+ rs = gain / loss
|
||||
+ rsi = 100 - (100 / (1 + rs))
|
||||
+
|
||||
+ # 格式化输出
|
||||
+ result = f"""# {symbol} 美股数据分析
|
||||
+
|
||||
+## 📊 基本信息
|
||||
+- 股票代码: {symbol}
|
||||
+- 数据期间: {start_date} 至 {end_date}
|
||||
+- 数据条数: {len(data)}条
|
||||
+- 最新价格: ${latest_price:.2f}
|
||||
+- 期间涨跌: ${price_change:+.2f} ({price_change_pct:+.2f}%)
|
||||
+
|
||||
+## 📈 价格统计
|
||||
+- 期间最高: ${data['High'].max():.2f}
|
||||
+- 期间最低: ${data['Low'].min():.2f}
|
||||
+- 平均成交量: {data['Volume'].mean():,.0f}
|
||||
+
|
||||
+## 🔍 技术指标
|
||||
+- MA5: ${data['MA5'].iloc[-1]:.2f}
|
||||
+- MA10: ${data['MA10'].iloc[-1]:.2f}
|
||||
+- MA20: ${data['MA20'].iloc[-1]:.2f}
|
||||
+- RSI: {rsi.iloc[-1]:.2f}
|
||||
+
|
||||
+## 📋 最近5日数据
|
||||
+{data.tail().to_string()}
|
||||
+
|
||||
+数据来源: Yahoo Finance API
|
||||
+更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
+"""
|
||||
+
|
||||
+ return result
|
||||
+
|
||||
+ def _try_get_old_cache(self, symbol: str, start_date: str, end_date: str) -> Optional[str]:
|
||||
+ """尝试获取过期的缓存数据作为备用"""
|
||||
+ try:
|
||||
+ # 查找任何相关的缓存,不考虑TTL
|
||||
+ for metadata_file in self.cache.metadata_dir.glob(f"*_meta.json"):
|
||||
+ try:
|
||||
+ import json
|
||||
+ with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||
+ metadata = json.load(f)
|
||||
|
||||
- # Format and return
|
||||
- formatted_data = self._format_stock_data(data, symbol)
|
||||
- print(f"✅ Successfully fetched and cached data for {symbol}")
|
||||
- return formatted_data
|
||||
- else:
|
||||
- print(f"⚠️ No data returned from Yahoo Finance for {symbol}")
|
||||
-
|
||||
- except Exception as e:
|
||||
- print(f"❌ Yahoo Finance error for {symbol}: {e}")
|
||||
-
|
||||
- # Fallback: Try FINNHUB (if API key available)
|
||||
- try:
|
||||
- finnhub_data = self._fetch_from_finnhub(symbol, start_date, end_date)
|
||||
- if finnhub_data:
|
||||
- # Cache the string data
|
||||
- cache_key = self.cache.save_stock_data(
|
||||
- symbol, finnhub_data, start_date, end_date, "optimized_finnhub"
|
||||
- )
|
||||
- print(f"✅ Successfully fetched data from FINNHUB for {symbol}")
|
||||
- return finnhub_data
|
||||
-
|
||||
- except Exception as e:
|
||||
- print(f"❌ FINNHUB error for {symbol}: {e}")
|
||||
-
|
||||
- # If all fails, return error message
|
||||
- error_msg = f"❌ Failed to fetch data for {symbol} from {start_date} to {end_date}"
|
||||
- print(error_msg)
|
||||
- return error_msg
|
||||
-
|
||||
+ if (metadata.get('symbol') == symbol and
|
||||
+ metadata.get('data_type') == 'stock_data' and
|
||||
+ metadata.get('market_type') == 'us'):
|
||||
+
|
||||
+ cache_key = metadata_file.stem.replace('_meta', '')
|
||||
+ cached_data = self.cache.load_stock_data(cache_key)
|
||||
+ if cached_data:
|
||||
+ return cached_data + "\n\n⚠️ 注意: 使用的是过期缓存数据"
|
||||
+ except Exception:
|
||||
+ continue
|
||||
+ except Exception:
|
||||
+ pass
|
||||
+
|
||||
+ return None
|
||||
+
|
||||
+ def _get_data_from_finnhub(self, symbol: str, start_date: str, end_date: str) -> str:
|
||||
+ """从FINNHUB API获取股票数据"""
|
||||
+ try:
|
||||
+ import finnhub
|
||||
+ import os
|
||||
+ from datetime import datetime, timedelta
|
||||
+
|
||||
+ # 获取API密钥
|
||||
+ api_key = os.getenv('FINNHUB_API_KEY')
|
||||
+ if not api_key:
|
||||
+ return None
|
||||
+
|
||||
+ client = finnhub.Client(api_key=api_key)
|
||||
+
|
||||
+ # 获取实时报价
|
||||
+ quote = client.quote(symbol.upper())
|
||||
+ if not quote or 'c' not in quote:
|
||||
+ return None
|
||||
+
|
||||
+ # 获取公司信息
|
||||
+ profile = client.company_profile2(symbol=symbol.upper())
|
||||
+ company_name = profile.get('name', symbol.upper()) if profile else symbol.upper()
|
||||
+
|
||||
+ # 格式化数据
|
||||
+ current_price = quote.get('c', 0)
|
||||
+ change = quote.get('d', 0)
|
||||
+ change_percent = quote.get('dp', 0)
|
||||
+
|
||||
+ formatted_data = f"""# {symbol.upper()} 美股数据分析
|
||||
+
|
||||
+## 📊 实时行情
|
||||
+- 股票名称: {company_name}
|
||||
+- 当前价格: ${current_price:.2f}
|
||||
+- 涨跌额: ${change:+.2f}
|
||||
+- 涨跌幅: {change_percent:+.2f}%
|
||||
+- 开盘价: ${quote.get('o', 0):.2f}
|
||||
+- 最高价: ${quote.get('h', 0):.2f}
|
||||
+- 最低价: ${quote.get('l', 0):.2f}
|
||||
+- 前收盘: ${quote.get('pc', 0):.2f}
|
||||
+- 更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
+
|
||||
+## 📈 数据概览
|
||||
+- 数据期间: {start_date} 至 {end_date}
|
||||
+- 数据来源: FINNHUB API (实时数据)
|
||||
+- 当前价位相对位置: {((current_price - quote.get('l', current_price)) / max(quote.get('h', current_price) - quote.get('l', current_price), 0.01) * 100):.1f}%
|
||||
+- 日内振幅: {((quote.get('h', 0) - quote.get('l', 0)) / max(quote.get('pc', 1), 0.01) * 100):.2f}%
|
||||
+
|
||||
+生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
+"""
|
||||
+
|
||||
+ return formatted_data
|
||||
+
|
||||
except Exception as e:
|
||||
- error_msg = f"❌ Unexpected error fetching data for {symbol}: {e}"
|
||||
- print(error_msg)
|
||||
- return error_msg
|
||||
-
|
||||
- def _fetch_from_yfinance(self, symbol: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
- """Fetch data from Yahoo Finance"""
|
||||
- try:
|
||||
- ticker = yf.Ticker(symbol)
|
||||
- data = ticker.history(start=start_date, end=end_date)
|
||||
-
|
||||
- if data.empty:
|
||||
- print(f"⚠️ No data available for {symbol} in the specified date range")
|
||||
- return None
|
||||
-
|
||||
- # Reset index to make Date a column
|
||||
- data = data.reset_index()
|
||||
-
|
||||
- # Ensure we have the required columns
|
||||
- required_columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
|
||||
- missing_columns = [col for col in required_columns if col not in data.columns]
|
||||
-
|
||||
- if missing_columns:
|
||||
- print(f"⚠️ Missing columns for {symbol}: {missing_columns}")
|
||||
- return None
|
||||
-
|
||||
- return data[required_columns]
|
||||
-
|
||||
- except Exception as e:
|
||||
- print(f"❌ Yahoo Finance fetch error for {symbol}: {e}")
|
||||
+ print(f"❌ FINNHUB数据获取失败: {e}")
|
||||
return None
|
||||
-
|
||||
- def _fetch_from_finnhub(self, symbol: str, start_date: str, end_date: str) -> Optional[str]:
|
||||
- """Fetch data from FINNHUB API"""
|
||||
- try:
|
||||
- # Check if FINNHUB API key is available
|
||||
- finnhub_api_key = os.getenv('FINNHUB_API_KEY')
|
||||
- if not finnhub_api_key:
|
||||
- print("⚠️ FINNHUB API key not found, skipping FINNHUB data fetch")
|
||||
- return None
|
||||
-
|
||||
- import finnhub
|
||||
-
|
||||
- # Initialize FINNHUB client
|
||||
- finnhub_client = finnhub.Client(api_key=finnhub_api_key)
|
||||
-
|
||||
- # Convert dates to timestamps
|
||||
- start_timestamp = int(datetime.strptime(start_date, '%Y-%m-%d').timestamp())
|
||||
- end_timestamp = int(datetime.strptime(end_date, '%Y-%m-%d').timestamp())
|
||||
-
|
||||
- # Fetch candle data
|
||||
- candle_data = finnhub_client.stock_candles(symbol, 'D', start_timestamp, end_timestamp)
|
||||
-
|
||||
- if candle_data['s'] != 'ok':
|
||||
- print(f"⚠️ FINNHUB returned status: {candle_data['s']} for {symbol}")
|
||||
- return None
|
||||
-
|
||||
- # Format data
|
||||
- formatted_data = self._format_finnhub_data(candle_data, symbol)
|
||||
- return formatted_data
|
||||
-
|
||||
- except ImportError:
|
||||
- print("⚠️ finnhub-python package not installed, skipping FINNHUB data fetch")
|
||||
- return None
|
||||
- except Exception as e:
|
||||
- print(f"❌ FINNHUB fetch error for {symbol}: {e}")
|
||||
- return None
|
||||
-
|
||||
- def _format_stock_data(self, data: pd.DataFrame, symbol: str) -> str:
|
||||
- """Format DataFrame stock data into string"""
|
||||
- try:
|
||||
- # Ensure Date column is properly formatted
|
||||
- if 'Date' in data.columns:
|
||||
- data['Date'] = pd.to_datetime(data['Date']).dt.strftime('%Y-%m-%d')
|
||||
-
|
||||
- # Round numerical columns to 2 decimal places
|
||||
- numeric_columns = ['Open', 'High', 'Low', 'Close']
|
||||
- for col in numeric_columns:
|
||||
- if col in data.columns:
|
||||
- data[col] = data[col].round(2)
|
||||
-
|
||||
- # Format volume as integer
|
||||
- if 'Volume' in data.columns:
|
||||
- data['Volume'] = data['Volume'].astype(int)
|
||||
-
|
||||
- # Create formatted string
|
||||
- formatted_lines = [f"Stock Data for {symbol}:"]
|
||||
- formatted_lines.append("Date,Open,High,Low,Close,Volume")
|
||||
-
|
||||
- for _, row in data.iterrows():
|
||||
- line = f"{row['Date']},{row['Open']},{row['High']},{row['Low']},{row['Close']},{row['Volume']}"
|
||||
- formatted_lines.append(line)
|
||||
-
|
||||
- # Add summary statistics
|
||||
- if len(data) > 0:
|
||||
- formatted_lines.append(f"\nSummary for {symbol}:")
|
||||
- formatted_lines.append(f"Period: {data['Date'].iloc[0]} to {data['Date'].iloc[-1]}")
|
||||
- formatted_lines.append(f"Total trading days: {len(data)}")
|
||||
- formatted_lines.append(f"Average volume: {data['Volume'].mean():,.0f}")
|
||||
- formatted_lines.append(f"Price range: ${data['Low'].min():.2f} - ${data['High'].max():.2f}")
|
||||
-
|
||||
- # Calculate basic statistics
|
||||
- start_price = data['Open'].iloc[0]
|
||||
- end_price = data['Close'].iloc[-1]
|
||||
- price_change = end_price - start_price
|
||||
- price_change_pct = (price_change / start_price) * 100
|
||||
-
|
||||
- formatted_lines.append(f"Period return: {price_change_pct:+.2f}% (${price_change:+.2f})")
|
||||
-
|
||||
- return "\n".join(formatted_lines)
|
||||
-
|
||||
- except Exception as e:
|
||||
- print(f"❌ Error formatting stock data for {symbol}: {e}")
|
||||
- return f"Error formatting data for {symbol}: {str(e)}"
|
||||
-
|
||||
- def _format_finnhub_data(self, candle_data: Dict, symbol: str) -> str:
|
||||
- """Format FINNHUB candle data into string"""
|
||||
- try:
|
||||
- # Extract data arrays
|
||||
- timestamps = candle_data['t']
|
||||
- opens = candle_data['o']
|
||||
- highs = candle_data['h']
|
||||
- lows = candle_data['l']
|
||||
- closes = candle_data['c']
|
||||
- volumes = candle_data['v']
|
||||
-
|
||||
- # Create formatted string
|
||||
- formatted_lines = [f"Stock Data for {symbol} (FINNHUB):"]
|
||||
- formatted_lines.append("Date,Open,High,Low,Close,Volume")
|
||||
-
|
||||
- for i in range(len(timestamps)):
|
||||
- date = datetime.fromtimestamp(timestamps[i]).strftime('%Y-%m-%d')
|
||||
- line = f"{date},{opens[i]:.2f},{highs[i]:.2f},{lows[i]:.2f},{closes[i]:.2f},{int(volumes[i])}"
|
||||
- formatted_lines.append(line)
|
||||
-
|
||||
- # Add summary
|
||||
- if len(timestamps) > 0:
|
||||
- start_date = datetime.fromtimestamp(timestamps[0]).strftime('%Y-%m-%d')
|
||||
- end_date = datetime.fromtimestamp(timestamps[-1]).strftime('%Y-%m-%d')
|
||||
-
|
||||
- formatted_lines.append(f"\nSummary for {symbol}:")
|
||||
- formatted_lines.append(f"Period: {start_date} to {end_date}")
|
||||
- formatted_lines.append(f"Total trading days: {len(timestamps)}")
|
||||
- formatted_lines.append(f"Average volume: {sum(volumes)/len(volumes):,.0f}")
|
||||
- formatted_lines.append(f"Price range: ${min(lows):.2f} - ${max(highs):.2f}")
|
||||
-
|
||||
- # Calculate return
|
||||
- price_change = closes[-1] - opens[0]
|
||||
- price_change_pct = (price_change / opens[0]) * 100
|
||||
- formatted_lines.append(f"Period return: {price_change_pct:+.2f}% (${price_change:+.2f})")
|
||||
-
|
||||
- return "\n".join(formatted_lines)
|
||||
-
|
||||
- except Exception as e:
|
||||
- print(f"❌ Error formatting FINNHUB data for {symbol}: {e}")
|
||||
- return f"Error formatting FINNHUB data for {symbol}: {str(e)}"
|
||||
-
|
||||
- def get_stock_with_indicators(self, symbol: str, start_date: str, end_date: str,
|
||||
- indicators: list = None) -> str:
|
||||
- """
|
||||
- Get stock data with technical indicators
|
||||
-
|
||||
- Args:
|
||||
- symbol: Stock symbol
|
||||
- start_date: Start date (YYYY-MM-DD)
|
||||
- end_date: End date (YYYY-MM-DD)
|
||||
- indicators: List of indicators to calculate ['sma_20', 'rsi', 'macd']
|
||||
-
|
||||
- Returns:
|
||||
- Formatted stock data with indicators
|
||||
- """
|
||||
- try:
|
||||
- # Get basic stock data
|
||||
- basic_data = self.get_stock_data(symbol, start_date, end_date)
|
||||
-
|
||||
- if basic_data.startswith("❌"):
|
||||
- return basic_data
|
||||
-
|
||||
- # If no indicators requested, return basic data
|
||||
- if not indicators:
|
||||
- return basic_data
|
||||
-
|
||||
- # Fetch DataFrame for indicator calculation
|
||||
- data_df = self._fetch_from_yfinance(symbol, start_date, end_date)
|
||||
- if data_df is None or data_df.empty:
|
||||
- return basic_data
|
||||
-
|
||||
- # Calculate indicators
|
||||
- indicator_data = self._calculate_indicators(data_df, indicators)
|
||||
-
|
||||
- # Combine basic data with indicators
|
||||
- combined_data = basic_data + "\n\nTechnical Indicators:\n" + indicator_data
|
||||
-
|
||||
- return combined_data
|
||||
-
|
||||
- except Exception as e:
|
||||
- error_msg = f"❌ Error getting stock data with indicators for {symbol}: {e}"
|
||||
- print(error_msg)
|
||||
- return error_msg
|
||||
-
|
||||
- def _calculate_indicators(self, data: pd.DataFrame, indicators: list) -> str:
|
||||
- """Calculate technical indicators"""
|
||||
- try:
|
||||
- indicator_lines = []
|
||||
-
|
||||
- for indicator in indicators:
|
||||
- if indicator == 'sma_20':
|
||||
- data['SMA_20'] = data['Close'].rolling(window=20).mean()
|
||||
- latest_sma = data['SMA_20'].iloc[-1]
|
||||
- indicator_lines.append(f"SMA(20): ${latest_sma:.2f}")
|
||||
-
|
||||
- elif indicator == 'rsi':
|
||||
- # Simple RSI calculation
|
||||
- delta = data['Close'].diff()
|
||||
- gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
|
||||
- loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
|
||||
- rs = gain / loss
|
||||
- rsi = 100 - (100 / (1 + rs))
|
||||
- latest_rsi = rsi.iloc[-1]
|
||||
- indicator_lines.append(f"RSI(14): {latest_rsi:.2f}")
|
||||
-
|
||||
- elif indicator == 'macd':
|
||||
- # Simple MACD calculation
|
||||
- ema_12 = data['Close'].ewm(span=12).mean()
|
||||
- ema_26 = data['Close'].ewm(span=26).mean()
|
||||
- macd_line = ema_12 - ema_26
|
||||
- signal_line = macd_line.ewm(span=9).mean()
|
||||
- latest_macd = macd_line.iloc[-1]
|
||||
- latest_signal = signal_line.iloc[-1]
|
||||
- indicator_lines.append(f"MACD: {latest_macd:.4f}, Signal: {latest_signal:.4f}")
|
||||
-
|
||||
- return "\n".join(indicator_lines)
|
||||
-
|
||||
- except Exception as e:
|
||||
- print(f"❌ Error calculating indicators: {e}")
|
||||
- return f"Error calculating indicators: {str(e)}"
|
||||
-
|
||||
-
|
||||
-# Global provider instance
|
||||
-_global_provider = None
|
||||
+
|
||||
+ def _generate_fallback_data(self, symbol: str, start_date: str, end_date: str, error_msg: str) -> str:
|
||||
+ """生成备用数据"""
|
||||
+ return f"""# {symbol} 美股数据获取失败
|
||||
+
|
||||
+## ❌ 错误信息
|
||||
+{error_msg}
|
||||
+
|
||||
+## 📊 模拟数据(仅供演示)
|
||||
+- 股票代码: {symbol}
|
||||
+- 数据期间: {start_date} 至 {end_date}
|
||||
+- 最新价格: ${random.uniform(100, 300):.2f}
|
||||
+- 模拟涨跌: {random.uniform(-5, 5):+.2f}%
|
||||
+
|
||||
+## ⚠️ 重要提示
|
||||
+由于API限制或网络问题,无法获取实时数据。
|
||||
+建议稍后重试或检查网络连接。
|
||||
+
|
||||
+生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
+"""
|
||||
+
|
||||
+
|
||||
+# 全局实例
|
||||
+_us_data_provider = None
|
||||
|
||||
def get_optimized_us_data_provider() -> OptimizedUSDataProvider:
|
||||
+ """获取全局美股数据提供器实例"""
|
||||
+ global _us_data_provider
|
||||
+ if _us_data_provider is None:
|
||||
+ _us_data_provider = OptimizedUSDataProvider()
|
||||
+ return _us_data_provider
|
||||
+
|
||||
+
|
||||
+def get_us_stock_data_cached(symbol: str, start_date: str, end_date: str,
|
||||
+ force_refresh: bool = False) -> str:
|
||||
"""
|
||||
- Get global optimized US data provider instance
|
||||
+ 获取美股数据的便捷函数
|
||||
+
|
||||
+ Args:
|
||||
+ symbol: 股票代码
|
||||
+ start_date: 开始日期 (YYYY-MM-DD)
|
||||
+ end_date: 结束日期 (YYYY-MM-DD)
|
||||
+ force_refresh: 是否强制刷新缓存
|
||||
|
||||
Returns:
|
||||
- OptimizedUSDataProvider instance
|
||||
+ 格式化的股票数据字符串
|
||||
"""
|
||||
- global _global_provider
|
||||
- if _global_provider is None:
|
||||
- _global_provider = OptimizedUSDataProvider()
|
||||
- return _global_provider
|
||||
-
|
||||
-
|
||||
-# Convenience functions
|
||||
-def get_optimized_stock_data(symbol: str, start_date: str, end_date: str,
|
||||
- force_refresh: bool = False) -> str:
|
||||
- """Get optimized stock data (convenience function)"""
|
||||
provider = get_optimized_us_data_provider()
|
||||
return provider.get_stock_data(symbol, start_date, end_date, force_refresh)
|
||||
-
|
||||
-
|
||||
-def get_stock_with_indicators(symbol: str, start_date: str, end_date: str,
|
||||
- indicators: list = None) -> str:
|
||||
- """Get stock data with technical indicators (convenience function)"""
|
||||
- provider = get_optimized_us_data_provider()
|
||||
- return provider.get_stock_with_indicators(symbol, start_date, end_date, indicators)
|
||||
-
|
||||
-
|
||||
-if __name__ == "__main__":
|
||||
- # Test the optimized data provider
|
||||
- print("🧪 Testing Optimized US Data Provider...")
|
||||
-
|
||||
- # Initialize provider
|
||||
- provider = OptimizedUSDataProvider()
|
||||
-
|
||||
- # Test data fetch
|
||||
- data = provider.get_stock_data("AAPL", "2024-01-01", "2024-01-31")
|
||||
- print("Sample data:")
|
||||
- print(data[:500] + "..." if len(data) > 500 else data)
|
||||
-
|
||||
- # Test with indicators
|
||||
- data_with_indicators = provider.get_stock_with_indicators(
|
||||
- "AAPL", "2024-01-01", "2024-01-31",
|
||||
- indicators=['sma_20', 'rsi', 'macd']
|
||||
- )
|
||||
- print("\nData with indicators:")
|
||||
- print(data_with_indicators[-500:] if len(data_with_indicators) > 500 else data_with_indicators)
|
||||
-
|
||||
- print("✅ Optimized data provider test completed!")
|
||||
|
|
@ -0,0 +1,395 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
实时新闻数据获取工具
|
||||
解决新闻滞后性问题
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Optional
|
||||
import time
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewsItem:
|
||||
"""新闻项目数据结构"""
|
||||
title: str
|
||||
content: str
|
||||
source: str
|
||||
publish_time: datetime
|
||||
url: str
|
||||
urgency: str # high, medium, low
|
||||
relevance_score: float
|
||||
|
||||
|
||||
class RealtimeNewsAggregator:
|
||||
"""实时新闻聚合器"""
|
||||
|
||||
def __init__(self):
|
||||
self.headers = {
|
||||
'User-Agent': 'TradingAgents-CN/1.0'
|
||||
}
|
||||
|
||||
# API密钥配置
|
||||
self.finnhub_key = os.getenv('FINNHUB_API_KEY')
|
||||
self.alpha_vantage_key = os.getenv('ALPHA_VANTAGE_API_KEY')
|
||||
self.newsapi_key = os.getenv('NEWSAPI_KEY')
|
||||
|
||||
def get_realtime_stock_news(self, ticker: str, hours_back: int = 6) -> List[NewsItem]:
|
||||
"""
|
||||
获取实时股票新闻
|
||||
优先级:专业API > 新闻API > 搜索引擎
|
||||
"""
|
||||
all_news = []
|
||||
|
||||
# 1. FinnHub实时新闻 (最高优先级)
|
||||
finnhub_news = self._get_finnhub_realtime_news(ticker, hours_back)
|
||||
all_news.extend(finnhub_news)
|
||||
|
||||
# 2. Alpha Vantage新闻
|
||||
av_news = self._get_alpha_vantage_news(ticker, hours_back)
|
||||
all_news.extend(av_news)
|
||||
|
||||
# 3. NewsAPI (如果配置了)
|
||||
if self.newsapi_key:
|
||||
newsapi_news = self._get_newsapi_news(ticker, hours_back)
|
||||
all_news.extend(newsapi_news)
|
||||
|
||||
# 4. 中文财经新闻源
|
||||
chinese_news = self._get_chinese_finance_news(ticker, hours_back)
|
||||
all_news.extend(chinese_news)
|
||||
|
||||
# 去重和排序
|
||||
unique_news = self._deduplicate_news(all_news)
|
||||
return sorted(unique_news, key=lambda x: x.publish_time, reverse=True)
|
||||
|
||||
def _get_finnhub_realtime_news(self, ticker: str, hours_back: int) -> List[NewsItem]:
|
||||
"""获取FinnHub实时新闻"""
|
||||
if not self.finnhub_key:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 计算时间范围
|
||||
end_time = datetime.now()
|
||||
start_time = end_time - timedelta(hours=hours_back)
|
||||
|
||||
# FinnHub API调用
|
||||
url = "https://finnhub.io/api/v1/company-news"
|
||||
params = {
|
||||
'symbol': ticker,
|
||||
'from': start_time.strftime('%Y-%m-%d'),
|
||||
'to': end_time.strftime('%Y-%m-%d'),
|
||||
'token': self.finnhub_key
|
||||
}
|
||||
|
||||
response = requests.get(url, params=params, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
|
||||
news_data = response.json()
|
||||
news_items = []
|
||||
|
||||
for item in news_data:
|
||||
# 检查新闻时效性
|
||||
publish_time = datetime.fromtimestamp(item.get('datetime', 0))
|
||||
if publish_time < start_time:
|
||||
continue
|
||||
|
||||
# 评估紧急程度
|
||||
urgency = self._assess_news_urgency(item.get('headline', ''), item.get('summary', ''))
|
||||
|
||||
news_items.append(NewsItem(
|
||||
title=item.get('headline', ''),
|
||||
content=item.get('summary', ''),
|
||||
source=item.get('source', 'FinnHub'),
|
||||
publish_time=publish_time,
|
||||
url=item.get('url', ''),
|
||||
urgency=urgency,
|
||||
relevance_score=self._calculate_relevance(item.get('headline', ''), ticker)
|
||||
))
|
||||
|
||||
return news_items
|
||||
|
||||
except Exception as e:
|
||||
print(f"FinnHub新闻获取失败: {e}")
|
||||
return []
|
||||
|
||||
def _get_alpha_vantage_news(self, ticker: str, hours_back: int) -> List[NewsItem]:
|
||||
"""获取Alpha Vantage新闻"""
|
||||
if not self.alpha_vantage_key:
|
||||
return []
|
||||
|
||||
try:
|
||||
url = "https://www.alphavantage.co/query"
|
||||
params = {
|
||||
'function': 'NEWS_SENTIMENT',
|
||||
'tickers': ticker,
|
||||
'apikey': self.alpha_vantage_key,
|
||||
'limit': 50
|
||||
}
|
||||
|
||||
response = requests.get(url, params=params, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
news_items = []
|
||||
|
||||
if 'feed' in data:
|
||||
for item in data['feed']:
|
||||
# 解析时间
|
||||
time_str = item.get('time_published', '')
|
||||
try:
|
||||
publish_time = datetime.strptime(time_str, '%Y%m%dT%H%M%S')
|
||||
except:
|
||||
continue
|
||||
|
||||
# 检查时效性
|
||||
if publish_time < datetime.now() - timedelta(hours=hours_back):
|
||||
continue
|
||||
|
||||
urgency = self._assess_news_urgency(item.get('title', ''), item.get('summary', ''))
|
||||
|
||||
news_items.append(NewsItem(
|
||||
title=item.get('title', ''),
|
||||
content=item.get('summary', ''),
|
||||
source=item.get('source', 'Alpha Vantage'),
|
||||
publish_time=publish_time,
|
||||
url=item.get('url', ''),
|
||||
urgency=urgency,
|
||||
relevance_score=self._calculate_relevance(item.get('title', ''), ticker)
|
||||
))
|
||||
|
||||
return news_items
|
||||
|
||||
except Exception as e:
|
||||
print(f"Alpha Vantage新闻获取失败: {e}")
|
||||
return []
|
||||
|
||||
def _get_newsapi_news(self, ticker: str, hours_back: int) -> List[NewsItem]:
|
||||
"""获取NewsAPI新闻"""
|
||||
try:
|
||||
# 构建搜索查询
|
||||
company_names = {
|
||||
'AAPL': 'Apple',
|
||||
'TSLA': 'Tesla',
|
||||
'NVDA': 'NVIDIA',
|
||||
'MSFT': 'Microsoft',
|
||||
'GOOGL': 'Google'
|
||||
}
|
||||
|
||||
query = f"{ticker} OR {company_names.get(ticker, ticker)}"
|
||||
|
||||
url = "https://newsapi.org/v2/everything"
|
||||
params = {
|
||||
'q': query,
|
||||
'language': 'en',
|
||||
'sortBy': 'publishedAt',
|
||||
'from': (datetime.now() - timedelta(hours=hours_back)).isoformat(),
|
||||
'apiKey': self.newsapi_key
|
||||
}
|
||||
|
||||
response = requests.get(url, params=params, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
news_items = []
|
||||
|
||||
for item in data.get('articles', []):
|
||||
# 解析时间
|
||||
time_str = item.get('publishedAt', '')
|
||||
try:
|
||||
publish_time = datetime.fromisoformat(time_str.replace('Z', '+00:00'))
|
||||
except:
|
||||
continue
|
||||
|
||||
urgency = self._assess_news_urgency(item.get('title', ''), item.get('description', ''))
|
||||
|
||||
news_items.append(NewsItem(
|
||||
title=item.get('title', ''),
|
||||
content=item.get('description', ''),
|
||||
source=item.get('source', {}).get('name', 'NewsAPI'),
|
||||
publish_time=publish_time,
|
||||
url=item.get('url', ''),
|
||||
urgency=urgency,
|
||||
relevance_score=self._calculate_relevance(item.get('title', ''), ticker)
|
||||
))
|
||||
|
||||
return news_items
|
||||
|
||||
except Exception as e:
|
||||
print(f"NewsAPI新闻获取失败: {e}")
|
||||
return []
|
||||
|
||||
def _get_chinese_finance_news(self, ticker: str, hours_back: int) -> List[NewsItem]:
|
||||
"""获取中文财经新闻"""
|
||||
# 这里可以集成中文财经新闻API
|
||||
# 例如:财联社、新浪财经、东方财富等
|
||||
|
||||
try:
|
||||
# 示例:集成财联社API (需要申请)
|
||||
# 或者使用RSS源
|
||||
news_items = []
|
||||
|
||||
# 财联社RSS (如果可用)
|
||||
rss_sources = [
|
||||
"https://www.cls.cn/api/sw?app=CailianpressWeb&os=web&sv=7.7.5",
|
||||
# 可以添加更多RSS源
|
||||
]
|
||||
|
||||
for rss_url in rss_sources:
|
||||
try:
|
||||
items = self._parse_rss_feed(rss_url, ticker, hours_back)
|
||||
news_items.extend(items)
|
||||
except:
|
||||
continue
|
||||
|
||||
return news_items
|
||||
|
||||
except Exception as e:
|
||||
print(f"中文财经新闻获取失败: {e}")
|
||||
return []
|
||||
|
||||
def _parse_rss_feed(self, rss_url: str, ticker: str, hours_back: int) -> List[NewsItem]:
|
||||
"""解析RSS源"""
|
||||
# 简化实现,实际需要使用feedparser库
|
||||
return []
|
||||
|
||||
def _assess_news_urgency(self, title: str, content: str) -> str:
|
||||
"""评估新闻紧急程度"""
|
||||
text = (title + ' ' + content).lower()
|
||||
|
||||
# 高紧急度关键词
|
||||
high_urgency_keywords = [
|
||||
'breaking', 'urgent', 'alert', 'emergency', 'halt', 'suspend',
|
||||
'突发', '紧急', '暂停', '停牌', '重大'
|
||||
]
|
||||
|
||||
# 中等紧急度关键词
|
||||
medium_urgency_keywords = [
|
||||
'earnings', 'report', 'announce', 'launch', 'merger', 'acquisition',
|
||||
'财报', '发布', '宣布', '并购', '收购'
|
||||
]
|
||||
|
||||
if any(keyword in text for keyword in high_urgency_keywords):
|
||||
return 'high'
|
||||
elif any(keyword in text for keyword in medium_urgency_keywords):
|
||||
return 'medium'
|
||||
else:
|
||||
return 'low'
|
||||
|
||||
def _calculate_relevance(self, title: str, ticker: str) -> float:
|
||||
"""计算新闻相关性分数"""
|
||||
text = title.lower()
|
||||
ticker_lower = ticker.lower()
|
||||
|
||||
# 基础相关性
|
||||
if ticker_lower in text:
|
||||
return 1.0
|
||||
|
||||
# 公司名称匹配
|
||||
company_names = {
|
||||
'aapl': ['apple', 'iphone', 'ipad', 'mac'],
|
||||
'tsla': ['tesla', 'elon musk', 'electric vehicle'],
|
||||
'nvda': ['nvidia', 'gpu', 'ai chip'],
|
||||
'msft': ['microsoft', 'windows', 'azure'],
|
||||
'googl': ['google', 'alphabet', 'search']
|
||||
}
|
||||
|
||||
if ticker_lower in company_names:
|
||||
for name in company_names[ticker_lower]:
|
||||
if name in text:
|
||||
return 0.8
|
||||
|
||||
return 0.3 # 默认相关性
|
||||
|
||||
def _deduplicate_news(self, news_items: List[NewsItem]) -> List[NewsItem]:
|
||||
"""去重新闻"""
|
||||
seen_titles = set()
|
||||
unique_news = []
|
||||
|
||||
for item in news_items:
|
||||
# 简单的标题去重
|
||||
title_key = item.title.lower().strip()
|
||||
if title_key not in seen_titles and len(title_key) > 10:
|
||||
seen_titles.add(title_key)
|
||||
unique_news.append(item)
|
||||
|
||||
return unique_news
|
||||
|
||||
def format_news_report(self, news_items: List[NewsItem], ticker: str) -> str:
|
||||
"""格式化新闻报告"""
|
||||
if not news_items:
|
||||
return f"未获取到{ticker}的实时新闻数据。"
|
||||
|
||||
# 按紧急程度分组
|
||||
high_urgency = [n for n in news_items if n.urgency == 'high']
|
||||
medium_urgency = [n for n in news_items if n.urgency == 'medium']
|
||||
low_urgency = [n for n in news_items if n.urgency == 'low']
|
||||
|
||||
report = f"# {ticker} 实时新闻分析报告\n\n"
|
||||
report += f"📅 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
|
||||
report += f"📊 新闻总数: {len(news_items)}条\n\n"
|
||||
|
||||
if high_urgency:
|
||||
report += "## 🚨 紧急新闻\n\n"
|
||||
for news in high_urgency[:3]: # 最多显示3条
|
||||
report += f"### {news.title}\n"
|
||||
report += f"**来源**: {news.source} | **时间**: {news.publish_time.strftime('%H:%M')}\n"
|
||||
report += f"{news.content}\n\n"
|
||||
|
||||
if medium_urgency:
|
||||
report += "## 📢 重要新闻\n\n"
|
||||
for news in medium_urgency[:5]: # 最多显示5条
|
||||
report += f"### {news.title}\n"
|
||||
report += f"**来源**: {news.source} | **时间**: {news.publish_time.strftime('%H:%M')}\n"
|
||||
report += f"{news.content}\n\n"
|
||||
|
||||
# 添加时效性说明
|
||||
latest_news = max(news_items, key=lambda x: x.publish_time)
|
||||
time_diff = datetime.now() - latest_news.publish_time
|
||||
|
||||
report += f"\n## ⏰ 数据时效性\n"
|
||||
report += f"最新新闻发布于: {time_diff.total_seconds() / 60:.0f}分钟前\n"
|
||||
|
||||
if time_diff.total_seconds() < 1800: # 30分钟内
|
||||
report += "🟢 数据时效性: 优秀 (30分钟内)\n"
|
||||
elif time_diff.total_seconds() < 3600: # 1小时内
|
||||
report += "🟡 数据时效性: 良好 (1小时内)\n"
|
||||
else:
|
||||
report += "🔴 数据时效性: 一般 (超过1小时)\n"
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def get_realtime_stock_news(ticker: str, curr_date: str, hours_back: int = 6) -> str:
|
||||
"""
|
||||
获取实时股票新闻的主要接口函数
|
||||
"""
|
||||
aggregator = RealtimeNewsAggregator()
|
||||
|
||||
try:
|
||||
# 获取实时新闻
|
||||
news_items = aggregator.get_realtime_stock_news(ticker, hours_back)
|
||||
|
||||
# 格式化报告
|
||||
report = aggregator.format_news_report(news_items, ticker)
|
||||
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
return f"""
|
||||
实时新闻获取失败 - {ticker}
|
||||
分析日期: {curr_date}
|
||||
|
||||
❌ 错误信息: {str(e)}
|
||||
|
||||
💡 备用建议:
|
||||
1. 检查API密钥配置 (FINNHUB_API_KEY, NEWSAPI_KEY)
|
||||
2. 使用基础新闻分析作为备选
|
||||
3. 关注官方财经媒体的最新报道
|
||||
4. 考虑使用专业金融终端获取实时新闻
|
||||
|
||||
注: 实时新闻获取依赖外部API服务的可用性。
|
||||
"""
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
股票数据API接口
|
||||
提供简单易用的股票数据获取接口,内置完整的降级机制
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
from .stock_data_service import get_stock_data_service
|
||||
|
||||
def get_stock_info(stock_code: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取单个股票的基础信息
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码(如 '000001')
|
||||
|
||||
Returns:
|
||||
Dict: 股票信息,包含code, name, market, category等字段
|
||||
如果获取失败,返回包含error字段的字典
|
||||
|
||||
Example:
|
||||
>>> info = get_stock_info('000001')
|
||||
>>> print(info['name']) # 输出: 平安银行
|
||||
"""
|
||||
service = get_stock_data_service()
|
||||
return service.get_stock_basic_info(stock_code)
|
||||
|
||||
def get_all_stocks() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取所有股票列表
|
||||
|
||||
Returns:
|
||||
List[Dict]: 股票列表,每个元素包含股票基础信息
|
||||
如果获取失败,返回包含error字段的字典
|
||||
|
||||
Example:
|
||||
>>> stocks = get_all_stocks()
|
||||
>>> print(f"共有{len(stocks)}只股票")
|
||||
"""
|
||||
service = get_stock_data_service()
|
||||
result = service.get_stock_basic_info()
|
||||
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
elif isinstance(result, dict) and 'error' in result:
|
||||
return [result] # 返回错误信息
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_stock_data(stock_code: str, start_date: str, end_date: str) -> str:
|
||||
"""
|
||||
获取股票历史数据(带降级机制)
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
start_date: 开始日期 'YYYY-MM-DD'
|
||||
end_date: 结束日期 'YYYY-MM-DD'
|
||||
|
||||
Returns:
|
||||
str: 格式化的股票数据报告
|
||||
|
||||
Example:
|
||||
>>> data = get_stock_data('000001', '2024-01-01', '2024-01-31')
|
||||
>>> print(data)
|
||||
"""
|
||||
service = get_stock_data_service()
|
||||
return service.get_stock_data_with_fallback(stock_code, start_date, end_date)
|
||||
|
||||
def search_stocks_by_name(name: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
根据股票名称搜索股票(需要MongoDB支持)
|
||||
|
||||
Args:
|
||||
name: 股票名称关键词
|
||||
|
||||
Returns:
|
||||
List[Dict]: 匹配的股票列表
|
||||
|
||||
Example:
|
||||
>>> results = search_stocks_by_name('银行')
|
||||
>>> for stock in results:
|
||||
... print(f"{stock['code']}: {stock['name']}")
|
||||
"""
|
||||
# 这个功能需要MongoDB支持,暂时通过原有方式实现
|
||||
try:
|
||||
from ..examples.stock_query_examples import EnhancedStockQueryService
|
||||
service = EnhancedStockQueryService()
|
||||
return service.query_stocks_by_name(name)
|
||||
except Exception as e:
|
||||
return [{'error': f'名称搜索功能不可用: {str(e)}'}]
|
||||
|
||||
def check_data_sources() -> Dict[str, Any]:
|
||||
"""
|
||||
检查数据源状态
|
||||
|
||||
Returns:
|
||||
Dict: 各数据源的可用状态
|
||||
|
||||
Example:
|
||||
>>> status = check_data_sources()
|
||||
>>> print(f"MongoDB可用: {status['mongodb_available']}")
|
||||
>>> print(f"通达信API可用: {status['tdx_api_available']}")
|
||||
"""
|
||||
service = get_stock_data_service()
|
||||
|
||||
return {
|
||||
'mongodb_available': service.db_manager is not None and service.db_manager.mongodb_db is not None,
|
||||
'tdx_api_available': service.tdx_provider is not None,
|
||||
'enhanced_fetcher_available': True, # 这个通常都可用
|
||||
'fallback_mode': service.db_manager is None or service.db_manager.mongodb_db is None,
|
||||
'recommendation': (
|
||||
"所有数据源正常" if service.db_manager and service.db_manager.mongodb_db
|
||||
else "建议配置MongoDB以获得最佳性能,当前使用通达信API降级模式"
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,279 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
统一的股票数据获取服务
|
||||
实现MongoDB -> 通达信API的完整降级机制
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
try:
|
||||
from tradingagents.config.database_manager import get_database_manager
|
||||
DATABASE_MANAGER_AVAILABLE = True
|
||||
except ImportError:
|
||||
DATABASE_MANAGER_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from .tdx_utils import get_tdx_provider, TongDaXinDataProvider
|
||||
TDX_AVAILABLE = True
|
||||
except ImportError:
|
||||
TDX_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import sys
|
||||
import os
|
||||
# 添加utils目录到路径
|
||||
utils_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'utils')
|
||||
if utils_path not in sys.path:
|
||||
sys.path.append(utils_path)
|
||||
from enhanced_stock_list_fetcher import enhanced_fetch_stock_list
|
||||
ENHANCED_FETCHER_AVAILABLE = True
|
||||
except ImportError:
|
||||
ENHANCED_FETCHER_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class StockDataService:
|
||||
"""
|
||||
统一的股票数据获取服务
|
||||
实现完整的降级机制:MongoDB -> 通达信API -> 缓存 -> 错误处理
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.db_manager = None
|
||||
self.tdx_provider = None
|
||||
self._init_services()
|
||||
|
||||
def _init_services(self):
|
||||
"""初始化服务"""
|
||||
# 尝试初始化数据库管理器
|
||||
if DATABASE_MANAGER_AVAILABLE:
|
||||
try:
|
||||
self.db_manager = get_database_manager()
|
||||
if self.db_manager.is_mongodb_available():
|
||||
print("✅ MongoDB连接成功")
|
||||
else:
|
||||
print("⚠️ MongoDB连接失败,将使用通达信API")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 数据库管理器初始化失败: {e}")
|
||||
self.db_manager = None
|
||||
|
||||
# 尝试初始化通达信提供器
|
||||
if TDX_AVAILABLE:
|
||||
try:
|
||||
self.tdx_provider = get_tdx_provider()
|
||||
print("✅ 通达信API初始化成功")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 通达信API初始化失败: {e}")
|
||||
self.tdx_provider = None
|
||||
|
||||
def get_stock_basic_info(self, stock_code: str = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取股票基础信息(单个股票或全部股票)
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码,如果为None则返回所有股票
|
||||
|
||||
Returns:
|
||||
Dict: 股票基础信息
|
||||
"""
|
||||
print(f"📊 获取股票基础信息: {stock_code or '全部股票'}")
|
||||
|
||||
# 1. 优先从MongoDB获取
|
||||
if self.db_manager and self.db_manager.is_mongodb_available():
|
||||
try:
|
||||
result = self._get_from_mongodb(stock_code)
|
||||
if result:
|
||||
print(f"✅ 从MongoDB获取成功: {len(result) if isinstance(result, list) else 1}条记录")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"⚠️ MongoDB查询失败: {e}")
|
||||
|
||||
# 2. 降级到通达信API
|
||||
print("🔄 MongoDB不可用,降级到通达信API")
|
||||
if ENHANCED_FETCHER_AVAILABLE:
|
||||
try:
|
||||
result = self._get_from_tdx_api(stock_code)
|
||||
if result:
|
||||
print(f"✅ 从通达信API获取成功: {len(result) if isinstance(result, list) else 1}条记录")
|
||||
# 尝试缓存到MongoDB(如果可用)
|
||||
self._cache_to_mongodb(result)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"⚠️ 通达信API查询失败: {e}")
|
||||
|
||||
# 3. 最后的降级方案
|
||||
print("❌ 所有数据源都不可用")
|
||||
return self._get_fallback_data(stock_code)
|
||||
|
||||
def _get_from_mongodb(self, stock_code: str = None) -> Optional[Dict[str, Any]]:
|
||||
"""从MongoDB获取数据"""
|
||||
try:
|
||||
mongodb_client = self.db_manager.get_mongodb_client()
|
||||
if not mongodb_client:
|
||||
return None
|
||||
|
||||
db = mongodb_client[self.db_manager.mongodb_config["database"]]
|
||||
collection = db['stock_basic_info']
|
||||
|
||||
if stock_code:
|
||||
# 获取单个股票
|
||||
result = collection.find_one({'code': stock_code})
|
||||
return result if result else None
|
||||
else:
|
||||
# 获取所有股票
|
||||
cursor = collection.find({})
|
||||
results = list(cursor)
|
||||
return results if results else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MongoDB查询失败: {e}")
|
||||
return None
|
||||
|
||||
def _get_from_tdx_api(self, stock_code: str = None) -> Optional[Dict[str, Any]]:
|
||||
"""从通达信API获取数据"""
|
||||
try:
|
||||
if stock_code:
|
||||
# 获取单个股票信息
|
||||
if self.tdx_provider:
|
||||
# 使用现有的股票名称获取方法
|
||||
stock_name = self.tdx_provider._get_stock_name(stock_code)
|
||||
return {
|
||||
'code': stock_code,
|
||||
'name': stock_name,
|
||||
'market': self._get_market_name(stock_code),
|
||||
'category': self._get_stock_category(stock_code),
|
||||
'source': 'tdx_api',
|
||||
'updated_at': datetime.now().isoformat()
|
||||
}
|
||||
else:
|
||||
# 获取所有股票列表
|
||||
stock_df = enhanced_fetch_stock_list(
|
||||
type_='stock',
|
||||
enable_server_failover=True,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
if stock_df is not None and not stock_df.empty:
|
||||
# 转换为字典列表
|
||||
results = []
|
||||
for _, row in stock_df.iterrows():
|
||||
results.append({
|
||||
'code': row.get('code', ''),
|
||||
'name': row.get('name', ''),
|
||||
'market': row.get('market', ''),
|
||||
'category': row.get('category', ''),
|
||||
'source': 'tdx_api',
|
||||
'updated_at': datetime.now().isoformat()
|
||||
})
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"通达信API查询失败: {e}")
|
||||
return None
|
||||
|
||||
def _cache_to_mongodb(self, data: Any) -> bool:
|
||||
"""将数据缓存到MongoDB"""
|
||||
if not self.db_manager or not self.db_manager.mongodb_db:
|
||||
return False
|
||||
|
||||
try:
|
||||
collection = self.db_manager.mongodb_db['stock_basic_info']
|
||||
|
||||
if isinstance(data, list):
|
||||
# 批量插入
|
||||
for item in data:
|
||||
collection.update_one(
|
||||
{'code': item['code']},
|
||||
{'$set': item},
|
||||
upsert=True
|
||||
)
|
||||
print(f"💾 已缓存{len(data)}条记录到MongoDB")
|
||||
elif isinstance(data, dict):
|
||||
# 单条插入
|
||||
collection.update_one(
|
||||
{'code': data['code']},
|
||||
{'$set': data},
|
||||
upsert=True
|
||||
)
|
||||
print(f"💾 已缓存股票{data['code']}到MongoDB")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"缓存到MongoDB失败: {e}")
|
||||
return False
|
||||
|
||||
def _get_fallback_data(self, stock_code: str = None) -> Dict[str, Any]:
|
||||
"""最后的降级数据"""
|
||||
if stock_code:
|
||||
return {
|
||||
'code': stock_code,
|
||||
'name': f'股票{stock_code}',
|
||||
'market': self._get_market_name(stock_code),
|
||||
'category': '未知',
|
||||
'source': 'fallback',
|
||||
'updated_at': datetime.now().isoformat(),
|
||||
'error': '所有数据源都不可用'
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'error': '无法获取股票列表,请检查网络连接和数据库配置',
|
||||
'suggestion': '请确保MongoDB已配置或网络连接正常以访问通达信API'
|
||||
}
|
||||
|
||||
def _get_market_name(self, stock_code: str) -> str:
|
||||
"""根据股票代码判断市场"""
|
||||
if stock_code.startswith(('60', '68', '90')):
|
||||
return '上海'
|
||||
elif stock_code.startswith(('00', '30', '20')):
|
||||
return '深圳'
|
||||
else:
|
||||
return '未知'
|
||||
|
||||
def _get_stock_category(self, stock_code: str) -> str:
|
||||
"""根据股票代码判断类别"""
|
||||
if stock_code.startswith('60'):
|
||||
return '沪市主板'
|
||||
elif stock_code.startswith('68'):
|
||||
return '科创板'
|
||||
elif stock_code.startswith('00'):
|
||||
return '深市主板'
|
||||
elif stock_code.startswith('30'):
|
||||
return '创业板'
|
||||
elif stock_code.startswith('20'):
|
||||
return '深市B股'
|
||||
else:
|
||||
return '其他'
|
||||
|
||||
def get_stock_data_with_fallback(self, stock_code: str, start_date: str, end_date: str) -> str:
|
||||
"""
|
||||
获取股票数据(带降级机制)
|
||||
这是对现有get_china_stock_data函数的增强
|
||||
"""
|
||||
print(f"📊 获取股票数据: {stock_code} ({start_date} 到 {end_date})")
|
||||
|
||||
# 首先确保股票基础信息可用
|
||||
stock_info = self.get_stock_basic_info(stock_code)
|
||||
if stock_info and 'error' in stock_info:
|
||||
return f"❌ 无法获取股票{stock_code}的基础信息: {stock_info.get('error', '未知错误')}"
|
||||
|
||||
# 调用现有的get_china_stock_data函数
|
||||
try:
|
||||
from .tdx_utils import get_china_stock_data
|
||||
return get_china_stock_data(stock_code, start_date, end_date)
|
||||
except Exception as e:
|
||||
return f"❌ 获取股票数据失败: {str(e)}\n\n💡 建议:\n1. 检查网络连接\n2. 确认股票代码格式正确\n3. 检查MongoDB配置"
|
||||
|
||||
# 全局服务实例
|
||||
_stock_data_service = None
|
||||
|
||||
def get_stock_data_service() -> StockDataService:
|
||||
"""获取股票数据服务实例(单例模式)"""
|
||||
global _stock_data_service
|
||||
if _stock_data_service is None:
|
||||
_stock_data_service = StockDataService()
|
||||
return _stock_data_service
|
||||
|
|
@ -0,0 +1,856 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
通达信API数据获取工具
|
||||
支持A股、港股实时数据和历史数据
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# 导入数据库管理器
|
||||
try:
|
||||
from tradingagents.config.database_manager import get_database_manager
|
||||
DB_MANAGER_AVAILABLE = True
|
||||
except ImportError:
|
||||
DB_MANAGER_AVAILABLE = False
|
||||
print("⚠️ 数据库缓存管理器不可用,尝试文件缓存")
|
||||
|
||||
# 导入MongoDB股票信息查询
|
||||
try:
|
||||
import os
|
||||
from pymongo import MongoClient
|
||||
MONGODB_AVAILABLE = True
|
||||
except ImportError:
|
||||
MONGODB_AVAILABLE = False
|
||||
print("⚠️ pymongo未安装,无法从MongoDB获取股票名称")
|
||||
|
||||
try:
|
||||
from .cache_manager import get_cache
|
||||
FILE_CACHE_AVAILABLE = True
|
||||
except ImportError:
|
||||
FILE_CACHE_AVAILABLE = False
|
||||
print("⚠️ 文件缓存管理器不可用,将直接从API获取数据")
|
||||
|
||||
try:
|
||||
# 通达信Python接口
|
||||
import pytdx
|
||||
from pytdx.hq import TdxHq_API
|
||||
from pytdx.exhq import TdxExHq_API
|
||||
TDX_AVAILABLE = True
|
||||
except ImportError:
|
||||
TDX_AVAILABLE = False
|
||||
print("⚠️ pytdx库未安装,无法使用通达信API")
|
||||
print("💡 安装命令: pip install pytdx")
|
||||
|
||||
|
||||
class TongDaXinDataProvider:
|
||||
"""通达信数据提供器"""
|
||||
|
||||
def __init__(self):
|
||||
print(f"🔍 [DEBUG] 初始化通达信数据提供器...")
|
||||
self.api = None
|
||||
self.exapi = None # 扩展行情API
|
||||
self.connected = False
|
||||
|
||||
print(f"🔍 [DEBUG] 检查pytdx库可用性: {TDX_AVAILABLE}")
|
||||
if not TDX_AVAILABLE:
|
||||
error_msg = "pytdx库未安装,请运行: pip install pytdx"
|
||||
print(f"❌ [DEBUG] {error_msg}")
|
||||
raise ImportError(error_msg)
|
||||
print(f"✅ [DEBUG] pytdx库检查通过")
|
||||
|
||||
def connect(self):
|
||||
"""连接通达信服务器"""
|
||||
print(f"🔍 [DEBUG] 开始连接通达信服务器...")
|
||||
try:
|
||||
# 尝试从配置文件加载可用服务器
|
||||
print(f"🔍 [DEBUG] 加载服务器配置...")
|
||||
working_servers = self._load_working_servers()
|
||||
|
||||
# 如果没有配置文件,使用默认服务器列表
|
||||
if not working_servers:
|
||||
print(f"🔍 [DEBUG] 未找到配置文件,使用默认服务器列表")
|
||||
working_servers = [
|
||||
{'ip': '115.238.56.198', 'port': 7709},
|
||||
{'ip': '115.238.90.165', 'port': 7709},
|
||||
{'ip': '180.153.18.170', 'port': 7709},
|
||||
{'ip': '119.147.212.81', 'port': 7709}, # 备用
|
||||
]
|
||||
else:
|
||||
print(f"🔍 [DEBUG] 从配置文件加载了 {len(working_servers)} 个服务器")
|
||||
|
||||
# 尝试连接可用服务器
|
||||
print(f"🔍 [DEBUG] 创建通达信API实例...")
|
||||
self.api = TdxHq_API()
|
||||
print(f"🔍 [DEBUG] 开始尝试连接服务器...")
|
||||
|
||||
for i, server in enumerate(working_servers):
|
||||
try:
|
||||
print(f"🔍 [DEBUG] 尝试连接服务器 {i+1}/{len(working_servers)}: {server['ip']}:{server['port']}")
|
||||
result = self.api.connect(server['ip'], server['port'])
|
||||
print(f"🔍 [DEBUG] 连接结果: {result}")
|
||||
if result:
|
||||
print(f"✅ 通达信API连接成功: {server['ip']}:{server['port']}")
|
||||
self.connected = True
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"⚠️ 服务器 {server['ip']}:{server['port']} 连接失败: {e}")
|
||||
continue
|
||||
|
||||
print("❌ 所有通达信服务器连接失败")
|
||||
self.connected = False
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 通达信API连接失败: {e}")
|
||||
self.connected = False
|
||||
return False
|
||||
|
||||
def _load_working_servers(self):
|
||||
"""加载可用服务器配置"""
|
||||
try:
|
||||
import json
|
||||
import os
|
||||
|
||||
config_file = 'tdx_servers_config.json'
|
||||
if os.path.exists(config_file):
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
return config.get('working_servers', [])
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
def disconnect(self):
|
||||
"""断开连接"""
|
||||
try:
|
||||
if self.api:
|
||||
self.api.disconnect()
|
||||
if self.exapi:
|
||||
self.exapi.disconnect()
|
||||
self.connected = False
|
||||
print("✅ 通达信API连接已断开")
|
||||
except:
|
||||
pass
|
||||
|
||||
def is_connected(self):
|
||||
"""检查连接状态"""
|
||||
if not self.connected or not self.api:
|
||||
return False
|
||||
|
||||
# 尝试简单的API调用来验证连接是否有效
|
||||
try:
|
||||
# 获取市场信息作为连接测试
|
||||
result = self.api.get_security_count(0) # 获取深圳市场股票数量
|
||||
return result is not None and result > 0
|
||||
except Exception as e:
|
||||
print(f"🔍 [DEBUG] 连接测试失败: {e}")
|
||||
self.connected = False
|
||||
return False
|
||||
|
||||
def _get_stock_name(self, stock_code: str) -> str:
|
||||
"""
|
||||
获取股票名称
|
||||
优先级:缓存 -> MongoDB -> 常用股票映射 -> API获取(仅深圳市场) -> 默认格式
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
Returns:
|
||||
str: 股票名称
|
||||
"""
|
||||
global _stock_name_cache
|
||||
|
||||
# 首先检查缓存
|
||||
if stock_code in _stock_name_cache:
|
||||
return _stock_name_cache[stock_code]
|
||||
|
||||
# 优先从MongoDB获取
|
||||
mongodb_name = _get_stock_name_from_mongodb(stock_code)
|
||||
if mongodb_name:
|
||||
_stock_name_cache[stock_code] = mongodb_name
|
||||
return mongodb_name
|
||||
|
||||
# 检查常用股票映射表
|
||||
if stock_code in _common_stock_names:
|
||||
name = _common_stock_names[stock_code]
|
||||
_stock_name_cache[stock_code] = name
|
||||
return name
|
||||
|
||||
# 如果API不可用,直接返回默认格式
|
||||
if not self.connected:
|
||||
if not self.connect():
|
||||
default_name = f'股票{stock_code}'
|
||||
_stock_name_cache[stock_code] = default_name
|
||||
return default_name
|
||||
|
||||
try:
|
||||
# 仅对深圳市场尝试从API获取(上海市场的get_security_list不可用)
|
||||
market = self._get_market_code(stock_code)
|
||||
if market == 0: # 深圳市场
|
||||
try:
|
||||
for start_pos in range(0, 2000, 1000): # 分批获取
|
||||
stock_list = self.api.get_security_list(market, start_pos)
|
||||
if stock_list:
|
||||
for stock_info in stock_list:
|
||||
if stock_info.get('code') == stock_code:
|
||||
stock_name = stock_info.get('name', '').strip()
|
||||
if stock_name:
|
||||
_stock_name_cache[stock_code] = stock_name
|
||||
return stock_name
|
||||
except Exception as e:
|
||||
print(f"⚠️ 获取深圳股票列表失败: {e}")
|
||||
|
||||
# 如果都失败了,返回默认格式并缓存
|
||||
default_name = f'股票{stock_code}'
|
||||
_stock_name_cache[stock_code] = default_name
|
||||
return default_name
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ 获取股票名称失败: {e}")
|
||||
default_name = f'股票{stock_code}'
|
||||
_stock_name_cache[stock_code] = default_name
|
||||
return default_name
|
||||
|
||||
def get_real_time_data(self, stock_code: str) -> Dict:
|
||||
"""
|
||||
获取股票实时数据
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
Returns:
|
||||
Dict: 实时数据
|
||||
"""
|
||||
if not self.connected:
|
||||
if not self.connect():
|
||||
return {}
|
||||
|
||||
try:
|
||||
market = self._get_market_code(stock_code)
|
||||
|
||||
# 获取实时数据
|
||||
data = self.api.get_security_quotes([(market, stock_code)])
|
||||
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
quote = data[0]
|
||||
|
||||
# 安全获取字段,避免KeyError
|
||||
def safe_get(key, default=0):
|
||||
return quote.get(key, default)
|
||||
|
||||
return {
|
||||
'code': stock_code,
|
||||
'name': self._get_stock_name(stock_code), # 使用独立的股票名称获取方法
|
||||
'price': safe_get('price'),
|
||||
'last_close': safe_get('last_close'),
|
||||
'open': safe_get('open'),
|
||||
'high': safe_get('high'),
|
||||
'low': safe_get('low'),
|
||||
'volume': safe_get('vol'),
|
||||
'amount': safe_get('amount'),
|
||||
'change': safe_get('price') - safe_get('last_close'),
|
||||
'change_percent': ((safe_get('price') - safe_get('last_close')) / safe_get('last_close') * 100) if safe_get('last_close') > 0 else 0,
|
||||
'bid_prices': [safe_get(f'bid{i}') for i in range(1, 6)],
|
||||
'bid_volumes': [safe_get(f'bid_vol{i}') for i in range(1, 6)],
|
||||
'ask_prices': [safe_get(f'ask{i}') for i in range(1, 6)],
|
||||
'ask_volumes': [safe_get(f'ask_vol{i}') for i in range(1, 6)],
|
||||
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取实时数据失败: {e}")
|
||||
return {}
|
||||
|
||||
def get_stock_history_data(self, stock_code: str, start_date: str, end_date: str, period: str = 'D') -> pd.DataFrame:
|
||||
"""
|
||||
获取股票历史数据
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
start_date: 开始日期 'YYYY-MM-DD'
|
||||
end_date: 结束日期 'YYYY-MM-DD'
|
||||
period: 周期 'D'=日线, 'W'=周线, 'M'=月线
|
||||
Returns:
|
||||
DataFrame: 历史数据
|
||||
"""
|
||||
if not self.connected:
|
||||
if not self.connect():
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
market = self._get_market_code(stock_code)
|
||||
|
||||
# 计算需要获取的数据量
|
||||
start_dt = datetime.strptime(start_date, '%Y-%m-%d')
|
||||
end_dt = datetime.strptime(end_date, '%Y-%m-%d')
|
||||
days_diff = (end_dt - start_dt).days
|
||||
|
||||
# 根据周期调整数据量
|
||||
if period == 'D':
|
||||
count = min(days_diff + 10, 800) # 日线最多800条
|
||||
elif period == 'W':
|
||||
count = min(days_diff // 7 + 10, 800)
|
||||
elif period == 'M':
|
||||
count = min(days_diff // 30 + 10, 800)
|
||||
else:
|
||||
count = 800
|
||||
|
||||
# 获取K线数据
|
||||
category_map = {'D': 9, 'W': 5, 'M': 6}
|
||||
category = category_map.get(period, 9)
|
||||
|
||||
data = self.api.get_security_bars(category, market, stock_code, 0, count)
|
||||
|
||||
if not data:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 转换为DataFrame
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# 处理数据格式
|
||||
df['datetime'] = pd.to_datetime(df['datetime'])
|
||||
df = df.set_index('datetime')
|
||||
df = df.sort_index()
|
||||
|
||||
# 筛选日期范围
|
||||
df = df[start_date:end_date]
|
||||
|
||||
# 重命名列以匹配Yahoo Finance格式
|
||||
df = df.rename(columns={
|
||||
'open': 'Open',
|
||||
'high': 'High',
|
||||
'low': 'Low',
|
||||
'close': 'Close',
|
||||
'vol': 'Volume',
|
||||
'amount': 'Amount'
|
||||
})
|
||||
|
||||
# 添加股票代码信息
|
||||
df['Symbol'] = stock_code
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取历史数据失败: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def get_stock_technical_indicators(self, stock_code: str, period: int = 20) -> Dict:
|
||||
"""
|
||||
计算技术指标
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
period: 计算周期
|
||||
Returns:
|
||||
Dict: 技术指标数据
|
||||
"""
|
||||
try:
|
||||
# 获取最近的历史数据
|
||||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||||
start_date = (datetime.now() - timedelta(days=period*2)).strftime('%Y-%m-%d')
|
||||
|
||||
df = self.get_stock_history_data(stock_code, start_date, end_date)
|
||||
|
||||
if df.empty:
|
||||
return {}
|
||||
|
||||
# 计算技术指标
|
||||
indicators = {}
|
||||
|
||||
# 移动平均线
|
||||
indicators['MA5'] = df['Close'].rolling(5).mean().iloc[-1] if len(df) >= 5 else None
|
||||
indicators['MA10'] = df['Close'].rolling(10).mean().iloc[-1] if len(df) >= 10 else None
|
||||
indicators['MA20'] = df['Close'].rolling(20).mean().iloc[-1] if len(df) >= 20 else None
|
||||
|
||||
# RSI
|
||||
if len(df) >= 14:
|
||||
delta = df['Close'].diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(14).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
|
||||
rs = gain / loss
|
||||
indicators['RSI'] = (100 - (100 / (1 + rs))).iloc[-1]
|
||||
|
||||
# MACD
|
||||
if len(df) >= 26:
|
||||
exp1 = df['Close'].ewm(span=12).mean()
|
||||
exp2 = df['Close'].ewm(span=26).mean()
|
||||
macd = exp1 - exp2
|
||||
signal = macd.ewm(span=9).mean()
|
||||
indicators['MACD'] = macd.iloc[-1]
|
||||
indicators['MACD_Signal'] = signal.iloc[-1]
|
||||
indicators['MACD_Histogram'] = (macd - signal).iloc[-1]
|
||||
|
||||
# 布林带
|
||||
if len(df) >= 20:
|
||||
sma = df['Close'].rolling(20).mean()
|
||||
std = df['Close'].rolling(20).std()
|
||||
indicators['BB_Upper'] = (sma + 2 * std).iloc[-1]
|
||||
indicators['BB_Middle'] = sma.iloc[-1]
|
||||
indicators['BB_Lower'] = (sma - 2 * std).iloc[-1]
|
||||
|
||||
return indicators
|
||||
|
||||
except Exception as e:
|
||||
print(f"计算技术指标失败: {e}")
|
||||
return {}
|
||||
|
||||
def search_stocks(self, keyword: str) -> List[Dict]:
|
||||
"""
|
||||
搜索股票
|
||||
Args:
|
||||
keyword: 搜索关键词(股票代码或名称)
|
||||
Returns:
|
||||
List[Dict]: 搜索结果
|
||||
"""
|
||||
if not self.connected:
|
||||
if not self.connect():
|
||||
return []
|
||||
|
||||
try:
|
||||
# 通达信没有直接的搜索API,这里提供一个简化的实现
|
||||
# 实际使用中可以维护一个股票代码表
|
||||
|
||||
# 常见股票代码映射
|
||||
stock_mapping = {
|
||||
'平安银行': '000001',
|
||||
'万科A': '000002',
|
||||
'中国平安': '601318',
|
||||
'贵州茅台': '600519',
|
||||
'招商银行': '600036',
|
||||
'五粮液': '000858',
|
||||
'格力电器': '000651',
|
||||
'美的集团': '000333',
|
||||
'中国石化': '600028',
|
||||
'工商银行': '601398'
|
||||
}
|
||||
|
||||
results = []
|
||||
|
||||
# 按关键词搜索
|
||||
for name, code in stock_mapping.items():
|
||||
if keyword.lower() in name.lower() or keyword in code:
|
||||
# 获取实时数据
|
||||
realtime_data = self.get_real_time_data(code)
|
||||
if realtime_data:
|
||||
results.append({
|
||||
'code': code,
|
||||
'name': name,
|
||||
'price': realtime_data.get('price', 0),
|
||||
'change_percent': realtime_data.get('change_percent', 0)
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索股票失败: {e}")
|
||||
return []
|
||||
|
||||
def _get_market_code(self, stock_code: str) -> int:
|
||||
"""
|
||||
根据股票代码判断市场
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
Returns:
|
||||
int: 市场代码 (0=深圳, 1=上海)
|
||||
"""
|
||||
if stock_code.startswith(('000', '002', '003', '300')):
|
||||
return 0 # 深圳
|
||||
elif stock_code.startswith(('600', '601', '603', '605', '688')):
|
||||
return 1 # 上海
|
||||
else:
|
||||
return 0 # 默认深圳
|
||||
|
||||
def get_market_overview(self) -> Dict:
|
||||
"""获取市场概览"""
|
||||
if not self.connected:
|
||||
if not self.connect():
|
||||
return {}
|
||||
|
||||
try:
|
||||
# 获取主要指数数据
|
||||
indices = {
|
||||
'上证指数': ('1', '000001'),
|
||||
'深证成指': ('0', '399001'),
|
||||
'创业板指': ('0', '399006'),
|
||||
'科创50': ('1', '000688')
|
||||
}
|
||||
|
||||
market_data = {}
|
||||
|
||||
for name, (market, code) in indices.items():
|
||||
try:
|
||||
data = self.api.get_security_quotes([(int(market), code)])
|
||||
if data:
|
||||
quote = data[0]
|
||||
market_data[name] = {
|
||||
'price': quote['price'],
|
||||
'change': quote['price'] - quote['last_close'],
|
||||
'change_percent': ((quote['price'] - quote['last_close']) / quote['last_close'] * 100) if quote['last_close'] > 0 else 0,
|
||||
'volume': quote['vol']
|
||||
}
|
||||
except:
|
||||
continue
|
||||
|
||||
return market_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取市场概览失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# 全局实例和缓存
|
||||
_tdx_provider = None
|
||||
_stock_name_cache = {} # 股票名称缓存,避免重复API调用
|
||||
_mongodb_client = None
|
||||
_mongodb_db = None
|
||||
|
||||
def _get_mongodb_connection():
|
||||
"""获取MongoDB连接"""
|
||||
global _mongodb_client, _mongodb_db
|
||||
|
||||
if not MONGODB_AVAILABLE:
|
||||
return None, None
|
||||
|
||||
if _mongodb_client is None or _mongodb_db is None:
|
||||
try:
|
||||
# 从环境变量获取MongoDB配置
|
||||
config = {
|
||||
'host': os.getenv('MONGODB_HOST', 'localhost'),
|
||||
'port': int(os.getenv('MONGODB_PORT', 27018)),
|
||||
'username': os.getenv('MONGODB_USERNAME'),
|
||||
'password': os.getenv('MONGODB_PASSWORD'),
|
||||
'database': os.getenv('MONGODB_DATABASE', 'tradingagents'),
|
||||
'auth_source': os.getenv('MONGODB_AUTH_SOURCE', 'admin')
|
||||
}
|
||||
|
||||
# 构建连接字符串
|
||||
if config.get('username') and config.get('password'):
|
||||
connection_string = f"mongodb://{config['username']}:{config['password']}@{config['host']}:{config['port']}/{config['auth_source']}"
|
||||
else:
|
||||
connection_string = f"mongodb://{config['host']}:{config['port']}/"
|
||||
|
||||
# 创建客户端
|
||||
_mongodb_client = MongoClient(
|
||||
connection_string,
|
||||
serverSelectionTimeoutMS=3000 # 3秒超时
|
||||
)
|
||||
|
||||
# 测试连接
|
||||
_mongodb_client.admin.command('ping')
|
||||
|
||||
# 选择数据库
|
||||
_mongodb_db = _mongodb_client[config['database']]
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ MongoDB连接失败: {e}")
|
||||
_mongodb_client = None
|
||||
_mongodb_db = None
|
||||
|
||||
return _mongodb_client, _mongodb_db
|
||||
|
||||
def _get_stock_name_from_mongodb(stock_code: str) -> Optional[str]:
|
||||
"""从MongoDB获取股票名称"""
|
||||
try:
|
||||
client, db = _get_mongodb_connection()
|
||||
if db is None:
|
||||
return None
|
||||
|
||||
collection = db['stock_basic_info']
|
||||
stock_info = collection.find_one({'code': stock_code})
|
||||
|
||||
if stock_info and 'name' in stock_info:
|
||||
return stock_info['name'].strip()
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ 从MongoDB获取股票名称失败: {e}")
|
||||
return None
|
||||
|
||||
# 精简的常用股票名称映射(仅包含最常见的股票)
|
||||
_common_stock_names = {
|
||||
# 深圳主板
|
||||
'000001': '平安银行',
|
||||
'000002': '万科A',
|
||||
'000858': '五粮液',
|
||||
'000895': '双汇发展',
|
||||
|
||||
# 深圳中小板
|
||||
'002594': '比亚迪',
|
||||
'002415': '海康威视',
|
||||
'002304': '洋河股份',
|
||||
|
||||
# 深圳创业板
|
||||
'300059': '东方财富',
|
||||
'300750': '宁德时代',
|
||||
'300015': '爱尔眼科',
|
||||
|
||||
# 上海主板
|
||||
'600519': '贵州茅台',
|
||||
'600036': '招商银行',
|
||||
'601398': '工商银行',
|
||||
'601127': '小康股份',
|
||||
'600000': '浦发银行',
|
||||
'601318': '中国平安',
|
||||
'600276': '恒瑞医药',
|
||||
'600887': '伊利股份',
|
||||
|
||||
# 科创板
|
||||
'688981': '中芯国际',
|
||||
'688599': '天合光能',
|
||||
}
|
||||
|
||||
def get_tdx_provider() -> TongDaXinDataProvider:
|
||||
"""获取通达信数据提供器实例"""
|
||||
global _tdx_provider
|
||||
if _tdx_provider is None:
|
||||
print(f"🔍 [DEBUG] 创建新的通达信数据提供器实例...")
|
||||
_tdx_provider = TongDaXinDataProvider()
|
||||
print(f"🔍 [DEBUG] 通达信数据提供器实例创建完成")
|
||||
else:
|
||||
print(f"🔍 [DEBUG] 使用现有的通达信数据提供器实例")
|
||||
# 检查连接状态,如果连接断开则重新创建
|
||||
if not _tdx_provider.is_connected():
|
||||
print(f"🔍 [DEBUG] 检测到连接断开,重新创建通达信数据提供器...")
|
||||
_tdx_provider = TongDaXinDataProvider()
|
||||
print(f"🔍 [DEBUG] 通达信数据提供器重新创建完成")
|
||||
return _tdx_provider
|
||||
|
||||
|
||||
def get_china_stock_data(stock_code: str, start_date: str, end_date: str) -> str:
|
||||
"""
|
||||
获取中国股票数据的主要接口函数(支持缓存)
|
||||
Args:
|
||||
stock_code: 股票代码 (如 '000001')
|
||||
start_date: 开始日期 'YYYY-MM-DD'
|
||||
end_date: 结束日期 'YYYY-MM-DD'
|
||||
Returns:
|
||||
str: 格式化的股票数据
|
||||
"""
|
||||
print(f"📊 正在获取中国股票数据: {stock_code} ({start_date} 到 {end_date})")
|
||||
|
||||
# 优先尝试从数据库缓存加载数据(使用统一的database_manager)
|
||||
try:
|
||||
from tradingagents.config.database_manager import get_database_manager
|
||||
db_manager = get_database_manager()
|
||||
if db_manager.is_mongodb_available():
|
||||
# 直接使用MongoDB客户端查询缓存数据
|
||||
mongodb_client = db_manager.get_mongodb_client()
|
||||
if mongodb_client:
|
||||
db = mongodb_client[db_manager.mongodb_config["database"]]
|
||||
collection = db.stock_data
|
||||
|
||||
# 查询最近的缓存数据
|
||||
from datetime import datetime, timedelta
|
||||
cutoff_time = datetime.utcnow() - timedelta(hours=6)
|
||||
|
||||
cached_doc = collection.find_one({
|
||||
"symbol": stock_code,
|
||||
"market_type": "china",
|
||||
"created_at": {"$gte": cutoff_time}
|
||||
}, sort=[("created_at", -1)])
|
||||
|
||||
if cached_doc and 'data' in cached_doc:
|
||||
print(f"🗄️ 从MongoDB缓存加载数据: {stock_code}")
|
||||
return cached_doc['data']
|
||||
except Exception as e:
|
||||
print(f"⚠️ 从MongoDB加载缓存失败: {e}")
|
||||
|
||||
# 如果数据库缓存不可用,尝试文件缓存
|
||||
if FILE_CACHE_AVAILABLE:
|
||||
cache = get_cache()
|
||||
cache_key = cache.find_cached_stock_data(
|
||||
symbol=stock_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
data_source="tdx",
|
||||
max_age_hours=6 # 6小时内的缓存有效
|
||||
)
|
||||
|
||||
if cache_key:
|
||||
cached_data = cache.load_stock_data(cache_key)
|
||||
if cached_data:
|
||||
print(f"💾 从文件缓存加载数据: {stock_code} -> {cache_key}")
|
||||
return cached_data
|
||||
|
||||
print(f"🌐 从通达信API获取数据: {stock_code}")
|
||||
|
||||
try:
|
||||
provider = get_tdx_provider()
|
||||
|
||||
# 获取历史数据
|
||||
df = provider.get_stock_history_data(stock_code, start_date, end_date)
|
||||
|
||||
if df.empty:
|
||||
error_msg = f"❌ 未能获取股票 {stock_code} 的历史数据"
|
||||
print(error_msg)
|
||||
return error_msg
|
||||
|
||||
# 获取实时数据
|
||||
realtime_data = provider.get_real_time_data(stock_code)
|
||||
|
||||
# 获取技术指标
|
||||
indicators = provider.get_stock_technical_indicators(stock_code)
|
||||
|
||||
# 格式化输出
|
||||
result = f"""
|
||||
# {stock_code} 股票数据分析
|
||||
|
||||
## 📊 实时行情
|
||||
- 股票名称: {realtime_data.get('name', 'N/A')}
|
||||
- 当前价格: ¥{realtime_data.get('price', 0):.2f}
|
||||
- 涨跌幅: {realtime_data.get('change_percent', 0):.2f}%
|
||||
- 成交量: {realtime_data.get('volume', 0):,}手
|
||||
- 更新时间: {realtime_data.get('update_time', 'N/A')}
|
||||
|
||||
## 📈 历史数据概览
|
||||
- 数据期间: {start_date} 至 {end_date}
|
||||
- 数据条数: {len(df)}条
|
||||
- 期间最高: ¥{df['High'].max():.2f}
|
||||
- 期间最低: ¥{df['Low'].min():.2f}
|
||||
- 期间涨幅: {((df['Close'].iloc[-1] - df['Close'].iloc[0]) / df['Close'].iloc[0] * 100):.2f}%
|
||||
|
||||
## 🔍 技术指标
|
||||
- MA5: ¥{indicators.get('MA5', 0):.2f}
|
||||
- MA10: ¥{indicators.get('MA10', 0):.2f}
|
||||
- MA20: ¥{indicators.get('MA20', 0):.2f}
|
||||
- RSI: {indicators.get('RSI', 0):.2f}
|
||||
- MACD: {indicators.get('MACD', 0):.4f}
|
||||
|
||||
## 📋 最近5日数据
|
||||
{df.tail().to_string()}
|
||||
|
||||
数据来源: 通达信API (实时数据)
|
||||
"""
|
||||
|
||||
# 优先保存到数据库缓存(使用统一的database_manager)
|
||||
try:
|
||||
from tradingagents.config.database_manager import get_database_manager
|
||||
db_manager = get_database_manager()
|
||||
if db_manager.is_mongodb_available():
|
||||
# 直接使用MongoDB客户端保存数据
|
||||
mongodb_client = db_manager.get_mongodb_client()
|
||||
if mongodb_client:
|
||||
db = mongodb_client[db_manager.mongodb_config["database"]]
|
||||
collection = db.stock_data
|
||||
|
||||
doc = {
|
||||
"symbol": stock_code,
|
||||
"market_type": "china",
|
||||
"data": result,
|
||||
"metadata": {
|
||||
'start_date': start_date,
|
||||
'end_date': end_date,
|
||||
'data_source': 'tdx',
|
||||
'realtime_data': realtime_data,
|
||||
'indicators': indicators,
|
||||
'history_count': len(df)
|
||||
},
|
||||
"created_at": datetime.utcnow(),
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
collection.replace_one(
|
||||
{"symbol": stock_code, "market_type": "china"},
|
||||
doc,
|
||||
upsert=True
|
||||
)
|
||||
print(f"💾 数据已保存到MongoDB: {stock_code}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 保存到MongoDB失败: {e}")
|
||||
|
||||
# 同时保存到文件缓存作为备份
|
||||
if FILE_CACHE_AVAILABLE:
|
||||
cache = get_cache()
|
||||
cache.save_stock_data(
|
||||
symbol=stock_code,
|
||||
data=result,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
data_source="tdx"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
print(f"❌ [DEBUG] 通达信API调用失败:")
|
||||
print(f"❌ [DEBUG] 错误类型: {type(e).__name__}")
|
||||
print(f"❌ [DEBUG] 错误信息: {str(e)}")
|
||||
print(f"❌ [DEBUG] 详细堆栈:")
|
||||
print(error_details)
|
||||
|
||||
return f"""
|
||||
❌ 中国股票数据获取失败 - {stock_code}
|
||||
错误类型: {type(e).__name__}
|
||||
错误信息: {str(e)}
|
||||
|
||||
🔍 调试信息:
|
||||
{error_details}
|
||||
|
||||
💡 解决建议:
|
||||
1. 检查pytdx库是否已安装: pip install pytdx
|
||||
2. 确认股票代码格式正确 (如: 000001, 600519)
|
||||
3. 检查网络连接是否正常
|
||||
4. 尝试重新连接通达信服务器
|
||||
|
||||
注: 通达信API需要网络连接到通达信服务器
|
||||
"""
|
||||
|
||||
|
||||
def get_china_market_overview() -> str:
|
||||
"""获取中国股市概览"""
|
||||
try:
|
||||
provider = get_tdx_provider()
|
||||
market_data = provider.get_market_overview()
|
||||
|
||||
if not market_data:
|
||||
return "无法获取市场概览数据"
|
||||
|
||||
result = "# 中国股市概览\n\n"
|
||||
|
||||
for name, data in market_data.items():
|
||||
change_symbol = "📈" if data['change'] >= 0 else "📉"
|
||||
result += f"## {change_symbol} {name}\n"
|
||||
result += f"- 当前点位: {data['price']:.2f}\n"
|
||||
result += f"- 涨跌点数: {data['change']:+.2f}\n"
|
||||
result += f"- 涨跌幅: {data['change_percent']:+.2f}%\n"
|
||||
result += f"- 成交量: {data['volume']:,}\n\n"
|
||||
|
||||
result += f"更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
|
||||
result += "数据来源: 通达信API\n"
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return f"获取市场概览失败: {str(e)}"
|
||||
|
||||
# 在文件末尾添加以下函数
|
||||
|
||||
def get_china_stock_data_enhanced(stock_code: str, start_date: str, end_date: str) -> str:
|
||||
"""
|
||||
增强版中国股票数据获取函数(完整降级机制)
|
||||
这是get_china_stock_data的增强版本
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码 (如 '000001')
|
||||
start_date: 开始日期 'YYYY-MM-DD'
|
||||
end_date: 结束日期 'YYYY-MM-DD'
|
||||
Returns:
|
||||
str: 格式化的股票数据
|
||||
"""
|
||||
try:
|
||||
from .stock_data_service import get_stock_data_service
|
||||
service = get_stock_data_service()
|
||||
return service.get_stock_data_with_fallback(stock_code, start_date, end_date)
|
||||
except ImportError:
|
||||
# 如果新服务不可用,降级到原有函数
|
||||
print("⚠️ 增强服务不可用,使用原有函数")
|
||||
return get_china_stock_data(stock_code, start_date, end_date)
|
||||
except Exception as e:
|
||||
print(f"⚠️ 增强服务出错,降级到原有函数: {e}")
|
||||
return get_china_stock_data(stock_code, start_date, end_date)
|
||||
|
||||
# ... existing code ...
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
import os
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
|
||||
"data_cache_dir": os.path.join(
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"dataflows/data_cache",
|
||||
),
|
||||
# LLM settings
|
||||
"llm_provider": "openai",
|
||||
"deep_think_llm": "o4-mini",
|
||||
"quick_think_llm": "gpt-4o-mini",
|
||||
"backend_url": "https://api.openai.com/v1",
|
||||
# Debate and discussion settings
|
||||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
"max_recur_limit": 100,
|
||||
# Tool settings
|
||||
"online_tools": True,
|
||||
}
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
# 文件差异报告
|
||||
# 当前文件: tradingagents\default_config.py
|
||||
# 中文版文件: TradingAgentsCN\tradingagents\default_config.py
|
||||
# 生成时间: 周日 2025/07/06
|
||||
|
||||
--- current/default_config.py+++ chinese_version/default_config.py@@ -3,7 +3,7 @@ DEFAULT_CONFIG = {
|
||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||
- "data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
|
||||
+ "data_dir": os.path.join(os.path.expanduser("~"), "Documents", "TradingAgents", "data"),
|
||||
"data_cache_dir": os.path.join(
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"dataflows/data_cache",
|
||||
@@ -19,4 +19,7 @@ "max_recur_limit": 100,
|
||||
# Tool settings
|
||||
"online_tools": True,
|
||||
+
|
||||
+ # Note: Database and cache configuration is now managed by .env file and config.database_manager
|
||||
+ # No database/cache settings in default config to avoid configuration conflicts
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
# LLM Adapters for TradingAgents
|
||||
from .dashscope_adapter import ChatDashScope
|
||||
|
||||
__all__ = ["ChatDashScope"]
|
||||
|
|
@ -0,0 +1,288 @@
|
|||
"""
|
||||
阿里百炼大模型 (DashScope) 适配器
|
||||
为 TradingAgents 提供阿里百炼大模型的 LangChain 兼容接口
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Union, Iterator, AsyncIterator, Sequence
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import Field, SecretStr
|
||||
import dashscope
|
||||
from dashscope import Generation
|
||||
from ..config.config_manager import token_tracker
|
||||
|
||||
|
||||
class ChatDashScope(BaseChatModel):
|
||||
"""阿里百炼大模型的 LangChain 适配器"""
|
||||
|
||||
# 模型配置
|
||||
model: str = Field(default="qwen-turbo", description="DashScope 模型名称")
|
||||
api_key: Optional[SecretStr] = Field(default=None, description="DashScope API 密钥")
|
||||
temperature: float = Field(default=0.1, description="生成温度")
|
||||
max_tokens: int = Field(default=2000, description="最大生成token数")
|
||||
top_p: float = Field(default=0.9, description="核采样参数")
|
||||
|
||||
# 内部属性
|
||||
_client: Any = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""初始化 DashScope 客户端"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# 设置API密钥
|
||||
api_key = self.api_key
|
||||
if api_key is None:
|
||||
api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"DashScope API key not found. Please set DASHSCOPE_API_KEY environment variable "
|
||||
"or pass api_key parameter."
|
||||
)
|
||||
|
||||
# 配置 DashScope
|
||||
if isinstance(api_key, SecretStr):
|
||||
dashscope.api_key = api_key.get_secret_value()
|
||||
else:
|
||||
dashscope.api_key = api_key
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""返回LLM类型"""
|
||||
return "dashscope"
|
||||
|
||||
def _convert_messages_to_dashscope_format(self, messages: List[BaseMessage]) -> List[Dict[str, str]]:
|
||||
"""将 LangChain 消息格式转换为 DashScope 格式"""
|
||||
dashscope_messages = []
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
role = "system"
|
||||
elif isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
role = "assistant"
|
||||
else:
|
||||
# 默认作为用户消息处理
|
||||
role = "user"
|
||||
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
# 处理多模态内容,目前只提取文本
|
||||
text_content = ""
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
text_content += item.get("text", "")
|
||||
content = text_content
|
||||
|
||||
dashscope_messages.append({
|
||||
"role": role,
|
||||
"content": str(content)
|
||||
})
|
||||
|
||||
return dashscope_messages
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""生成聊天回复"""
|
||||
|
||||
# 转换消息格式
|
||||
dashscope_messages = self._convert_messages_to_dashscope_format(messages)
|
||||
|
||||
# 准备请求参数
|
||||
request_params = {
|
||||
"model": self.model,
|
||||
"messages": dashscope_messages,
|
||||
"result_format": "message",
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
|
||||
# 添加停止词
|
||||
if stop:
|
||||
request_params["stop"] = stop
|
||||
|
||||
# 合并额外参数
|
||||
request_params.update(kwargs)
|
||||
|
||||
try:
|
||||
# 调用 DashScope API
|
||||
response = Generation.call(**request_params)
|
||||
|
||||
if response.status_code == 200:
|
||||
# 解析响应
|
||||
output = response.output
|
||||
message_content = output.choices[0].message.content
|
||||
|
||||
# 提取token使用量信息
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
# DashScope API响应中包含usage信息
|
||||
if hasattr(response, 'usage') and response.usage:
|
||||
usage = response.usage
|
||||
# 根据API文档,usage可能包含input_tokens和output_tokens
|
||||
if hasattr(usage, 'input_tokens'):
|
||||
input_tokens = usage.input_tokens
|
||||
if hasattr(usage, 'output_tokens'):
|
||||
output_tokens = usage.output_tokens
|
||||
# 有些情况下可能是total_tokens
|
||||
elif hasattr(usage, 'total_tokens'):
|
||||
# 估算输入和输出token(如果没有分别提供)
|
||||
total_tokens = usage.total_tokens
|
||||
# 简单估算:假设输入占30%,输出占70%
|
||||
input_tokens = int(total_tokens * 0.3)
|
||||
output_tokens = int(total_tokens * 0.7)
|
||||
|
||||
# 记录token使用量
|
||||
if input_tokens > 0 or output_tokens > 0:
|
||||
try:
|
||||
# 生成会话ID(如果没有提供)
|
||||
session_id = kwargs.get('session_id', f"dashscope_{hash(str(messages))%10000}")
|
||||
analysis_type = kwargs.get('analysis_type', 'stock_analysis')
|
||||
|
||||
# 使用TokenTracker记录使用量
|
||||
token_tracker.track_usage(
|
||||
provider="dashscope",
|
||||
model_name=self.model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
session_id=session_id,
|
||||
analysis_type=analysis_type
|
||||
)
|
||||
except Exception as track_error:
|
||||
# 记录失败不应该影响主要功能
|
||||
print(f"Token tracking failed: {track_error}")
|
||||
|
||||
# 创建 AI 消息
|
||||
ai_message = AIMessage(content=message_content)
|
||||
|
||||
# 创建生成结果
|
||||
generation = ChatGeneration(message=ai_message)
|
||||
|
||||
return ChatResult(generations=[generation])
|
||||
else:
|
||||
raise Exception(f"DashScope API error: {response.code} - {response.message}")
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Error calling DashScope API: {str(e)}")
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""异步生成聊天回复"""
|
||||
# 目前使用同步方法,后续可以实现真正的异步
|
||||
return self._generate(messages, stop, run_manager, **kwargs)
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], type, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> "ChatDashScope":
|
||||
"""绑定工具到模型"""
|
||||
# 注意:DashScope 目前不直接支持工具调用
|
||||
# 这里我们返回一个新的实例,但实际上工具调用需要在应用层处理
|
||||
formatted_tools = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "name") and hasattr(tool, "description"):
|
||||
# 这是一个 BaseTool 实例
|
||||
formatted_tools.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": getattr(tool, "args_schema", {})
|
||||
})
|
||||
elif isinstance(tool, dict):
|
||||
formatted_tools.append(tool)
|
||||
else:
|
||||
# 尝试转换为 OpenAI 工具格式
|
||||
try:
|
||||
formatted_tools.append(convert_to_openai_tool(tool))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 创建新实例,保存工具信息
|
||||
new_instance = self.__class__(
|
||||
model=self.model,
|
||||
api_key=self.api_key,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
top_p=self.top_p,
|
||||
**kwargs
|
||||
)
|
||||
new_instance._tools = formatted_tools
|
||||
return new_instance
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""返回标识参数"""
|
||||
return {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
|
||||
|
||||
# 支持的模型列表
|
||||
DASHSCOPE_MODELS = {
|
||||
# 通义千问系列
|
||||
"qwen-turbo": {
|
||||
"description": "通义千问 Turbo - 快速响应,适合日常对话",
|
||||
"context_length": 8192,
|
||||
"recommended_for": ["快速任务", "日常对话", "简单分析"]
|
||||
},
|
||||
"qwen-plus": {
|
||||
"description": "通义千问 Plus - 平衡性能和成本",
|
||||
"context_length": 32768,
|
||||
"recommended_for": ["复杂分析", "专业任务", "深度思考"]
|
||||
},
|
||||
"qwen-max": {
|
||||
"description": "通义千问 Max - 最强性能",
|
||||
"context_length": 32768,
|
||||
"recommended_for": ["最复杂任务", "专业分析", "高质量输出"]
|
||||
},
|
||||
"qwen-max-longcontext": {
|
||||
"description": "通义千问 Max 长文本版 - 支持超长上下文",
|
||||
"context_length": 1000000,
|
||||
"recommended_for": ["长文档分析", "大量数据处理", "复杂推理"]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, Dict[str, Any]]:
|
||||
"""获取可用的 DashScope 模型列表"""
|
||||
return DASHSCOPE_MODELS
|
||||
|
||||
|
||||
def create_dashscope_llm(
|
||||
model: str = "qwen-plus",
|
||||
api_key: Optional[str] = None,
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> ChatDashScope:
|
||||
"""创建 DashScope LLM 实例的便捷函数"""
|
||||
|
||||
return ChatDashScope(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**kwargs
|
||||
)
|
||||
Loading…
Reference in New Issue