#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 统一的股票数据获取服务 实现MongoDB -> Tushare API的完整降级机制 """ import pandas as pd from typing import Dict, List, Optional, Any from datetime import datetime, timedelta import logging try: from tradingagents.config.database_manager import get_database_manager DATABASE_MANAGER_AVAILABLE = True except ImportError: DATABASE_MANAGER_AVAILABLE = False try: from .tushare_utils import get_tushare_provider, TushareDataProvider TUSHARE_AVAILABLE = True except ImportError: TUSHARE_AVAILABLE = False try: import sys import os # 添加utils目录到路径 utils_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'utils') if utils_path not in sys.path: sys.path.append(utils_path) from enhanced_stock_list_fetcher import enhanced_fetch_stock_list ENHANCED_FETCHER_AVAILABLE = True except ImportError: ENHANCED_FETCHER_AVAILABLE = False logger = logging.getLogger(__name__) class StockDataService: """ 统一的股票数据获取服务 实现完整的降级机制:MongoDB -> Tushare API -> 缓存 -> 错误处理 """ def __init__(self): self.db_manager = None self._init_services() def _init_services(self): """初始化服务""" # 尝试初始化数据库管理器 if DATABASE_MANAGER_AVAILABLE: try: self.db_manager = get_database_manager() if self.db_manager.is_mongodb_available(): print("✅ MongoDB连接成功") else: print("⚠️ MongoDB连接失败,将使用Tushare API") except Exception as e: print(f"⚠️ 数据库管理器初始化失败: {e}") self.db_manager = None def get_stock_basic_info(self, stock_code: str = None) -> Optional[Dict[str, Any]]: """ 获取股票基础信息(单个股票或全部股票) Args: stock_code: 股票代码,如果为None则返回所有股票 Returns: Dict: 股票基础信息 """ print(f"📊 获取股票基础信息: {stock_code or '全部股票'}") # 1. 优先从MongoDB获取 if self.db_manager and self.db_manager.is_mongodb_available(): try: result = self._get_from_mongodb(stock_code) if result: print(f"✅ 从MongoDB获取成功: {len(result) if isinstance(result, list) else 1}条记录") return result except Exception as e: print(f"⚠️ MongoDB查询失败: {e}") # 2. 降级到Tushare API print("🔄 MongoDB不可用,降级到Tushare API") try: from .tushare_utils import get_tushare_provider provider = get_tushare_provider() if provider and provider.is_connected(): # 获取股票基本信息 stock_name = provider.get_stock_name(stock_code) if stock_name: result = { 'code': stock_code, 'name': stock_name, 'market': self._get_market_name(stock_code), 'category': self._get_stock_category(stock_code), 'source': 'tushare_api', 'updated_at': datetime.now().isoformat() } print(f"✅ 从Tushare API获取成功: {stock_code}") # 尝试缓存到MongoDB(如果可用) self._cache_to_mongodb(result) return result except Exception as e: print(f"⚠️ Tushare API查询失败: {e}") # 3. 最后的降级方案 print("❌ 所有数据源都不可用") return self._get_fallback_data(stock_code) def _get_from_mongodb(self, stock_code: str = None) -> Optional[Dict[str, Any]]: """从MongoDB获取数据""" try: mongodb_client = self.db_manager.get_mongodb_client() if not mongodb_client: return None db = mongodb_client[self.db_manager.mongodb_config["database"]] collection = db['stock_basic_info'] if stock_code: # 获取单个股票 result = collection.find_one({'code': stock_code}) return result if result else None else: # 获取所有股票 cursor = collection.find({}) results = list(cursor) return results if results else None except Exception as e: logger.error(f"MongoDB查询失败: {e}") return None def _cache_to_mongodb(self, data: Any) -> bool: """将数据缓存到MongoDB""" if not self.db_manager or not self.db_manager.mongodb_db: return False try: collection = self.db_manager.mongodb_db['stock_basic_info'] if isinstance(data, list): # 批量插入 for item in data: collection.update_one( {'code': item['code']}, {'$set': item}, upsert=True ) print(f"💾 已缓存{len(data)}条记录到MongoDB") elif isinstance(data, dict): # 单条插入 collection.update_one( {'code': data['code']}, {'$set': data}, upsert=True ) print(f"💾 已缓存股票{data['code']}到MongoDB") return True except Exception as e: logger.error(f"缓存到MongoDB失败: {e}") return False def _get_fallback_data(self, stock_code: str = None) -> Dict[str, Any]: """最后的降级数据""" if stock_code: return { 'code': stock_code, 'name': f'股票{stock_code}', 'market': self._get_market_name(stock_code), 'category': '未知', 'source': 'fallback', 'updated_at': datetime.now().isoformat(), 'error': '所有数据源都不可用' } else: return { 'error': '无法获取股票列表,请检查网络连接和数据库配置', 'suggestion': '请确保MongoDB已配置或网络连接正常以访问tushareAPI' } def _get_market_name(self, stock_code: str) -> str: """根据股票代码判断市场""" if stock_code.startswith(('60', '68', '90')): return '上海' elif stock_code.startswith(('00', '30', '20')): return '深圳' else: return '未知' def _get_stock_category(self, stock_code: str) -> str: """根据股票代码判断类别""" if stock_code.startswith('60'): return '沪市主板' elif stock_code.startswith('68'): return '科创板' elif stock_code.startswith('00'): return '深市主板' elif stock_code.startswith('30'): return '创业板' elif stock_code.startswith('20'): return '深市B股' else: return '其他' def get_stock_data_with_fallback(self, stock_code: str, start_date: str, end_date: str) -> str: """ 获取股票数据(带降级机制) 这是对现有get_china_stock_data函数的增强 """ print(f"📊 获取股票数据: {stock_code} ({start_date} 到 {end_date})") # 首先确保股票基础信息可用 stock_info = self.get_stock_basic_info(stock_code) if stock_info and 'error' in stock_info: return f"❌ 无法获取股票{stock_code}的基础信息: {stock_info.get('error', '未知错误')}" # 调用现有的get_china_stock_data函数 try: from .tushare_utils import get_china_stock_data return get_china_stock_data(stock_code, start_date, end_date) except Exception as e: return f"❌ 获取股票数据失败: {str(e)}\n\n💡 建议:\n1. 检查网络连接\n2. 确认股票代码格式正确\n3. 检查Tushare配置\n4. 配置TUSHARE_TOKEN以获得更好的服务" # 全局服务实例 _stock_data_service = None def get_stock_data_service() -> StockDataService: """获取股票数据服务实例(单例模式)""" global _stock_data_service if _stock_data_service is None: _stock_data_service = StockDataService() return _stock_data_service