TradingAgents/tradingagents/dataflows/tushare_utils.py

588 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Tushare数据获取工具
支持A股实时数据和历史数据替换tushareAPI
Tushare是更稳定和专业的中国金融数据接口
"""
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from typing import List, Dict, Optional, Tuple
import warnings
import os
warnings.filterwarnings('ignore')
# 导入数据库管理器
try:
from tradingagents.config.database_manager import get_database_manager
DB_MANAGER_AVAILABLE = True
except ImportError:
DB_MANAGER_AVAILABLE = False
print("⚠️ 数据库缓存管理器不可用,尝试文件缓存")
try:
from .cache_manager import get_cache
FILE_CACHE_AVAILABLE = True
except ImportError:
FILE_CACHE_AVAILABLE = False
print("⚠️ 文件缓存管理器不可用将直接从API获取数据")
try:
import tushare as ts
TUSHARE_AVAILABLE = True
except ImportError:
TUSHARE_AVAILABLE = False
print("⚠️ tushare库未安装无法使用Tushare API")
print("💡 安装命令: pip install tushare")
# 股票名称缓存
_stock_name_cache = {}
# 常用股票名称映射(备用)
_common_stock_names = {
'000001': '平安银行',
'000002': '万科A',
'000858': '五粮液',
'000651': '格力电器',
'000333': '美的集团',
'600036': '招商银行',
'600519': '贵州茅台',
'601318': '中国平安',
'600028': '中国石化',
'601398': '工商银行',
'600000': '浦发银行',
'000725': '京东方A',
'002415': '海康威视',
'300059': '东方财富',
'688001': '华兴源创',
'688036': '传音控股'
}
# 全局Tushare提供器实例
_tushare_provider = None
class TushareDataProvider:
"""Tushare数据提供器"""
def __init__(self):
print(f"🔍 [DEBUG] 初始化Tushare数据提供器...")
self.pro = None
self.connected = False
self.token = None
print(f"🔍 [DEBUG] 检查tushare库可用性: {TUSHARE_AVAILABLE}")
if not TUSHARE_AVAILABLE:
error_msg = "tushare库未安装请运行: pip install tushare"
print(f"❌ [DEBUG] {error_msg}")
raise ImportError(error_msg)
print(f"✅ [DEBUG] tushare库检查通过")
# 获取Tushare token
self.token = self._get_tushare_token()
if not self.token:
print("⚠️ [DEBUG] Tushare token未配置将使用免费接口有限制")
self.connect()
def _get_tushare_token(self) -> Optional[str]:
"""获取Tushare API token"""
# 从环境变量获取
token = os.getenv('TUSHARE_TOKEN')
if token:
return token
# 从.env文件获取
try:
from dotenv import load_dotenv
load_dotenv()
token = os.getenv('TUSHARE_TOKEN')
if token:
return token
except ImportError:
pass
return None
def connect(self) -> bool:
"""连接到Tushare API"""
try:
if self.token:
ts.set_token(self.token)
self.pro = ts.pro_api()
print(f"✅ [DEBUG] Tushare Pro API连接成功")
else:
# 使用免费接口
print(f"🔍 [DEBUG] 使用Tushare免费接口")
self.connected = True
return True
except Exception as e:
print(f"❌ [DEBUG] Tushare连接失败: {e}")
self.connected = False
return False
def is_connected(self) -> bool:
"""检查连接状态"""
return self.connected
def _format_stock_code(self, stock_code: str) -> str:
"""格式化股票代码为Tushare格式"""
if len(stock_code) != 6:
return stock_code
# 判断交易所
if stock_code.startswith('6'):
return f"{stock_code}.SH" # 上交所
elif stock_code.startswith(('0', '3')):
return f"{stock_code}.SZ" # 深交所
else:
return f"{stock_code}.SZ" # 默认深交所
def _get_stock_name(self, stock_code: str) -> str:
"""获取股票名称"""
global _stock_name_cache
# 首先检查缓存
if stock_code in _stock_name_cache:
return _stock_name_cache[stock_code]
# 检查常用股票映射表
if stock_code in _common_stock_names:
name = _common_stock_names[stock_code]
_stock_name_cache[stock_code] = name
return name
# 从Tushare获取
try:
if self.pro:
ts_code = self._format_stock_code(stock_code)
df = self.pro.stock_basic(ts_code=ts_code, fields='ts_code,name')
if not df.empty:
name = df.iloc[0]['name']
_stock_name_cache[stock_code] = name
return name
except Exception as e:
print(f"⚠️ 从Tushare获取股票名称失败: {e}")
# 默认格式
default_name = f'股票{stock_code}'
_stock_name_cache[stock_code] = default_name
return default_name
def get_stock_history_data(self, stock_code: str, start_date: str, end_date: str) -> pd.DataFrame:
"""获取股票历史数据"""
try:
ts_code = self._format_stock_code(stock_code)
if self.pro:
# 使用Pro API
df = self.pro.daily(ts_code=ts_code, start_date=start_date.replace('-', ''),
end_date=end_date.replace('-', ''))
else:
# 使用免费接口
df = ts.get_hist_data(stock_code, start=start_date, end=end_date)
if df is not None:
df = df.reset_index()
df['trade_date'] = df['date']
df = df.rename(columns={
'open': 'open',
'high': 'high',
'low': 'low',
'close': 'close',
'volume': 'vol'
})
if df is None or df.empty:
print(f"⚠️ 未获取到股票 {stock_code} 的历史数据")
return pd.DataFrame()
# 数据处理
df = df.sort_values('trade_date')
return df
except Exception as e:
print(f"❌ 获取历史数据失败: {e}")
return pd.DataFrame()
def get_real_time_data(self, stock_code: str) -> Optional[Dict]:
"""获取实时数据"""
try:
if self.pro:
# Pro API获取最新交易日数据
ts_code = self._format_stock_code(stock_code)
df = self.pro.daily(ts_code=ts_code, trade_date='', limit=1)
if not df.empty:
row = df.iloc[0]
return {
'price': row['close'],
'change': row['change'] if 'change' in row else 0,
'change_percent': row['pct_chg'] if 'pct_chg' in row else 0,
'volume': row['vol'],
'turnover': row['amount'] if 'amount' in row else 0,
'high': row['high'],
'low': row['low'],
'open': row['open']
}
else:
# 免费接口获取实时数据
df = ts.get_realtime_quotes(stock_code)
if df is not None and not df.empty:
row = df.iloc[0]
return {
'price': float(row['price']),
'change': float(row['change']),
'change_percent': float(row['changepercent']),
'volume': float(row['volume']),
'high': float(row['high']),
'low': float(row['low']),
'open': float(row['open'])
}
except Exception as e:
print(f"⚠️ 获取实时数据失败: {e}")
return None
def get_stock_technical_indicators(self, stock_code: str) -> Dict:
"""获取技术指标"""
try:
# 获取最近30天数据计算技术指标
end_date = datetime.now().strftime('%Y-%m-%d')
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d')
df = self.get_stock_history_data(stock_code, start_date, end_date)
if df.empty:
return {}
# 计算技术指标
closes = df['close'].astype(float)
# MA均线
ma5 = closes.rolling(5).mean().iloc[-1] if len(closes) >= 5 else closes.iloc[-1]
ma10 = closes.rolling(10).mean().iloc[-1] if len(closes) >= 10 else closes.iloc[-1]
ma20 = closes.rolling(20).mean().iloc[-1] if len(closes) >= 20 else closes.iloc[-1]
# RSI
def calculate_rsi(prices, period=14):
delta = prices.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
rsi = 100 - (100 / (1 + rs))
return rsi.iloc[-1] if not rsi.empty else 50
rsi = calculate_rsi(closes)
return {
'MA5': round(ma5, 2),
'MA10': round(ma10, 2),
'MA20': round(ma20, 2),
'RSI': round(rsi, 2)
}
except Exception as e:
print(f"⚠️ 计算技术指标失败: {e}")
return {}
def search_stocks(self, keyword: str) -> List[Dict]:
"""搜索股票"""
try:
results = []
if self.pro:
# 使用Pro API搜索
df = self.pro.stock_basic(fields='ts_code,symbol,name')
if not df.empty:
# 按关键词过滤
filtered = df[df['name'].str.contains(keyword, na=False)]
for _, row in filtered.head(10).iterrows():
stock_code = row['symbol']
realtime_data = self.get_real_time_data(stock_code)
results.append({
'code': stock_code,
'name': row['name'],
'price': realtime_data.get('price', 0) if realtime_data else 0,
'change_percent': realtime_data.get('change_percent', 0) if realtime_data else 0
})
else:
# 使用常见股票映射搜索
for name, code in _common_stock_names.items():
if keyword.lower() in name.lower() or keyword in code:
realtime_data = self.get_real_time_data(code)
results.append({
'code': code,
'name': name,
'price': realtime_data.get('price', 0) if realtime_data else 0,
'change_percent': realtime_data.get('change_percent', 0) if realtime_data else 0
})
return results
except Exception as e:
print(f"搜索股票失败: {e}")
return []
def get_tushare_provider() -> TushareDataProvider:
"""获取Tushare数据提供器实例"""
global _tushare_provider
if _tushare_provider is None:
print(f"🔍 [DEBUG] 创建新的Tushare数据提供器实例...")
_tushare_provider = TushareDataProvider()
print(f"🔍 [DEBUG] Tushare数据提供器实例创建完成")
else:
print(f"🔍 [DEBUG] 使用现有的Tushare数据提供器实例")
# 检查连接状态,如果连接断开则重新创建
if not _tushare_provider.is_connected():
print(f"🔍 [DEBUG] 检测到连接断开重新创建Tushare数据提供器...")
_tushare_provider = TushareDataProvider()
print(f"🔍 [DEBUG] Tushare数据提供器重新创建完成")
return _tushare_provider
def get_china_stock_data(stock_code: str, start_date: str, end_date: str) -> str:
"""
获取中国股票数据的主要接口函数(支持缓存)
使用Tushare API替换tushareAPI
Args:
stock_code: 股票代码 (如 '000001')
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
Returns:
str: 格式化的股票数据
"""
print(f"📊 正在获取中国股票数据: {stock_code} ({start_date}{end_date})")
# 优先尝试从数据库缓存加载数据使用统一的database_manager
try:
from tradingagents.config.database_manager import get_database_manager
db_manager = get_database_manager()
if db_manager.is_mongodb_available():
# 直接使用MongoDB客户端查询缓存数据
mongodb_client = db_manager.get_mongodb_client()
if mongodb_client:
db = mongodb_client[db_manager.mongodb_config["database"]]
collection = db.stock_data
# 查询最近的缓存数据
from datetime import datetime, timedelta
cutoff_time = datetime.utcnow() - timedelta(hours=6)
cached_doc = collection.find_one({
"symbol": stock_code,
"market_type": "china",
"created_at": {"$gte": cutoff_time}
}, sort=[("created_at", -1)])
if cached_doc and 'data' in cached_doc:
print(f"🗄️ 从MongoDB缓存加载数据: {stock_code}")
return cached_doc['data']
except Exception as e:
print(f"⚠️ 从MongoDB加载缓存失败: {e}")
# 如果数据库缓存不可用,尝试文件缓存
if FILE_CACHE_AVAILABLE:
cache = get_cache()
try:
cache_key = cache.find_cached_stock_data(
symbol=stock_code,
start_date=start_date,
end_date=end_date,
data_source="tushare",
max_age_hours=6 # 6小时内的缓存有效
)
except TypeError:
# 如果缓存管理器不支持max_age_hours参数则忽略该参数
print("⚠️ 缓存管理器不支持max_age_hours参数使用默认缓存策略")
cache_key = cache.find_cached_stock_data(
symbol=stock_code,
start_date=start_date,
end_date=end_date,
data_source="tushare"
)
if cache_key:
cached_data = cache.load_stock_data(cache_key)
if cached_data:
print(f"💾 从文件缓存加载数据: {stock_code} -> {cache_key}")
return cached_data
print(f"🌐 从Tushare API获取数据: {stock_code}")
try:
provider = get_tushare_provider()
# 获取历史数据
df = provider.get_stock_history_data(stock_code, start_date, end_date)
if df.empty:
error_msg = f"❌ 未能获取股票 {stock_code} 的历史数据"
print(error_msg)
return error_msg
# 获取实时数据
realtime_data = provider.get_real_time_data(stock_code)
# 获取技术指标
indicators = provider.get_stock_technical_indicators(stock_code)
# 获取股票名称
stock_name = provider._get_stock_name(stock_code)
# 格式化输出
result = f"""
# {stock_code} ({stock_name}) 股票数据分析
## 基本信息
- 股票代码: {stock_code}
- 股票名称: {stock_name}
- 数据源: Tushare API
- 数据时间: {start_date}{end_date}
## 实时行情 (最新交易日)
"""
if realtime_data:
result += f"""- 最新价格: ¥{realtime_data['price']:.2f}
- 涨跌幅: {realtime_data['change_percent']:.2f}%
- 成交量: {realtime_data['volume']:,.0f}
- 最高价: ¥{realtime_data['high']:.2f}
- 最低价: ¥{realtime_data['low']:.2f}
- 开盘价: ¥{realtime_data['open']:.2f}
"""
else:
result += "- 实时数据获取失败\n"
# 技术指标
if indicators:
result += f"""
## 技术指标
- MA5: ¥{indicators.get('MA5', 'N/A')}
- MA10: ¥{indicators.get('MA10', 'N/A')}
- MA20: ¥{indicators.get('MA20', 'N/A')}
- RSI: {indicators.get('RSI', 'N/A')}
"""
# 历史数据统计
result += f"""
## 历史数据统计 ({len(df)}个交易日)
- 最高价: ¥{df['high'].max():.2f}
- 最低价: ¥{df['low'].min():.2f}
- 平均价: ¥{df['close'].mean():.2f}
- 总成交量: {df['vol'].sum():,.0f}
## 最近5个交易日
"""
# 显示最近5天数据
recent_data = df.tail(5)
for _, row in recent_data.iterrows():
result += f"- {row['trade_date']}: 收盘¥{row['close']:.2f}, 成交量{row['vol']:,.0f}\n"
print(f"✅ 股票数据获取成功: {stock_code}")
# 优先保存到数据库缓存使用统一的database_manager
try:
from tradingagents.config.database_manager import get_database_manager
db_manager = get_database_manager()
if db_manager.is_mongodb_available():
# 直接使用MongoDB客户端保存数据
mongodb_client = db_manager.get_mongodb_client()
if mongodb_client:
db = mongodb_client[db_manager.mongodb_config["database"]]
collection = db.stock_data
doc = {
"symbol": stock_code,
"market_type": "china",
"data": result,
"metadata": {
'start_date': start_date,
'end_date': end_date,
'data_source': 'tushare',
'realtime_data': realtime_data,
'indicators': indicators,
'history_count': len(df)
},
"created_at": datetime.utcnow(),
"updated_at": datetime.utcnow()
}
collection.replace_one(
{"symbol": stock_code, "market_type": "china"},
doc,
upsert=True
)
print(f"💾 数据已保存到MongoDB: {stock_code}")
except Exception as e:
print(f"⚠️ 保存到MongoDB失败: {e}")
# 同时保存到文件缓存作为备份
if FILE_CACHE_AVAILABLE:
cache = get_cache()
cache.save_stock_data(
symbol=stock_code,
data=result,
start_date=start_date,
end_date=end_date,
data_source="tushare"
)
return result
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"❌ [DEBUG] Tushare API调用失败:")
print(f"❌ [DEBUG] 错误类型: {type(e).__name__}")
print(f"❌ [DEBUG] 错误信息: {str(e)}")
print(f"❌ [DEBUG] 详细堆栈:")
print(error_details)
return f"""
❌ 中国股票数据获取失败 - {stock_code}
错误类型: {type(e).__name__}
错误信息: {str(e)}
🔍 调试信息:
{error_details}
💡 解决建议:
1. 检查tushare库是否已安装: pip install tushare
2. 确认股票代码格式正确 (如: 000001, 600519)
3. 检查网络连接是否正常
4. 配置Tushare token以获得更好的服务: TUSHARE_TOKEN=your_token
5. 检查Tushare API服务状态
📚 Tushare文档: https://tushare.pro/
"""
def get_china_stock_data_enhanced(stock_code: str, start_date: str, end_date: str) -> str:
"""
增强版中国股票数据获取函数(完整降级机制)
这是get_china_stock_data的增强版本使用Tushare API
Args:
stock_code: 股票代码 (如 '000001')
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
Returns:
str: 格式化的股票数据
"""
try:
from .stock_data_service import get_stock_data_service
service = get_stock_data_service()
return service.get_stock_data_with_fallback(stock_code, start_date, end_date)
except ImportError:
# 如果新服务不可用,降级到原有函数
print("⚠️ 增强服务不可用,使用原有函数")
return get_china_stock_data(stock_code, start_date, end_date)
except Exception as e:
print(f"⚠️ 增强服务出错,降级到原有函数: {e}")
return get_china_stock_data(stock_code, start_date, end_date)