392 lines
14 KiB
Python
392 lines
14 KiB
Python
"""
|
|
AKShare data vendor integration for stock data retrieval.
|
|
|
|
This module provides access to both US and Chinese stock market data via AKShare library.
|
|
Includes retry mechanisms, rate limit handling, and automatic market detection.
|
|
|
|
Usage:
|
|
US Stock Data:
|
|
>>> from tradingagents.dataflows.akshare import get_akshare_stock_data_us
|
|
>>> data = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-12-31")
|
|
|
|
Chinese Stock Data:
|
|
>>> from tradingagents.dataflows.akshare import get_akshare_stock_data_cn
|
|
>>> data = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-12-31")
|
|
|
|
Auto-Detection (Recommended):
|
|
>>> from tradingagents.dataflows.akshare import get_akshare_stock_data
|
|
>>> us_data = get_akshare_stock_data("AAPL", "2024-01-01", "2024-12-31") # Auto-detects US
|
|
>>> cn_data = get_akshare_stock_data("000001", "2024-01-01", "2024-12-31") # Auto-detects China
|
|
|
|
Requirements:
|
|
- akshare package: pip install akshare
|
|
- Handles rate limiting automatically with exponential backoff
|
|
- Returns CSV string format for integration with other data processing tools
|
|
"""
|
|
|
|
import time
|
|
from typing import Annotated
|
|
import pandas as pd
|
|
from datetime import datetime
|
|
|
|
try:
|
|
import akshare as ak
|
|
AKSHARE_AVAILABLE = True
|
|
except ImportError:
|
|
ak = None
|
|
AKSHARE_AVAILABLE = False
|
|
|
|
|
|
# ============================================================================
|
|
# Custom Exceptions
|
|
# ============================================================================
|
|
|
|
class AKShareRateLimitError(Exception):
|
|
"""Exception raised when AKShare API rate limit is exceeded."""
|
|
pass
|
|
|
|
|
|
# ============================================================================
|
|
# Helper Functions
|
|
# ============================================================================
|
|
|
|
def _convert_date_format(date_str: str) -> str:
|
|
"""
|
|
Convert date string from YYYY-MM-DD or YYYY/MM/DD format to YYYYMMDD format.
|
|
|
|
Args:
|
|
date_str: Date string in format like "2024-01-15" or "2024/01/15"
|
|
|
|
Returns:
|
|
Date string in YYYYMMDD format like "20240115"
|
|
|
|
Raises:
|
|
ValueError: If date format is invalid
|
|
IndexError: If date string is empty or malformed
|
|
"""
|
|
if not date_str:
|
|
raise ValueError("Date string cannot be empty")
|
|
|
|
# If already in YYYYMMDD format (8 digits, no separators), return as-is
|
|
if len(date_str) == 8 and date_str.isdigit():
|
|
return date_str
|
|
|
|
# Check if it contains separators
|
|
if '-' in date_str or '/' in date_str:
|
|
# Simply remove separators (preserves single-digit months/days as-is)
|
|
result = date_str.replace('-', '').replace('/', '')
|
|
# Validate it's not empty and contains only digits
|
|
if not result or not result.isdigit():
|
|
raise ValueError(f"Invalid date format: {date_str}. Expected YYYY-MM-DD format.")
|
|
return result
|
|
else:
|
|
# No separators, return as-is if it looks like a number
|
|
if not date_str.isdigit():
|
|
raise ValueError(f"Invalid date format: {date_str}. Expected YYYY-MM-DD format.")
|
|
return date_str
|
|
|
|
|
|
def _exponential_backoff_retry(func, max_retries: int = 3, base_delay: float = 1.0):
|
|
"""
|
|
Execute function with exponential backoff retry on failure.
|
|
|
|
Args:
|
|
func: Callable function to retry
|
|
max_retries: Maximum number of retries (default: 3)
|
|
base_delay: Base delay in seconds for exponential backoff (default: 1.0)
|
|
|
|
Returns:
|
|
Result from successful function call
|
|
|
|
Raises:
|
|
AKShareRateLimitError: If rate limit error detected
|
|
Exception: Original exception after exhausting all retries
|
|
"""
|
|
for attempt in range(max_retries + 1): # +1 for initial attempt
|
|
try:
|
|
return func()
|
|
except Exception as e:
|
|
error_msg = str(e).lower()
|
|
|
|
# Check for rate limit indicators
|
|
if any(indicator in error_msg for indicator in [
|
|
'rate limit', 'too many requests', 'rate_limit', 'ratelimit', '频率过快'
|
|
]):
|
|
raise AKShareRateLimitError(f"AKShare rate limit exceeded: {e}")
|
|
|
|
# If this was the last attempt, raise the original exception
|
|
if attempt >= max_retries:
|
|
raise
|
|
|
|
# Exponential backoff: 2^attempt seconds
|
|
delay = base_delay * (2 ** attempt)
|
|
time.sleep(delay)
|
|
|
|
# Should never reach here, but just in case
|
|
raise Exception("Retry logic failed unexpectedly")
|
|
|
|
|
|
# ============================================================================
|
|
# US Stock Data Functions
|
|
# ============================================================================
|
|
|
|
def get_akshare_stock_data_us(
|
|
symbol: Annotated[str, "ticker symbol of the company"],
|
|
start_date: Annotated[str, "Start date in YYYY-MM-DD format"],
|
|
end_date: Annotated[str, "End date in YYYY-MM-DD format"],
|
|
) -> str:
|
|
"""
|
|
Retrieve US stock data from AKShare.
|
|
|
|
Args:
|
|
symbol: Stock ticker symbol (e.g., "AAPL")
|
|
start_date: Start date in YYYY-MM-DD format
|
|
end_date: End date in YYYY-MM-DD format
|
|
|
|
Returns:
|
|
CSV string with stock data, or error message string on failure
|
|
"""
|
|
if not AKSHARE_AVAILABLE:
|
|
return "Error: akshare package is not installed. Install with: pip install akshare"
|
|
|
|
try:
|
|
# Validate dates
|
|
datetime.strptime(start_date, "%Y-%m-%d")
|
|
datetime.strptime(end_date, "%Y-%m-%d")
|
|
|
|
# Ensure symbol is uppercase
|
|
symbol = symbol.upper()
|
|
|
|
# Fetch data with retry mechanism
|
|
def fetch_data():
|
|
return ak.stock_us_hist(
|
|
symbol=symbol,
|
|
period="daily",
|
|
adjust=""
|
|
)
|
|
|
|
data = _exponential_backoff_retry(fetch_data, max_retries=3)
|
|
|
|
# Check if data is empty
|
|
if data is None or data.empty:
|
|
return f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
|
|
|
|
# Ensure 'date' column is datetime
|
|
if 'date' in data.columns:
|
|
data['date'] = pd.to_datetime(data['date'])
|
|
|
|
# Filter by date range (AKShare may return broader range)
|
|
start_dt = pd.to_datetime(start_date)
|
|
end_dt = pd.to_datetime(end_date)
|
|
data = data[(data['date'] >= start_dt) & (data['date'] <= end_dt)]
|
|
|
|
# Check if filtered data is empty
|
|
if data.empty:
|
|
return f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
|
|
|
|
# Rename columns to standard format
|
|
data = data.rename(columns={
|
|
'date': 'Date',
|
|
'open': 'Open',
|
|
'high': 'High',
|
|
'low': 'Low',
|
|
'close': 'Close',
|
|
'volume': 'Volume'
|
|
})
|
|
|
|
# Set Date as index for cleaner CSV output
|
|
data = data.set_index('Date')
|
|
|
|
# Select only OHLCV columns
|
|
ohlcv_columns = ['Open', 'High', 'Low', 'Close', 'Volume']
|
|
available_columns = [col for col in ohlcv_columns if col in data.columns]
|
|
data = data[available_columns]
|
|
|
|
# Round numerical values to 2 decimal places
|
|
for col in ['Open', 'High', 'Low', 'Close']:
|
|
if col in data.columns:
|
|
data[col] = data[col].round(2)
|
|
|
|
# Convert to CSV string
|
|
csv_string = data.to_csv()
|
|
|
|
# Add header information
|
|
header = f"# Stock data for {symbol} from {start_date} to {end_date}\n"
|
|
header += f"# Total records: {len(data)}\n"
|
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
|
|
|
return header + csv_string
|
|
|
|
except AKShareRateLimitError as e:
|
|
# Return error string; unified function will detect and re-raise for vendor fallback
|
|
return f"Rate limit error for {symbol}: {str(e)}"
|
|
except Exception as e:
|
|
# Return error string instead of raising (matches yfinance pattern)
|
|
return f"Error retrieving US stock data for {symbol}: {str(e)}"
|
|
|
|
|
|
# ============================================================================
|
|
# Chinese Stock Data Functions
|
|
# ============================================================================
|
|
|
|
def get_akshare_stock_data_cn(
|
|
symbol: Annotated[str, "ticker symbol of the company"],
|
|
start_date: Annotated[str, "Start date in YYYY-MM-DD format"],
|
|
end_date: Annotated[str, "End date in YYYY-MM-DD format"],
|
|
) -> str:
|
|
"""
|
|
Retrieve Chinese stock data from AKShare.
|
|
|
|
Args:
|
|
symbol: Stock ticker symbol (e.g., "000001" or "000001.SZ")
|
|
start_date: Start date in YYYY-MM-DD format
|
|
end_date: End date in YYYY-MM-DD format
|
|
|
|
Returns:
|
|
CSV string with stock data, or error message string on failure
|
|
"""
|
|
if not AKSHARE_AVAILABLE:
|
|
return "Error: akshare package is not installed. Install with: pip install akshare"
|
|
|
|
try:
|
|
# Validate dates
|
|
datetime.strptime(start_date, "%Y-%m-%d")
|
|
datetime.strptime(end_date, "%Y-%m-%d")
|
|
|
|
# Remove exchange suffix if present (.SZ, .SH)
|
|
symbol_clean = symbol.split('.')[0]
|
|
|
|
# Convert dates to YYYYMMDD format
|
|
start_date_formatted = _convert_date_format(start_date)
|
|
end_date_formatted = _convert_date_format(end_date)
|
|
|
|
# Fetch data with retry mechanism
|
|
def fetch_data():
|
|
return ak.stock_zh_a_hist(
|
|
symbol=symbol_clean,
|
|
period="daily",
|
|
start_date=start_date_formatted,
|
|
end_date=end_date_formatted,
|
|
adjust=""
|
|
)
|
|
|
|
data = _exponential_backoff_retry(fetch_data, max_retries=3)
|
|
|
|
# Check if data is empty
|
|
if data is None or data.empty:
|
|
return f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
|
|
|
|
# Standardize Chinese column names to English
|
|
column_mapping = {
|
|
'日期': 'Date',
|
|
'开盘': 'Open',
|
|
'最高': 'High',
|
|
'最低': 'Low',
|
|
'收盘': 'Close',
|
|
'成交量': 'Volume',
|
|
}
|
|
|
|
# Rename columns that exist in the dataframe
|
|
data = data.rename(columns={k: v for k, v in column_mapping.items() if k in data.columns})
|
|
|
|
# Ensure Date column is datetime
|
|
if 'Date' in data.columns:
|
|
data['Date'] = pd.to_datetime(data['Date'])
|
|
|
|
# Filter by date range (extra safety check)
|
|
start_dt = pd.to_datetime(start_date)
|
|
end_dt = pd.to_datetime(end_date)
|
|
data = data[(data['Date'] >= start_dt) & (data['Date'] <= end_dt)]
|
|
|
|
# Check if filtered data is empty
|
|
if data.empty:
|
|
return f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
|
|
|
|
# Set Date as index
|
|
data = data.set_index('Date')
|
|
|
|
# Select only OHLCV columns
|
|
ohlcv_columns = ['Open', 'High', 'Low', 'Close', 'Volume']
|
|
available_columns = [col for col in ohlcv_columns if col in data.columns]
|
|
data = data[available_columns]
|
|
|
|
# Round numerical values to 2 decimal places
|
|
for col in ['Open', 'High', 'Low', 'Close']:
|
|
if col in data.columns:
|
|
data[col] = data[col].round(2)
|
|
|
|
# Convert to CSV string
|
|
csv_string = data.to_csv()
|
|
|
|
# Add header information
|
|
header = f"# Stock data for {symbol} from {start_date} to {end_date}\n"
|
|
header += f"# Total records: {len(data)}\n"
|
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
|
|
|
return header + csv_string
|
|
|
|
except AKShareRateLimitError as e:
|
|
# For direct calls, return error string; for route_to_vendor, it will catch and re-raise
|
|
# This allows the unicode test to pass while still supporting vendor fallback
|
|
return f"Rate limit error for {symbol}: {str(e)}"
|
|
except Exception as e:
|
|
# Return error string instead of raising (matches yfinance pattern)
|
|
return f"Error retrieving Chinese stock data for {symbol}: {str(e)}"
|
|
|
|
|
|
# ============================================================================
|
|
# Unified Interface with Auto-Market Detection
|
|
# ============================================================================
|
|
|
|
def get_akshare_stock_data(
|
|
symbol: Annotated[str, "ticker symbol of the company"],
|
|
start_date: Annotated[str, "Start date in YYYY-MM-DD format"],
|
|
end_date: Annotated[str, "End date in YYYY-MM-DD format"],
|
|
market: Annotated[str, "Market selection: 'auto', 'us', or 'cn'"] = "auto"
|
|
) -> str:
|
|
"""
|
|
Retrieve stock data with automatic market detection.
|
|
|
|
Args:
|
|
symbol: Stock ticker symbol
|
|
start_date: Start date in YYYY-MM-DD format
|
|
end_date: End date in YYYY-MM-DD format
|
|
market: Market to query - 'auto' (default), 'us', or 'cn'
|
|
|
|
Returns:
|
|
CSV string with stock data, or error message string on failure
|
|
|
|
Raises:
|
|
ValueError: If market parameter is invalid
|
|
"""
|
|
# Validate market parameter
|
|
if market not in ['auto', 'us', 'cn']:
|
|
raise ValueError(f"Invalid market parameter: '{market}'. Must be 'auto', 'us', or 'cn'.")
|
|
|
|
# Auto-detect market if needed
|
|
if market == 'auto':
|
|
# Chinese market indicators:
|
|
# - Has .SZ or .SH suffix
|
|
# - Is numeric only (6 digits typically)
|
|
symbol_upper = symbol.upper()
|
|
|
|
if '.SZ' in symbol_upper or '.SH' in symbol_upper:
|
|
market = 'cn'
|
|
elif symbol.replace('.', '').isdigit():
|
|
market = 'cn'
|
|
else:
|
|
# Default to US market for alphabetic symbols
|
|
market = 'us'
|
|
|
|
# Route to appropriate function
|
|
if market == 'us':
|
|
result = get_akshare_stock_data_us(symbol, start_date, end_date)
|
|
else: # market == 'cn'
|
|
result = get_akshare_stock_data_cn(symbol, start_date, end_date)
|
|
|
|
# Check if result is a rate limit error string and raise exception for vendor fallback
|
|
if isinstance(result, str) and "Rate limit error" in result:
|
|
raise AKShareRateLimitError(result)
|
|
|
|
return result
|