TradingAgents/tradingagents/dataflows/akshare.py

392 lines
14 KiB
Python

"""
AKShare data vendor integration for stock data retrieval.
This module provides access to both US and Chinese stock market data via AKShare library.
Includes retry mechanisms, rate limit handling, and automatic market detection.
Usage:
US Stock Data:
>>> from tradingagents.dataflows.akshare import get_akshare_stock_data_us
>>> data = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-12-31")
Chinese Stock Data:
>>> from tradingagents.dataflows.akshare import get_akshare_stock_data_cn
>>> data = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-12-31")
Auto-Detection (Recommended):
>>> from tradingagents.dataflows.akshare import get_akshare_stock_data
>>> us_data = get_akshare_stock_data("AAPL", "2024-01-01", "2024-12-31") # Auto-detects US
>>> cn_data = get_akshare_stock_data("000001", "2024-01-01", "2024-12-31") # Auto-detects China
Requirements:
- akshare package: pip install akshare
- Handles rate limiting automatically with exponential backoff
- Returns CSV string format for integration with other data processing tools
"""
import time
from typing import Annotated
import pandas as pd
from datetime import datetime
try:
import akshare as ak
AKSHARE_AVAILABLE = True
except ImportError:
ak = None
AKSHARE_AVAILABLE = False
# ============================================================================
# Custom Exceptions
# ============================================================================
class AKShareRateLimitError(Exception):
"""Exception raised when AKShare API rate limit is exceeded."""
pass
# ============================================================================
# Helper Functions
# ============================================================================
def _convert_date_format(date_str: str) -> str:
"""
Convert date string from YYYY-MM-DD or YYYY/MM/DD format to YYYYMMDD format.
Args:
date_str: Date string in format like "2024-01-15" or "2024/01/15"
Returns:
Date string in YYYYMMDD format like "20240115"
Raises:
ValueError: If date format is invalid
IndexError: If date string is empty or malformed
"""
if not date_str:
raise ValueError("Date string cannot be empty")
# If already in YYYYMMDD format (8 digits, no separators), return as-is
if len(date_str) == 8 and date_str.isdigit():
return date_str
# Check if it contains separators
if '-' in date_str or '/' in date_str:
# Simply remove separators (preserves single-digit months/days as-is)
result = date_str.replace('-', '').replace('/', '')
# Validate it's not empty and contains only digits
if not result or not result.isdigit():
raise ValueError(f"Invalid date format: {date_str}. Expected YYYY-MM-DD format.")
return result
else:
# No separators, return as-is if it looks like a number
if not date_str.isdigit():
raise ValueError(f"Invalid date format: {date_str}. Expected YYYY-MM-DD format.")
return date_str
def _exponential_backoff_retry(func, max_retries: int = 3, base_delay: float = 1.0):
"""
Execute function with exponential backoff retry on failure.
Args:
func: Callable function to retry
max_retries: Maximum number of retries (default: 3)
base_delay: Base delay in seconds for exponential backoff (default: 1.0)
Returns:
Result from successful function call
Raises:
AKShareRateLimitError: If rate limit error detected
Exception: Original exception after exhausting all retries
"""
for attempt in range(max_retries + 1): # +1 for initial attempt
try:
return func()
except Exception as e:
error_msg = str(e).lower()
# Check for rate limit indicators
if any(indicator in error_msg for indicator in [
'rate limit', 'too many requests', 'rate_limit', 'ratelimit', '频率过快'
]):
raise AKShareRateLimitError(f"AKShare rate limit exceeded: {e}")
# If this was the last attempt, raise the original exception
if attempt >= max_retries:
raise
# Exponential backoff: 2^attempt seconds
delay = base_delay * (2 ** attempt)
time.sleep(delay)
# Should never reach here, but just in case
raise Exception("Retry logic failed unexpectedly")
# ============================================================================
# US Stock Data Functions
# ============================================================================
def get_akshare_stock_data_us(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in YYYY-MM-DD format"],
end_date: Annotated[str, "End date in YYYY-MM-DD format"],
) -> str:
"""
Retrieve US stock data from AKShare.
Args:
symbol: Stock ticker symbol (e.g., "AAPL")
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
Returns:
CSV string with stock data, or error message string on failure
"""
if not AKSHARE_AVAILABLE:
return "Error: akshare package is not installed. Install with: pip install akshare"
try:
# Validate dates
datetime.strptime(start_date, "%Y-%m-%d")
datetime.strptime(end_date, "%Y-%m-%d")
# Ensure symbol is uppercase
symbol = symbol.upper()
# Fetch data with retry mechanism
def fetch_data():
return ak.stock_us_hist(
symbol=symbol,
period="daily",
adjust=""
)
data = _exponential_backoff_retry(fetch_data, max_retries=3)
# Check if data is empty
if data is None or data.empty:
return f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
# Ensure 'date' column is datetime
if 'date' in data.columns:
data['date'] = pd.to_datetime(data['date'])
# Filter by date range (AKShare may return broader range)
start_dt = pd.to_datetime(start_date)
end_dt = pd.to_datetime(end_date)
data = data[(data['date'] >= start_dt) & (data['date'] <= end_dt)]
# Check if filtered data is empty
if data.empty:
return f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
# Rename columns to standard format
data = data.rename(columns={
'date': 'Date',
'open': 'Open',
'high': 'High',
'low': 'Low',
'close': 'Close',
'volume': 'Volume'
})
# Set Date as index for cleaner CSV output
data = data.set_index('Date')
# Select only OHLCV columns
ohlcv_columns = ['Open', 'High', 'Low', 'Close', 'Volume']
available_columns = [col for col in ohlcv_columns if col in data.columns]
data = data[available_columns]
# Round numerical values to 2 decimal places
for col in ['Open', 'High', 'Low', 'Close']:
if col in data.columns:
data[col] = data[col].round(2)
# Convert to CSV string
csv_string = data.to_csv()
# Add header information
header = f"# Stock data for {symbol} from {start_date} to {end_date}\n"
header += f"# Total records: {len(data)}\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
return header + csv_string
except AKShareRateLimitError as e:
# Return error string; unified function will detect and re-raise for vendor fallback
return f"Rate limit error for {symbol}: {str(e)}"
except Exception as e:
# Return error string instead of raising (matches yfinance pattern)
return f"Error retrieving US stock data for {symbol}: {str(e)}"
# ============================================================================
# Chinese Stock Data Functions
# ============================================================================
def get_akshare_stock_data_cn(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in YYYY-MM-DD format"],
end_date: Annotated[str, "End date in YYYY-MM-DD format"],
) -> str:
"""
Retrieve Chinese stock data from AKShare.
Args:
symbol: Stock ticker symbol (e.g., "000001" or "000001.SZ")
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
Returns:
CSV string with stock data, or error message string on failure
"""
if not AKSHARE_AVAILABLE:
return "Error: akshare package is not installed. Install with: pip install akshare"
try:
# Validate dates
datetime.strptime(start_date, "%Y-%m-%d")
datetime.strptime(end_date, "%Y-%m-%d")
# Remove exchange suffix if present (.SZ, .SH)
symbol_clean = symbol.split('.')[0]
# Convert dates to YYYYMMDD format
start_date_formatted = _convert_date_format(start_date)
end_date_formatted = _convert_date_format(end_date)
# Fetch data with retry mechanism
def fetch_data():
return ak.stock_zh_a_hist(
symbol=symbol_clean,
period="daily",
start_date=start_date_formatted,
end_date=end_date_formatted,
adjust=""
)
data = _exponential_backoff_retry(fetch_data, max_retries=3)
# Check if data is empty
if data is None or data.empty:
return f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
# Standardize Chinese column names to English
column_mapping = {
'日期': 'Date',
'开盘': 'Open',
'最高': 'High',
'最低': 'Low',
'收盘': 'Close',
'成交量': 'Volume',
}
# Rename columns that exist in the dataframe
data = data.rename(columns={k: v for k, v in column_mapping.items() if k in data.columns})
# Ensure Date column is datetime
if 'Date' in data.columns:
data['Date'] = pd.to_datetime(data['Date'])
# Filter by date range (extra safety check)
start_dt = pd.to_datetime(start_date)
end_dt = pd.to_datetime(end_date)
data = data[(data['Date'] >= start_dt) & (data['Date'] <= end_dt)]
# Check if filtered data is empty
if data.empty:
return f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
# Set Date as index
data = data.set_index('Date')
# Select only OHLCV columns
ohlcv_columns = ['Open', 'High', 'Low', 'Close', 'Volume']
available_columns = [col for col in ohlcv_columns if col in data.columns]
data = data[available_columns]
# Round numerical values to 2 decimal places
for col in ['Open', 'High', 'Low', 'Close']:
if col in data.columns:
data[col] = data[col].round(2)
# Convert to CSV string
csv_string = data.to_csv()
# Add header information
header = f"# Stock data for {symbol} from {start_date} to {end_date}\n"
header += f"# Total records: {len(data)}\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
return header + csv_string
except AKShareRateLimitError as e:
# For direct calls, return error string; for route_to_vendor, it will catch and re-raise
# This allows the unicode test to pass while still supporting vendor fallback
return f"Rate limit error for {symbol}: {str(e)}"
except Exception as e:
# Return error string instead of raising (matches yfinance pattern)
return f"Error retrieving Chinese stock data for {symbol}: {str(e)}"
# ============================================================================
# Unified Interface with Auto-Market Detection
# ============================================================================
def get_akshare_stock_data(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in YYYY-MM-DD format"],
end_date: Annotated[str, "End date in YYYY-MM-DD format"],
market: Annotated[str, "Market selection: 'auto', 'us', or 'cn'"] = "auto"
) -> str:
"""
Retrieve stock data with automatic market detection.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
market: Market to query - 'auto' (default), 'us', or 'cn'
Returns:
CSV string with stock data, or error message string on failure
Raises:
ValueError: If market parameter is invalid
"""
# Validate market parameter
if market not in ['auto', 'us', 'cn']:
raise ValueError(f"Invalid market parameter: '{market}'. Must be 'auto', 'us', or 'cn'.")
# Auto-detect market if needed
if market == 'auto':
# Chinese market indicators:
# - Has .SZ or .SH suffix
# - Is numeric only (6 digits typically)
symbol_upper = symbol.upper()
if '.SZ' in symbol_upper or '.SH' in symbol_upper:
market = 'cn'
elif symbol.replace('.', '').isdigit():
market = 'cn'
else:
# Default to US market for alphabetic symbols
market = 'us'
# Route to appropriate function
if market == 'us':
result = get_akshare_stock_data_us(symbol, start_date, end_date)
else: # market == 'cn'
result = get_akshare_stock_data_cn(symbol, start_date, end_date)
# Check if result is a rate limit error string and raise exception for vendor fallback
if isinstance(result, str) and "Rate limit error" in result:
raise AKShareRateLimitError(result)
return result