588 lines
21 KiB
Python
588 lines
21 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Tushare数据获取工具
|
||
支持A股实时数据和历史数据,替换tushareAPI
|
||
Tushare是更稳定和专业的中国金融数据接口
|
||
"""
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
from datetime import datetime, timedelta
|
||
from typing import List, Dict, Optional, Tuple
|
||
import warnings
|
||
import os
|
||
warnings.filterwarnings('ignore')
|
||
|
||
# 导入数据库管理器
|
||
try:
|
||
from tradingagents.config.database_manager import get_database_manager
|
||
DB_MANAGER_AVAILABLE = True
|
||
except ImportError:
|
||
DB_MANAGER_AVAILABLE = False
|
||
print("⚠️ 数据库缓存管理器不可用,尝试文件缓存")
|
||
|
||
try:
|
||
from .cache_manager import get_cache
|
||
FILE_CACHE_AVAILABLE = True
|
||
except ImportError:
|
||
FILE_CACHE_AVAILABLE = False
|
||
print("⚠️ 文件缓存管理器不可用,将直接从API获取数据")
|
||
|
||
try:
|
||
import tushare as ts
|
||
TUSHARE_AVAILABLE = True
|
||
except ImportError:
|
||
TUSHARE_AVAILABLE = False
|
||
print("⚠️ tushare库未安装,无法使用Tushare API")
|
||
print("💡 安装命令: pip install tushare")
|
||
|
||
# 股票名称缓存
|
||
_stock_name_cache = {}
|
||
|
||
# 常用股票名称映射(备用)
|
||
_common_stock_names = {
|
||
'000001': '平安银行',
|
||
'000002': '万科A',
|
||
'000858': '五粮液',
|
||
'000651': '格力电器',
|
||
'000333': '美的集团',
|
||
'600036': '招商银行',
|
||
'600519': '贵州茅台',
|
||
'601318': '中国平安',
|
||
'600028': '中国石化',
|
||
'601398': '工商银行',
|
||
'600000': '浦发银行',
|
||
'000725': '京东方A',
|
||
'002415': '海康威视',
|
||
'300059': '东方财富',
|
||
'688001': '华兴源创',
|
||
'688036': '传音控股'
|
||
}
|
||
|
||
# 全局Tushare提供器实例
|
||
_tushare_provider = None
|
||
|
||
|
||
class TushareDataProvider:
|
||
"""Tushare数据提供器"""
|
||
|
||
def __init__(self):
|
||
print(f"🔍 [DEBUG] 初始化Tushare数据提供器...")
|
||
self.pro = None
|
||
self.connected = False
|
||
self.token = None
|
||
|
||
print(f"🔍 [DEBUG] 检查tushare库可用性: {TUSHARE_AVAILABLE}")
|
||
if not TUSHARE_AVAILABLE:
|
||
error_msg = "tushare库未安装,请运行: pip install tushare"
|
||
print(f"❌ [DEBUG] {error_msg}")
|
||
raise ImportError(error_msg)
|
||
print(f"✅ [DEBUG] tushare库检查通过")
|
||
|
||
# 获取Tushare token
|
||
self.token = self._get_tushare_token()
|
||
if not self.token:
|
||
print("⚠️ [DEBUG] Tushare token未配置,将使用免费接口(有限制)")
|
||
|
||
self.connect()
|
||
|
||
def _get_tushare_token(self) -> Optional[str]:
|
||
"""获取Tushare API token"""
|
||
# 从环境变量获取
|
||
token = os.getenv('TUSHARE_TOKEN')
|
||
if token:
|
||
return token
|
||
|
||
# 从.env文件获取
|
||
try:
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
token = os.getenv('TUSHARE_TOKEN')
|
||
if token:
|
||
return token
|
||
except ImportError:
|
||
pass
|
||
|
||
return None
|
||
|
||
def connect(self) -> bool:
|
||
"""连接到Tushare API"""
|
||
try:
|
||
if self.token:
|
||
ts.set_token(self.token)
|
||
self.pro = ts.pro_api()
|
||
print(f"✅ [DEBUG] Tushare Pro API连接成功")
|
||
else:
|
||
# 使用免费接口
|
||
print(f"🔍 [DEBUG] 使用Tushare免费接口")
|
||
|
||
self.connected = True
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ [DEBUG] Tushare连接失败: {e}")
|
||
self.connected = False
|
||
return False
|
||
|
||
def is_connected(self) -> bool:
|
||
"""检查连接状态"""
|
||
return self.connected
|
||
|
||
def _format_stock_code(self, stock_code: str) -> str:
|
||
"""格式化股票代码为Tushare格式"""
|
||
if len(stock_code) != 6:
|
||
return stock_code
|
||
|
||
# 判断交易所
|
||
if stock_code.startswith('6'):
|
||
return f"{stock_code}.SH" # 上交所
|
||
elif stock_code.startswith(('0', '3')):
|
||
return f"{stock_code}.SZ" # 深交所
|
||
else:
|
||
return f"{stock_code}.SZ" # 默认深交所
|
||
|
||
def _get_stock_name(self, stock_code: str) -> str:
|
||
"""获取股票名称"""
|
||
global _stock_name_cache
|
||
|
||
# 首先检查缓存
|
||
if stock_code in _stock_name_cache:
|
||
return _stock_name_cache[stock_code]
|
||
|
||
# 检查常用股票映射表
|
||
if stock_code in _common_stock_names:
|
||
name = _common_stock_names[stock_code]
|
||
_stock_name_cache[stock_code] = name
|
||
return name
|
||
|
||
# 从Tushare获取
|
||
try:
|
||
if self.pro:
|
||
ts_code = self._format_stock_code(stock_code)
|
||
df = self.pro.stock_basic(ts_code=ts_code, fields='ts_code,name')
|
||
if not df.empty:
|
||
name = df.iloc[0]['name']
|
||
_stock_name_cache[stock_code] = name
|
||
return name
|
||
except Exception as e:
|
||
print(f"⚠️ 从Tushare获取股票名称失败: {e}")
|
||
|
||
# 默认格式
|
||
default_name = f'股票{stock_code}'
|
||
_stock_name_cache[stock_code] = default_name
|
||
return default_name
|
||
|
||
def get_stock_history_data(self, stock_code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
||
"""获取股票历史数据"""
|
||
try:
|
||
ts_code = self._format_stock_code(stock_code)
|
||
|
||
if self.pro:
|
||
# 使用Pro API
|
||
df = self.pro.daily(ts_code=ts_code, start_date=start_date.replace('-', ''),
|
||
end_date=end_date.replace('-', ''))
|
||
else:
|
||
# 使用免费接口
|
||
df = ts.get_hist_data(stock_code, start=start_date, end=end_date)
|
||
if df is not None:
|
||
df = df.reset_index()
|
||
df['trade_date'] = df['date']
|
||
df = df.rename(columns={
|
||
'open': 'open',
|
||
'high': 'high',
|
||
'low': 'low',
|
||
'close': 'close',
|
||
'volume': 'vol'
|
||
})
|
||
|
||
if df is None or df.empty:
|
||
print(f"⚠️ 未获取到股票 {stock_code} 的历史数据")
|
||
return pd.DataFrame()
|
||
|
||
# 数据处理
|
||
df = df.sort_values('trade_date')
|
||
return df
|
||
|
||
except Exception as e:
|
||
print(f"❌ 获取历史数据失败: {e}")
|
||
return pd.DataFrame()
|
||
|
||
def get_real_time_data(self, stock_code: str) -> Optional[Dict]:
|
||
"""获取实时数据"""
|
||
try:
|
||
if self.pro:
|
||
# Pro API获取最新交易日数据
|
||
ts_code = self._format_stock_code(stock_code)
|
||
df = self.pro.daily(ts_code=ts_code, trade_date='', limit=1)
|
||
if not df.empty:
|
||
row = df.iloc[0]
|
||
return {
|
||
'price': row['close'],
|
||
'change': row['change'] if 'change' in row else 0,
|
||
'change_percent': row['pct_chg'] if 'pct_chg' in row else 0,
|
||
'volume': row['vol'],
|
||
'turnover': row['amount'] if 'amount' in row else 0,
|
||
'high': row['high'],
|
||
'low': row['low'],
|
||
'open': row['open']
|
||
}
|
||
else:
|
||
# 免费接口获取实时数据
|
||
df = ts.get_realtime_quotes(stock_code)
|
||
if df is not None and not df.empty:
|
||
row = df.iloc[0]
|
||
return {
|
||
'price': float(row['price']),
|
||
'change': float(row['change']),
|
||
'change_percent': float(row['changepercent']),
|
||
'volume': float(row['volume']),
|
||
'high': float(row['high']),
|
||
'low': float(row['low']),
|
||
'open': float(row['open'])
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"⚠️ 获取实时数据失败: {e}")
|
||
|
||
return None
|
||
|
||
def get_stock_technical_indicators(self, stock_code: str) -> Dict:
|
||
"""获取技术指标"""
|
||
try:
|
||
# 获取最近30天数据计算技术指标
|
||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d')
|
||
|
||
df = self.get_stock_history_data(stock_code, start_date, end_date)
|
||
|
||
if df.empty:
|
||
return {}
|
||
|
||
# 计算技术指标
|
||
closes = df['close'].astype(float)
|
||
|
||
# MA均线
|
||
ma5 = closes.rolling(5).mean().iloc[-1] if len(closes) >= 5 else closes.iloc[-1]
|
||
ma10 = closes.rolling(10).mean().iloc[-1] if len(closes) >= 10 else closes.iloc[-1]
|
||
ma20 = closes.rolling(20).mean().iloc[-1] if len(closes) >= 20 else closes.iloc[-1]
|
||
|
||
# RSI
|
||
def calculate_rsi(prices, period=14):
|
||
delta = prices.diff()
|
||
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||
rs = gain / loss
|
||
rsi = 100 - (100 / (1 + rs))
|
||
return rsi.iloc[-1] if not rsi.empty else 50
|
||
|
||
rsi = calculate_rsi(closes)
|
||
|
||
return {
|
||
'MA5': round(ma5, 2),
|
||
'MA10': round(ma10, 2),
|
||
'MA20': round(ma20, 2),
|
||
'RSI': round(rsi, 2)
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"⚠️ 计算技术指标失败: {e}")
|
||
return {}
|
||
|
||
def search_stocks(self, keyword: str) -> List[Dict]:
|
||
"""搜索股票"""
|
||
try:
|
||
results = []
|
||
|
||
if self.pro:
|
||
# 使用Pro API搜索
|
||
df = self.pro.stock_basic(fields='ts_code,symbol,name')
|
||
if not df.empty:
|
||
# 按关键词过滤
|
||
filtered = df[df['name'].str.contains(keyword, na=False)]
|
||
for _, row in filtered.head(10).iterrows():
|
||
stock_code = row['symbol']
|
||
realtime_data = self.get_real_time_data(stock_code)
|
||
results.append({
|
||
'code': stock_code,
|
||
'name': row['name'],
|
||
'price': realtime_data.get('price', 0) if realtime_data else 0,
|
||
'change_percent': realtime_data.get('change_percent', 0) if realtime_data else 0
|
||
})
|
||
else:
|
||
# 使用常见股票映射搜索
|
||
for name, code in _common_stock_names.items():
|
||
if keyword.lower() in name.lower() or keyword in code:
|
||
realtime_data = self.get_real_time_data(code)
|
||
results.append({
|
||
'code': code,
|
||
'name': name,
|
||
'price': realtime_data.get('price', 0) if realtime_data else 0,
|
||
'change_percent': realtime_data.get('change_percent', 0) if realtime_data else 0
|
||
})
|
||
|
||
return results
|
||
|
||
except Exception as e:
|
||
print(f"搜索股票失败: {e}")
|
||
return []
|
||
|
||
|
||
def get_tushare_provider() -> TushareDataProvider:
|
||
"""获取Tushare数据提供器实例"""
|
||
global _tushare_provider
|
||
if _tushare_provider is None:
|
||
print(f"🔍 [DEBUG] 创建新的Tushare数据提供器实例...")
|
||
_tushare_provider = TushareDataProvider()
|
||
print(f"🔍 [DEBUG] Tushare数据提供器实例创建完成")
|
||
else:
|
||
print(f"🔍 [DEBUG] 使用现有的Tushare数据提供器实例")
|
||
# 检查连接状态,如果连接断开则重新创建
|
||
if not _tushare_provider.is_connected():
|
||
print(f"🔍 [DEBUG] 检测到连接断开,重新创建Tushare数据提供器...")
|
||
_tushare_provider = TushareDataProvider()
|
||
print(f"🔍 [DEBUG] Tushare数据提供器重新创建完成")
|
||
return _tushare_provider
|
||
|
||
|
||
def get_china_stock_data(stock_code: str, start_date: str, end_date: str) -> str:
|
||
"""
|
||
获取中国股票数据的主要接口函数(支持缓存)
|
||
使用Tushare API替换tushareAPI
|
||
Args:
|
||
stock_code: 股票代码 (如 '000001')
|
||
start_date: 开始日期 'YYYY-MM-DD'
|
||
end_date: 结束日期 'YYYY-MM-DD'
|
||
Returns:
|
||
str: 格式化的股票数据
|
||
"""
|
||
print(f"📊 正在获取中国股票数据: {stock_code} ({start_date} 到 {end_date})")
|
||
|
||
# 优先尝试从数据库缓存加载数据(使用统一的database_manager)
|
||
try:
|
||
from tradingagents.config.database_manager import get_database_manager
|
||
db_manager = get_database_manager()
|
||
if db_manager.is_mongodb_available():
|
||
# 直接使用MongoDB客户端查询缓存数据
|
||
mongodb_client = db_manager.get_mongodb_client()
|
||
if mongodb_client:
|
||
db = mongodb_client[db_manager.mongodb_config["database"]]
|
||
collection = db.stock_data
|
||
|
||
# 查询最近的缓存数据
|
||
from datetime import datetime, timedelta
|
||
cutoff_time = datetime.utcnow() - timedelta(hours=6)
|
||
|
||
cached_doc = collection.find_one({
|
||
"symbol": stock_code,
|
||
"market_type": "china",
|
||
"created_at": {"$gte": cutoff_time}
|
||
}, sort=[("created_at", -1)])
|
||
|
||
if cached_doc and 'data' in cached_doc:
|
||
print(f"🗄️ 从MongoDB缓存加载数据: {stock_code}")
|
||
return cached_doc['data']
|
||
except Exception as e:
|
||
print(f"⚠️ 从MongoDB加载缓存失败: {e}")
|
||
|
||
# 如果数据库缓存不可用,尝试文件缓存
|
||
if FILE_CACHE_AVAILABLE:
|
||
cache = get_cache()
|
||
try:
|
||
cache_key = cache.find_cached_stock_data(
|
||
symbol=stock_code,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
data_source="tushare",
|
||
max_age_hours=6 # 6小时内的缓存有效
|
||
)
|
||
except TypeError:
|
||
# 如果缓存管理器不支持max_age_hours参数,则忽略该参数
|
||
print("⚠️ 缓存管理器不支持max_age_hours参数,使用默认缓存策略")
|
||
cache_key = cache.find_cached_stock_data(
|
||
symbol=stock_code,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
data_source="tushare"
|
||
)
|
||
|
||
if cache_key:
|
||
cached_data = cache.load_stock_data(cache_key)
|
||
if cached_data:
|
||
print(f"💾 从文件缓存加载数据: {stock_code} -> {cache_key}")
|
||
return cached_data
|
||
|
||
print(f"🌐 从Tushare API获取数据: {stock_code}")
|
||
|
||
try:
|
||
provider = get_tushare_provider()
|
||
|
||
# 获取历史数据
|
||
df = provider.get_stock_history_data(stock_code, start_date, end_date)
|
||
|
||
if df.empty:
|
||
error_msg = f"❌ 未能获取股票 {stock_code} 的历史数据"
|
||
print(error_msg)
|
||
return error_msg
|
||
|
||
# 获取实时数据
|
||
realtime_data = provider.get_real_time_data(stock_code)
|
||
|
||
# 获取技术指标
|
||
indicators = provider.get_stock_technical_indicators(stock_code)
|
||
|
||
# 获取股票名称
|
||
stock_name = provider._get_stock_name(stock_code)
|
||
|
||
# 格式化输出
|
||
result = f"""
|
||
# {stock_code} ({stock_name}) 股票数据分析
|
||
|
||
## 基本信息
|
||
- 股票代码: {stock_code}
|
||
- 股票名称: {stock_name}
|
||
- 数据源: Tushare API
|
||
- 数据时间: {start_date} 到 {end_date}
|
||
|
||
## 实时行情 (最新交易日)
|
||
"""
|
||
|
||
if realtime_data:
|
||
result += f"""- 最新价格: ¥{realtime_data['price']:.2f}
|
||
- 涨跌幅: {realtime_data['change_percent']:.2f}%
|
||
- 成交量: {realtime_data['volume']:,.0f}
|
||
- 最高价: ¥{realtime_data['high']:.2f}
|
||
- 最低价: ¥{realtime_data['low']:.2f}
|
||
- 开盘价: ¥{realtime_data['open']:.2f}
|
||
"""
|
||
else:
|
||
result += "- 实时数据获取失败\n"
|
||
|
||
# 技术指标
|
||
if indicators:
|
||
result += f"""
|
||
## 技术指标
|
||
- MA5: ¥{indicators.get('MA5', 'N/A')}
|
||
- MA10: ¥{indicators.get('MA10', 'N/A')}
|
||
- MA20: ¥{indicators.get('MA20', 'N/A')}
|
||
- RSI: {indicators.get('RSI', 'N/A')}
|
||
"""
|
||
|
||
# 历史数据统计
|
||
result += f"""
|
||
## 历史数据统计 ({len(df)}个交易日)
|
||
- 最高价: ¥{df['high'].max():.2f}
|
||
- 最低价: ¥{df['low'].min():.2f}
|
||
- 平均价: ¥{df['close'].mean():.2f}
|
||
- 总成交量: {df['vol'].sum():,.0f}
|
||
|
||
## 最近5个交易日
|
||
"""
|
||
|
||
# 显示最近5天数据
|
||
recent_data = df.tail(5)
|
||
for _, row in recent_data.iterrows():
|
||
result += f"- {row['trade_date']}: 收盘¥{row['close']:.2f}, 成交量{row['vol']:,.0f}\n"
|
||
|
||
print(f"✅ 股票数据获取成功: {stock_code}")
|
||
|
||
# 优先保存到数据库缓存(使用统一的database_manager)
|
||
try:
|
||
from tradingagents.config.database_manager import get_database_manager
|
||
db_manager = get_database_manager()
|
||
if db_manager.is_mongodb_available():
|
||
# 直接使用MongoDB客户端保存数据
|
||
mongodb_client = db_manager.get_mongodb_client()
|
||
if mongodb_client:
|
||
db = mongodb_client[db_manager.mongodb_config["database"]]
|
||
collection = db.stock_data
|
||
|
||
doc = {
|
||
"symbol": stock_code,
|
||
"market_type": "china",
|
||
"data": result,
|
||
"metadata": {
|
||
'start_date': start_date,
|
||
'end_date': end_date,
|
||
'data_source': 'tushare',
|
||
'realtime_data': realtime_data,
|
||
'indicators': indicators,
|
||
'history_count': len(df)
|
||
},
|
||
"created_at": datetime.utcnow(),
|
||
"updated_at": datetime.utcnow()
|
||
}
|
||
|
||
collection.replace_one(
|
||
{"symbol": stock_code, "market_type": "china"},
|
||
doc,
|
||
upsert=True
|
||
)
|
||
print(f"💾 数据已保存到MongoDB: {stock_code}")
|
||
except Exception as e:
|
||
print(f"⚠️ 保存到MongoDB失败: {e}")
|
||
|
||
# 同时保存到文件缓存作为备份
|
||
if FILE_CACHE_AVAILABLE:
|
||
cache = get_cache()
|
||
cache.save_stock_data(
|
||
symbol=stock_code,
|
||
data=result,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
data_source="tushare"
|
||
)
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = traceback.format_exc()
|
||
print(f"❌ [DEBUG] Tushare API调用失败:")
|
||
print(f"❌ [DEBUG] 错误类型: {type(e).__name__}")
|
||
print(f"❌ [DEBUG] 错误信息: {str(e)}")
|
||
print(f"❌ [DEBUG] 详细堆栈:")
|
||
print(error_details)
|
||
|
||
return f"""
|
||
❌ 中国股票数据获取失败 - {stock_code}
|
||
错误类型: {type(e).__name__}
|
||
错误信息: {str(e)}
|
||
|
||
🔍 调试信息:
|
||
{error_details}
|
||
|
||
💡 解决建议:
|
||
1. 检查tushare库是否已安装: pip install tushare
|
||
2. 确认股票代码格式正确 (如: 000001, 600519)
|
||
3. 检查网络连接是否正常
|
||
4. 配置Tushare token以获得更好的服务: TUSHARE_TOKEN=your_token
|
||
5. 检查Tushare API服务状态
|
||
|
||
📚 Tushare文档: https://tushare.pro/
|
||
"""
|
||
|
||
|
||
def get_china_stock_data_enhanced(stock_code: str, start_date: str, end_date: str) -> str:
|
||
"""
|
||
增强版中国股票数据获取函数(完整降级机制)
|
||
这是get_china_stock_data的增强版本,使用Tushare API
|
||
|
||
Args:
|
||
stock_code: 股票代码 (如 '000001')
|
||
start_date: 开始日期 'YYYY-MM-DD'
|
||
end_date: 结束日期 'YYYY-MM-DD'
|
||
Returns:
|
||
str: 格式化的股票数据
|
||
"""
|
||
try:
|
||
from .stock_data_service import get_stock_data_service
|
||
service = get_stock_data_service()
|
||
return service.get_stock_data_with_fallback(stock_code, start_date, end_date)
|
||
except ImportError:
|
||
# 如果新服务不可用,降级到原有函数
|
||
print("⚠️ 增强服务不可用,使用原有函数")
|
||
return get_china_stock_data(stock_code, start_date, end_date)
|
||
except Exception as e:
|
||
print(f"⚠️ 增强服务出错,降级到原有函数: {e}")
|
||
return get_china_stock_data(stock_code, start_date, end_date)
|