""" AKShare data vendor integration for stock data retrieval. This module provides access to both US and Chinese stock market data via AKShare library. Includes retry mechanisms, rate limit handling, and automatic market detection. Usage: US Stock Data: >>> from tradingagents.dataflows.akshare import get_akshare_stock_data_us >>> data = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-12-31") Chinese Stock Data: >>> from tradingagents.dataflows.akshare import get_akshare_stock_data_cn >>> data = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-12-31") Auto-Detection (Recommended): >>> from tradingagents.dataflows.akshare import get_akshare_stock_data >>> us_data = get_akshare_stock_data("AAPL", "2024-01-01", "2024-12-31") # Auto-detects US >>> cn_data = get_akshare_stock_data("000001", "2024-01-01", "2024-12-31") # Auto-detects China Requirements: - akshare package: pip install akshare - Handles rate limiting automatically with exponential backoff - Returns CSV string format for integration with other data processing tools """ import time from typing import Annotated import pandas as pd from datetime import datetime try: import akshare as ak AKSHARE_AVAILABLE = True except ImportError: ak = None AKSHARE_AVAILABLE = False # ============================================================================ # Custom Exceptions # ============================================================================ class AKShareRateLimitError(Exception): """Exception raised when AKShare API rate limit is exceeded.""" pass # ============================================================================ # Helper Functions # ============================================================================ def _convert_date_format(date_str: str) -> str: """ Convert date string from YYYY-MM-DD or YYYY/MM/DD format to YYYYMMDD format. Args: date_str: Date string in format like "2024-01-15" or "2024/01/15" Returns: Date string in YYYYMMDD format like "20240115" Raises: ValueError: If date format is invalid IndexError: If date string is empty or malformed """ if not date_str: raise ValueError("Date string cannot be empty") # If already in YYYYMMDD format (8 digits, no separators), return as-is if len(date_str) == 8 and date_str.isdigit(): return date_str # Check if it contains separators if '-' in date_str or '/' in date_str: # Simply remove separators (preserves single-digit months/days as-is) result = date_str.replace('-', '').replace('/', '') # Validate it's not empty and contains only digits if not result or not result.isdigit(): raise ValueError(f"Invalid date format: {date_str}. Expected YYYY-MM-DD format.") return result else: # No separators, return as-is if it looks like a number if not date_str.isdigit(): raise ValueError(f"Invalid date format: {date_str}. Expected YYYY-MM-DD format.") return date_str def _exponential_backoff_retry(func, max_retries: int = 3, base_delay: float = 1.0): """ Execute function with exponential backoff retry on failure. Args: func: Callable function to retry max_retries: Maximum number of retries (default: 3) base_delay: Base delay in seconds for exponential backoff (default: 1.0) Returns: Result from successful function call Raises: AKShareRateLimitError: If rate limit error detected Exception: Original exception after exhausting all retries """ for attempt in range(max_retries + 1): # +1 for initial attempt try: return func() except Exception as e: error_msg = str(e).lower() # Check for rate limit indicators if any(indicator in error_msg for indicator in [ 'rate limit', 'too many requests', 'rate_limit', 'ratelimit', '频率过快' ]): raise AKShareRateLimitError(f"AKShare rate limit exceeded: {e}") # If this was the last attempt, raise the original exception if attempt >= max_retries: raise # Exponential backoff: 2^attempt seconds delay = base_delay * (2 ** attempt) time.sleep(delay) # Should never reach here, but just in case raise Exception("Retry logic failed unexpectedly") # ============================================================================ # US Stock Data Functions # ============================================================================ def get_akshare_stock_data_us( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in YYYY-MM-DD format"], end_date: Annotated[str, "End date in YYYY-MM-DD format"], ) -> str: """ Retrieve US stock data from AKShare. Args: symbol: Stock ticker symbol (e.g., "AAPL") start_date: Start date in YYYY-MM-DD format end_date: End date in YYYY-MM-DD format Returns: CSV string with stock data, or error message string on failure """ if not AKSHARE_AVAILABLE: return "Error: akshare package is not installed. Install with: pip install akshare" try: # Validate dates datetime.strptime(start_date, "%Y-%m-%d") datetime.strptime(end_date, "%Y-%m-%d") # Ensure symbol is uppercase symbol = symbol.upper() # Fetch data with retry mechanism def fetch_data(): return ak.stock_us_hist( symbol=symbol, period="daily", adjust="" ) data = _exponential_backoff_retry(fetch_data, max_retries=3) # Check if data is empty if data is None or data.empty: return f"No data found for symbol '{symbol}' between {start_date} and {end_date}" # Ensure 'date' column is datetime if 'date' in data.columns: data['date'] = pd.to_datetime(data['date']) # Filter by date range (AKShare may return broader range) start_dt = pd.to_datetime(start_date) end_dt = pd.to_datetime(end_date) data = data[(data['date'] >= start_dt) & (data['date'] <= end_dt)] # Check if filtered data is empty if data.empty: return f"No data found for symbol '{symbol}' between {start_date} and {end_date}" # Rename columns to standard format data = data.rename(columns={ 'date': 'Date', 'open': 'Open', 'high': 'High', 'low': 'Low', 'close': 'Close', 'volume': 'Volume' }) # Set Date as index for cleaner CSV output data = data.set_index('Date') # Select only OHLCV columns ohlcv_columns = ['Open', 'High', 'Low', 'Close', 'Volume'] available_columns = [col for col in ohlcv_columns if col in data.columns] data = data[available_columns] # Round numerical values to 2 decimal places for col in ['Open', 'High', 'Low', 'Close']: if col in data.columns: data[col] = data[col].round(2) # Convert to CSV string csv_string = data.to_csv() # Add header information header = f"# Stock data for {symbol} from {start_date} to {end_date}\n" header += f"# Total records: {len(data)}\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" return header + csv_string except AKShareRateLimitError as e: # Return error string; unified function will detect and re-raise for vendor fallback return f"Rate limit error for {symbol}: {str(e)}" except Exception as e: # Return error string instead of raising (matches yfinance pattern) return f"Error retrieving US stock data for {symbol}: {str(e)}" # ============================================================================ # Chinese Stock Data Functions # ============================================================================ def get_akshare_stock_data_cn( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in YYYY-MM-DD format"], end_date: Annotated[str, "End date in YYYY-MM-DD format"], ) -> str: """ Retrieve Chinese stock data from AKShare. Args: symbol: Stock ticker symbol (e.g., "000001" or "000001.SZ") start_date: Start date in YYYY-MM-DD format end_date: End date in YYYY-MM-DD format Returns: CSV string with stock data, or error message string on failure """ if not AKSHARE_AVAILABLE: return "Error: akshare package is not installed. Install with: pip install akshare" try: # Validate dates datetime.strptime(start_date, "%Y-%m-%d") datetime.strptime(end_date, "%Y-%m-%d") # Remove exchange suffix if present (.SZ, .SH) symbol_clean = symbol.split('.')[0] # Convert dates to YYYYMMDD format start_date_formatted = _convert_date_format(start_date) end_date_formatted = _convert_date_format(end_date) # Fetch data with retry mechanism def fetch_data(): return ak.stock_zh_a_hist( symbol=symbol_clean, period="daily", start_date=start_date_formatted, end_date=end_date_formatted, adjust="" ) data = _exponential_backoff_retry(fetch_data, max_retries=3) # Check if data is empty if data is None or data.empty: return f"No data found for symbol '{symbol}' between {start_date} and {end_date}" # Standardize Chinese column names to English column_mapping = { '日期': 'Date', '开盘': 'Open', '最高': 'High', '最低': 'Low', '收盘': 'Close', '成交量': 'Volume', } # Rename columns that exist in the dataframe data = data.rename(columns={k: v for k, v in column_mapping.items() if k in data.columns}) # Ensure Date column is datetime if 'Date' in data.columns: data['Date'] = pd.to_datetime(data['Date']) # Filter by date range (extra safety check) start_dt = pd.to_datetime(start_date) end_dt = pd.to_datetime(end_date) data = data[(data['Date'] >= start_dt) & (data['Date'] <= end_dt)] # Check if filtered data is empty if data.empty: return f"No data found for symbol '{symbol}' between {start_date} and {end_date}" # Set Date as index data = data.set_index('Date') # Select only OHLCV columns ohlcv_columns = ['Open', 'High', 'Low', 'Close', 'Volume'] available_columns = [col for col in ohlcv_columns if col in data.columns] data = data[available_columns] # Round numerical values to 2 decimal places for col in ['Open', 'High', 'Low', 'Close']: if col in data.columns: data[col] = data[col].round(2) # Convert to CSV string csv_string = data.to_csv() # Add header information header = f"# Stock data for {symbol} from {start_date} to {end_date}\n" header += f"# Total records: {len(data)}\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" return header + csv_string except AKShareRateLimitError as e: # For direct calls, return error string; for route_to_vendor, it will catch and re-raise # This allows the unicode test to pass while still supporting vendor fallback return f"Rate limit error for {symbol}: {str(e)}" except Exception as e: # Return error string instead of raising (matches yfinance pattern) return f"Error retrieving Chinese stock data for {symbol}: {str(e)}" # ============================================================================ # Unified Interface with Auto-Market Detection # ============================================================================ def get_akshare_stock_data( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in YYYY-MM-DD format"], end_date: Annotated[str, "End date in YYYY-MM-DD format"], market: Annotated[str, "Market selection: 'auto', 'us', or 'cn'"] = "auto" ) -> str: """ Retrieve stock data with automatic market detection. Args: symbol: Stock ticker symbol start_date: Start date in YYYY-MM-DD format end_date: End date in YYYY-MM-DD format market: Market to query - 'auto' (default), 'us', or 'cn' Returns: CSV string with stock data, or error message string on failure Raises: ValueError: If market parameter is invalid """ # Validate market parameter if market not in ['auto', 'us', 'cn']: raise ValueError(f"Invalid market parameter: '{market}'. Must be 'auto', 'us', or 'cn'.") # Auto-detect market if needed if market == 'auto': # Chinese market indicators: # - Has .SZ or .SH suffix # - Is numeric only (6 digits typically) symbol_upper = symbol.upper() if '.SZ' in symbol_upper or '.SH' in symbol_upper: market = 'cn' elif symbol.replace('.', '').isdigit(): market = 'cn' else: # Default to US market for alphabetic symbols market = 'us' # Route to appropriate function if market == 'us': result = get_akshare_stock_data_us(symbol, start_date, end_date) else: # market == 'cn' result = get_akshare_stock_data_cn(symbol, start_date, end_date) # Check if result is a rate limit error string and raise exception for vendor fallback if isinstance(result, str) and "Rate limit error" in result: raise AKShareRateLimitError(result) return result