# 文件差异报告 # 当前文件: tradingagents\dataflows\optimized_us_data.py # 中文版文件: TradingAgentsCN\tradingagents\dataflows\optimized_us_data.py # 生成时间: 周日 2025/07/06 --- current/optimized_us_data.py+++ chinese_version/optimized_us_data.py@@ -1,7 +1,7 @@ #!/usr/bin/env python3 """ -Optimized US Stock Data Fetcher -Integrates caching strategy to reduce API calls and improve response speed +优化的美股数据获取工具 +集成缓存策略,减少API调用,提高响应速度 """ import os @@ -16,24 +16,24 @@ class OptimizedUSDataProvider: - """Optimized US Stock Data Provider - Integrates caching and API rate limiting""" + """优化的美股数据提供器 - 集成缓存和API限制处理""" def __init__(self): self.cache = get_cache() self.config = get_config() self.last_api_call = 0 - self.min_api_interval = 1.0 # Minimum API call interval (seconds) - - print("📊 Optimized US stock data provider initialized") + self.min_api_interval = 1.0 # 最小API调用间隔(秒) + + print("📊 优化美股数据提供器初始化完成") def _wait_for_rate_limit(self): - """Wait for API rate limit""" + """等待API限制""" current_time = time.time() time_since_last_call = current_time - self.last_api_call if time_since_last_call < self.min_api_interval: wait_time = self.min_api_interval - time_since_last_call - print(f"⏳ API rate limit wait {wait_time:.1f}s...") + print(f"⏳ API限制等待 {wait_time:.1f}s...") time.sleep(wait_time) self.last_api_call = time.time() @@ -41,364 +41,292 @@ def get_stock_data(self, symbol: str, start_date: str, end_date: str, force_refresh: bool = False) -> str: """ - Get US stock data - prioritize cache usage + 获取美股数据 - 优先使用缓存 Args: - symbol: Stock symbol - start_date: Start date (YYYY-MM-DD) - end_date: End date (YYYY-MM-DD) - force_refresh: Whether to force refresh cache - + symbol: 股票代码 + start_date: 开始日期 (YYYY-MM-DD) + end_date: 结束日期 (YYYY-MM-DD) + force_refresh: 是否强制刷新缓存 + Returns: - Formatted stock data string + 格式化的股票数据字符串 """ + print(f"📈 获取美股数据: {symbol} ({start_date} 到 {end_date})") + + # 检查缓存(除非强制刷新) + if not force_refresh: + # 优先查找FINNHUB缓存 + cache_key = self.cache.find_cached_stock_data( + symbol=symbol, + start_date=start_date, + end_date=end_date, + data_source="finnhub" + ) + + # 如果没有FINNHUB缓存,查找Yahoo Finance缓存 + if not cache_key: + cache_key = self.cache.find_cached_stock_data( + symbol=symbol, + start_date=start_date, + end_date=end_date, + data_source="yfinance" + ) + + if cache_key: + cached_data = self.cache.load_stock_data(cache_key) + if cached_data: + print(f"⚡ 从缓存加载美股数据: {symbol}") + return cached_data + + # 缓存未命中,从API获取 - 优先使用FINNHUB + formatted_data = None + data_source = None + + # 尝试FINNHUB API(优先) try: - # Check cache first (unless force refresh) - if not force_refresh: - cache_key = self.cache.find_cached_stock_data( - symbol, start_date, end_date, "optimized_yfinance" - ) - - if cache_key and self.cache.is_cache_valid(cache_key, symbol): - cached_data = self.cache.load_stock_data(cache_key) - if cached_data: - print(f"📖 Using cached data for {symbol}") - if isinstance(cached_data, pd.DataFrame): - return self._format_stock_data(cached_data, symbol) - else: - return cached_data - - # Fetch new data from API - print(f"🌐 Fetching new data for {symbol} from {start_date} to {end_date}") - - # Wait for rate limit + print(f"🌐 从FINNHUB API获取数据: {symbol}") self._wait_for_rate_limit() - - # Try Yahoo Finance first + + formatted_data = self._get_data_from_finnhub(symbol, start_date, end_date) + if formatted_data and "❌" not in formatted_data: + data_source = "finnhub" + print(f"✅ FINNHUB数据获取成功: {symbol}") + else: + print(f"⚠️ FINNHUB数据获取失败,尝试备用方案") + formatted_data = None + + except Exception as e: + print(f"❌ FINNHUB API调用失败: {e}") + formatted_data = None + + # 备用方案:Yahoo Finance API + if not formatted_data: try: - data = self._fetch_from_yfinance(symbol, start_date, end_date) - if data is not None and not data.empty: - # Cache the DataFrame - cache_key = self.cache.save_stock_data( - symbol, data, start_date, end_date, "optimized_yfinance" - ) + print(f"🌐 从Yahoo Finance API获取数据: {symbol}") + self._wait_for_rate_limit() + + # 获取数据 + ticker = yf.Ticker(symbol.upper()) + data = ticker.history(start=start_date, end=end_date) + + if data.empty: + error_msg = f"未找到股票 '{symbol}' 在 {start_date} 到 {end_date} 期间的数据" + print(f"❌ {error_msg}") + else: + # 格式化数据 + formatted_data = self._format_stock_data(symbol, data, start_date, end_date) + data_source = "yfinance" + print(f"✅ Yahoo Finance数据获取成功: {symbol}") + + except Exception as e: + print(f"❌ Yahoo Finance API调用失败: {e}") + formatted_data = None + + # 如果所有API都失败,生成备用数据 + if not formatted_data: + error_msg = "所有美股数据源都不可用" + print(f"❌ {error_msg}") + return self._generate_fallback_data(symbol, start_date, end_date, error_msg) + + # 保存到缓存 + self.cache.save_stock_data( + symbol=symbol, + data=formatted_data, + start_date=start_date, + end_date=end_date, + data_source=data_source + ) + + return formatted_data + + def _format_stock_data(self, symbol: str, data: pd.DataFrame, + start_date: str, end_date: str) -> str: + """格式化股票数据为字符串""" + + # 移除时区信息 + if data.index.tz is not None: + data.index = data.index.tz_localize(None) + + # 四舍五入数值 + numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"] + for col in numeric_columns: + if col in data.columns: + data[col] = data[col].round(2) + + # 获取最新价格和统计信息 + latest_price = data['Close'].iloc[-1] + price_change = data['Close'].iloc[-1] - data['Close'].iloc[0] + price_change_pct = (price_change / data['Close'].iloc[0]) * 100 + + # 计算技术指标 + data['MA5'] = data['Close'].rolling(window=5).mean() + data['MA10'] = data['Close'].rolling(window=10).mean() + data['MA20'] = data['Close'].rolling(window=20).mean() + + # 计算RSI + delta = data['Close'].diff() + gain = (delta.where(delta > 0, 0)).rolling(window=14).mean() + loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean() + rs = gain / loss + rsi = 100 - (100 / (1 + rs)) + + # 格式化输出 + result = f"""# {symbol} 美股数据分析 + +## 📊 基本信息 +- 股票代码: {symbol} +- 数据期间: {start_date} 至 {end_date} +- 数据条数: {len(data)}条 +- 最新价格: ${latest_price:.2f} +- 期间涨跌: ${price_change:+.2f} ({price_change_pct:+.2f}%) + +## 📈 价格统计 +- 期间最高: ${data['High'].max():.2f} +- 期间最低: ${data['Low'].min():.2f} +- 平均成交量: {data['Volume'].mean():,.0f} + +## 🔍 技术指标 +- MA5: ${data['MA5'].iloc[-1]:.2f} +- MA10: ${data['MA10'].iloc[-1]:.2f} +- MA20: ${data['MA20'].iloc[-1]:.2f} +- RSI: {rsi.iloc[-1]:.2f} + +## 📋 最近5日数据 +{data.tail().to_string()} + +数据来源: Yahoo Finance API +更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} +""" + + return result + + def _try_get_old_cache(self, symbol: str, start_date: str, end_date: str) -> Optional[str]: + """尝试获取过期的缓存数据作为备用""" + try: + # 查找任何相关的缓存,不考虑TTL + for metadata_file in self.cache.metadata_dir.glob(f"*_meta.json"): + try: + import json + with open(metadata_file, 'r', encoding='utf-8') as f: + metadata = json.load(f) - # Format and return - formatted_data = self._format_stock_data(data, symbol) - print(f"✅ Successfully fetched and cached data for {symbol}") - return formatted_data - else: - print(f"⚠️ No data returned from Yahoo Finance for {symbol}") - - except Exception as e: - print(f"❌ Yahoo Finance error for {symbol}: {e}") - - # Fallback: Try FINNHUB (if API key available) - try: - finnhub_data = self._fetch_from_finnhub(symbol, start_date, end_date) - if finnhub_data: - # Cache the string data - cache_key = self.cache.save_stock_data( - symbol, finnhub_data, start_date, end_date, "optimized_finnhub" - ) - print(f"✅ Successfully fetched data from FINNHUB for {symbol}") - return finnhub_data - - except Exception as e: - print(f"❌ FINNHUB error for {symbol}: {e}") - - # If all fails, return error message - error_msg = f"❌ Failed to fetch data for {symbol} from {start_date} to {end_date}" - print(error_msg) - return error_msg - + if (metadata.get('symbol') == symbol and + metadata.get('data_type') == 'stock_data' and + metadata.get('market_type') == 'us'): + + cache_key = metadata_file.stem.replace('_meta', '') + cached_data = self.cache.load_stock_data(cache_key) + if cached_data: + return cached_data + "\n\n⚠️ 注意: 使用的是过期缓存数据" + except Exception: + continue + except Exception: + pass + + return None + + def _get_data_from_finnhub(self, symbol: str, start_date: str, end_date: str) -> str: + """从FINNHUB API获取股票数据""" + try: + import finnhub + import os + from datetime import datetime, timedelta + + # 获取API密钥 + api_key = os.getenv('FINNHUB_API_KEY') + if not api_key: + return None + + client = finnhub.Client(api_key=api_key) + + # 获取实时报价 + quote = client.quote(symbol.upper()) + if not quote or 'c' not in quote: + return None + + # 获取公司信息 + profile = client.company_profile2(symbol=symbol.upper()) + company_name = profile.get('name', symbol.upper()) if profile else symbol.upper() + + # 格式化数据 + current_price = quote.get('c', 0) + change = quote.get('d', 0) + change_percent = quote.get('dp', 0) + + formatted_data = f"""# {symbol.upper()} 美股数据分析 + +## 📊 实时行情 +- 股票名称: {company_name} +- 当前价格: ${current_price:.2f} +- 涨跌额: ${change:+.2f} +- 涨跌幅: {change_percent:+.2f}% +- 开盘价: ${quote.get('o', 0):.2f} +- 最高价: ${quote.get('h', 0):.2f} +- 最低价: ${quote.get('l', 0):.2f} +- 前收盘: ${quote.get('pc', 0):.2f} +- 更新时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} + +## 📈 数据概览 +- 数据期间: {start_date} 至 {end_date} +- 数据来源: FINNHUB API (实时数据) +- 当前价位相对位置: {((current_price - quote.get('l', current_price)) / max(quote.get('h', current_price) - quote.get('l', current_price), 0.01) * 100):.1f}% +- 日内振幅: {((quote.get('h', 0) - quote.get('l', 0)) / max(quote.get('pc', 1), 0.01) * 100):.2f}% + +生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} +""" + + return formatted_data + except Exception as e: - error_msg = f"❌ Unexpected error fetching data for {symbol}: {e}" - print(error_msg) - return error_msg - - def _fetch_from_yfinance(self, symbol: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: - """Fetch data from Yahoo Finance""" - try: - ticker = yf.Ticker(symbol) - data = ticker.history(start=start_date, end=end_date) - - if data.empty: - print(f"⚠️ No data available for {symbol} in the specified date range") - return None - - # Reset index to make Date a column - data = data.reset_index() - - # Ensure we have the required columns - required_columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume'] - missing_columns = [col for col in required_columns if col not in data.columns] - - if missing_columns: - print(f"⚠️ Missing columns for {symbol}: {missing_columns}") - return None - - return data[required_columns] - - except Exception as e: - print(f"❌ Yahoo Finance fetch error for {symbol}: {e}") + print(f"❌ FINNHUB数据获取失败: {e}") return None - - def _fetch_from_finnhub(self, symbol: str, start_date: str, end_date: str) -> Optional[str]: - """Fetch data from FINNHUB API""" - try: - # Check if FINNHUB API key is available - finnhub_api_key = os.getenv('FINNHUB_API_KEY') - if not finnhub_api_key: - print("⚠️ FINNHUB API key not found, skipping FINNHUB data fetch") - return None - - import finnhub - - # Initialize FINNHUB client - finnhub_client = finnhub.Client(api_key=finnhub_api_key) - - # Convert dates to timestamps - start_timestamp = int(datetime.strptime(start_date, '%Y-%m-%d').timestamp()) - end_timestamp = int(datetime.strptime(end_date, '%Y-%m-%d').timestamp()) - - # Fetch candle data - candle_data = finnhub_client.stock_candles(symbol, 'D', start_timestamp, end_timestamp) - - if candle_data['s'] != 'ok': - print(f"⚠️ FINNHUB returned status: {candle_data['s']} for {symbol}") - return None - - # Format data - formatted_data = self._format_finnhub_data(candle_data, symbol) - return formatted_data - - except ImportError: - print("⚠️ finnhub-python package not installed, skipping FINNHUB data fetch") - return None - except Exception as e: - print(f"❌ FINNHUB fetch error for {symbol}: {e}") - return None - - def _format_stock_data(self, data: pd.DataFrame, symbol: str) -> str: - """Format DataFrame stock data into string""" - try: - # Ensure Date column is properly formatted - if 'Date' in data.columns: - data['Date'] = pd.to_datetime(data['Date']).dt.strftime('%Y-%m-%d') - - # Round numerical columns to 2 decimal places - numeric_columns = ['Open', 'High', 'Low', 'Close'] - for col in numeric_columns: - if col in data.columns: - data[col] = data[col].round(2) - - # Format volume as integer - if 'Volume' in data.columns: - data['Volume'] = data['Volume'].astype(int) - - # Create formatted string - formatted_lines = [f"Stock Data for {symbol}:"] - formatted_lines.append("Date,Open,High,Low,Close,Volume") - - for _, row in data.iterrows(): - line = f"{row['Date']},{row['Open']},{row['High']},{row['Low']},{row['Close']},{row['Volume']}" - formatted_lines.append(line) - - # Add summary statistics - if len(data) > 0: - formatted_lines.append(f"\nSummary for {symbol}:") - formatted_lines.append(f"Period: {data['Date'].iloc[0]} to {data['Date'].iloc[-1]}") - formatted_lines.append(f"Total trading days: {len(data)}") - formatted_lines.append(f"Average volume: {data['Volume'].mean():,.0f}") - formatted_lines.append(f"Price range: ${data['Low'].min():.2f} - ${data['High'].max():.2f}") - - # Calculate basic statistics - start_price = data['Open'].iloc[0] - end_price = data['Close'].iloc[-1] - price_change = end_price - start_price - price_change_pct = (price_change / start_price) * 100 - - formatted_lines.append(f"Period return: {price_change_pct:+.2f}% (${price_change:+.2f})") - - return "\n".join(formatted_lines) - - except Exception as e: - print(f"❌ Error formatting stock data for {symbol}: {e}") - return f"Error formatting data for {symbol}: {str(e)}" - - def _format_finnhub_data(self, candle_data: Dict, symbol: str) -> str: - """Format FINNHUB candle data into string""" - try: - # Extract data arrays - timestamps = candle_data['t'] - opens = candle_data['o'] - highs = candle_data['h'] - lows = candle_data['l'] - closes = candle_data['c'] - volumes = candle_data['v'] - - # Create formatted string - formatted_lines = [f"Stock Data for {symbol} (FINNHUB):"] - formatted_lines.append("Date,Open,High,Low,Close,Volume") - - for i in range(len(timestamps)): - date = datetime.fromtimestamp(timestamps[i]).strftime('%Y-%m-%d') - line = f"{date},{opens[i]:.2f},{highs[i]:.2f},{lows[i]:.2f},{closes[i]:.2f},{int(volumes[i])}" - formatted_lines.append(line) - - # Add summary - if len(timestamps) > 0: - start_date = datetime.fromtimestamp(timestamps[0]).strftime('%Y-%m-%d') - end_date = datetime.fromtimestamp(timestamps[-1]).strftime('%Y-%m-%d') - - formatted_lines.append(f"\nSummary for {symbol}:") - formatted_lines.append(f"Period: {start_date} to {end_date}") - formatted_lines.append(f"Total trading days: {len(timestamps)}") - formatted_lines.append(f"Average volume: {sum(volumes)/len(volumes):,.0f}") - formatted_lines.append(f"Price range: ${min(lows):.2f} - ${max(highs):.2f}") - - # Calculate return - price_change = closes[-1] - opens[0] - price_change_pct = (price_change / opens[0]) * 100 - formatted_lines.append(f"Period return: {price_change_pct:+.2f}% (${price_change:+.2f})") - - return "\n".join(formatted_lines) - - except Exception as e: - print(f"❌ Error formatting FINNHUB data for {symbol}: {e}") - return f"Error formatting FINNHUB data for {symbol}: {str(e)}" - - def get_stock_with_indicators(self, symbol: str, start_date: str, end_date: str, - indicators: list = None) -> str: - """ - Get stock data with technical indicators - - Args: - symbol: Stock symbol - start_date: Start date (YYYY-MM-DD) - end_date: End date (YYYY-MM-DD) - indicators: List of indicators to calculate ['sma_20', 'rsi', 'macd'] - - Returns: - Formatted stock data with indicators - """ - try: - # Get basic stock data - basic_data = self.get_stock_data(symbol, start_date, end_date) - - if basic_data.startswith("❌"): - return basic_data - - # If no indicators requested, return basic data - if not indicators: - return basic_data - - # Fetch DataFrame for indicator calculation - data_df = self._fetch_from_yfinance(symbol, start_date, end_date) - if data_df is None or data_df.empty: - return basic_data - - # Calculate indicators - indicator_data = self._calculate_indicators(data_df, indicators) - - # Combine basic data with indicators - combined_data = basic_data + "\n\nTechnical Indicators:\n" + indicator_data - - return combined_data - - except Exception as e: - error_msg = f"❌ Error getting stock data with indicators for {symbol}: {e}" - print(error_msg) - return error_msg - - def _calculate_indicators(self, data: pd.DataFrame, indicators: list) -> str: - """Calculate technical indicators""" - try: - indicator_lines = [] - - for indicator in indicators: - if indicator == 'sma_20': - data['SMA_20'] = data['Close'].rolling(window=20).mean() - latest_sma = data['SMA_20'].iloc[-1] - indicator_lines.append(f"SMA(20): ${latest_sma:.2f}") - - elif indicator == 'rsi': - # Simple RSI calculation - delta = data['Close'].diff() - gain = (delta.where(delta > 0, 0)).rolling(window=14).mean() - loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean() - rs = gain / loss - rsi = 100 - (100 / (1 + rs)) - latest_rsi = rsi.iloc[-1] - indicator_lines.append(f"RSI(14): {latest_rsi:.2f}") - - elif indicator == 'macd': - # Simple MACD calculation - ema_12 = data['Close'].ewm(span=12).mean() - ema_26 = data['Close'].ewm(span=26).mean() - macd_line = ema_12 - ema_26 - signal_line = macd_line.ewm(span=9).mean() - latest_macd = macd_line.iloc[-1] - latest_signal = signal_line.iloc[-1] - indicator_lines.append(f"MACD: {latest_macd:.4f}, Signal: {latest_signal:.4f}") - - return "\n".join(indicator_lines) - - except Exception as e: - print(f"❌ Error calculating indicators: {e}") - return f"Error calculating indicators: {str(e)}" - - -# Global provider instance -_global_provider = None + + def _generate_fallback_data(self, symbol: str, start_date: str, end_date: str, error_msg: str) -> str: + """生成备用数据""" + return f"""# {symbol} 美股数据获取失败 + +## ❌ 错误信息 +{error_msg} + +## 📊 模拟数据(仅供演示) +- 股票代码: {symbol} +- 数据期间: {start_date} 至 {end_date} +- 最新价格: ${random.uniform(100, 300):.2f} +- 模拟涨跌: {random.uniform(-5, 5):+.2f}% + +## ⚠️ 重要提示 +由于API限制或网络问题,无法获取实时数据。 +建议稍后重试或检查网络连接。 + +生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} +""" + + +# 全局实例 +_us_data_provider = None def get_optimized_us_data_provider() -> OptimizedUSDataProvider: + """获取全局美股数据提供器实例""" + global _us_data_provider + if _us_data_provider is None: + _us_data_provider = OptimizedUSDataProvider() + return _us_data_provider + + +def get_us_stock_data_cached(symbol: str, start_date: str, end_date: str, + force_refresh: bool = False) -> str: """ - Get global optimized US data provider instance + 获取美股数据的便捷函数 + + Args: + symbol: 股票代码 + start_date: 开始日期 (YYYY-MM-DD) + end_date: 结束日期 (YYYY-MM-DD) + force_refresh: 是否强制刷新缓存 Returns: - OptimizedUSDataProvider instance + 格式化的股票数据字符串 """ - global _global_provider - if _global_provider is None: - _global_provider = OptimizedUSDataProvider() - return _global_provider - - -# Convenience functions -def get_optimized_stock_data(symbol: str, start_date: str, end_date: str, - force_refresh: bool = False) -> str: - """Get optimized stock data (convenience function)""" provider = get_optimized_us_data_provider() return provider.get_stock_data(symbol, start_date, end_date, force_refresh) - - -def get_stock_with_indicators(symbol: str, start_date: str, end_date: str, - indicators: list = None) -> str: - """Get stock data with technical indicators (convenience function)""" - provider = get_optimized_us_data_provider() - return provider.get_stock_with_indicators(symbol, start_date, end_date, indicators) - - -if __name__ == "__main__": - # Test the optimized data provider - print("🧪 Testing Optimized US Data Provider...") - - # Initialize provider - provider = OptimizedUSDataProvider() - - # Test data fetch - data = provider.get_stock_data("AAPL", "2024-01-01", "2024-01-31") - print("Sample data:") - print(data[:500] + "..." if len(data) > 500 else data) - - # Test with indicators - data_with_indicators = provider.get_stock_with_indicators( - "AAPL", "2024-01-01", "2024-01-31", - indicators=['sma_20', 'rsi', 'macd'] - ) - print("\nData with indicators:") - print(data_with_indicators[-500:] if len(data_with_indicators) > 500 else data_with_indicators) - - print("✅ Optimized data provider test completed!")