TradingAgents/tests/test_akshare.py

833 lines
34 KiB
Python

"""
Test suite for AKShare data vendor integration.
This module tests:
1. Date format conversion helper (_convert_date_format)
2. Exponential backoff retry mechanism (_exponential_backoff_retry)
3. US stock data retrieval (get_akshare_stock_data_us)
4. Chinese stock data retrieval (get_akshare_stock_data_cn)
5. Auto-market detection (get_akshare_stock_data)
6. AKShareRateLimitError exception handling
7. Integration with vendor routing system (interface.py)
Test Coverage:
- Unit tests for individual helper functions
- Integration tests for stock data retrieval functions
- Edge cases (empty data, network errors, rate limits)
- Vendor fallback behavior with rate limit errors
"""
import pytest
import pandas as pd
import time
import sys
from unittest.mock import Mock, patch, MagicMock, call
from datetime import datetime
from typing import Callable, Any
# Clear any cached imports and mock akshare before importing our modules
if 'tradingagents.dataflows.akshare' in sys.modules:
del sys.modules['tradingagents.dataflows.akshare']
if 'akshare' in sys.modules:
del sys.modules['akshare']
mock_akshare = MagicMock()
sys.modules['akshare'] = mock_akshare
# Import modules under test
from tradingagents.dataflows.akshare import (
AKShareRateLimitError,
_convert_date_format,
_exponential_backoff_retry,
get_akshare_stock_data_us,
get_akshare_stock_data_cn,
get_akshare_stock_data,
)
from tradingagents.dataflows.interface import route_to_vendor, VENDOR_METHODS
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def sample_us_dataframe():
"""Create a sample US stock data DataFrame matching akshare format."""
return pd.DataFrame({
'date': pd.date_range('2024-01-01', periods=5, freq='D'),
'open': [150.0, 151.0, 152.0, 153.0, 154.0],
'high': [152.0, 153.0, 154.0, 155.0, 156.0],
'low': [149.0, 150.0, 151.0, 152.0, 153.0],
'close': [151.0, 152.0, 153.0, 154.0, 155.0],
'volume': [1000000, 1100000, 1200000, 1300000, 1400000],
})
@pytest.fixture
def sample_cn_dataframe():
"""Create a sample Chinese stock data DataFrame matching akshare format."""
return pd.DataFrame({
'日期': ['2024-01-01', '2024-01-02', '2024-01-03', '2024-01-04', '2024-01-05'],
'开盘': [10.0, 10.1, 10.2, 10.3, 10.4],
'最高': [10.2, 10.3, 10.4, 10.5, 10.6],
'最低': [9.9, 10.0, 10.1, 10.2, 10.3],
'收盘': [10.1, 10.2, 10.3, 10.4, 10.5],
'成交量': [500000, 550000, 600000, 650000, 700000],
})
@pytest.fixture
def sample_standardized_dataframe():
"""Create a standardized DataFrame with English column names."""
return pd.DataFrame({
'Date': pd.date_range('2024-01-01', periods=5, freq='D'),
'Open': [150.0, 151.0, 152.0, 153.0, 154.0],
'High': [152.0, 153.0, 154.0, 155.0, 156.0],
'Low': [149.0, 150.0, 151.0, 152.0, 153.0],
'Close': [151.0, 152.0, 153.0, 154.0, 155.0],
'Volume': [1000000, 1100000, 1200000, 1300000, 1400000],
})
@pytest.fixture
def mock_akshare():
"""Mock akshare module for testing."""
with patch('tradingagents.dataflows.akshare.ak') as mock_ak:
yield mock_ak
@pytest.fixture
def mock_time_sleep():
"""Mock time.sleep to speed up retry tests."""
with patch('tradingagents.dataflows.akshare.time.sleep') as mock_sleep:
yield mock_sleep
# ============================================================================
# Test Date Format Conversion
# ============================================================================
class TestConvertDateFormat:
"""Test the _convert_date_format helper function."""
def test_standard_date_format_with_hyphen(self):
"""Test conversion from YYYY-MM-DD to YYYYMMDD format."""
result = _convert_date_format("2024-01-15")
assert result == "20240115"
def test_standard_date_format_with_single_digits(self):
"""Test conversion handles single-digit months and days."""
result = _convert_date_format("2024-1-5")
assert result == "202415"
def test_handles_slash_separator(self):
"""Test conversion from YYYY/MM/DD format."""
result = _convert_date_format("2024/01/15")
assert result == "20240115"
def test_preserves_yyyymmdd_format(self):
"""Test that already-correct format passes through."""
result = _convert_date_format("20240115")
assert result == "20240115"
def test_handles_various_date_formats(self):
"""Test multiple valid date format variations."""
test_cases = [
("2024-12-31", "20241231"),
("2024-01-01", "20240101"),
("2023-06-15", "20230615"),
]
for input_date, expected in test_cases:
assert _convert_date_format(input_date) == expected
def test_empty_string_raises_error(self):
"""Test that empty string raises appropriate error."""
with pytest.raises((ValueError, IndexError)):
_convert_date_format("")
def test_invalid_format_raises_error(self):
"""Test that invalid format raises appropriate error."""
with pytest.raises((ValueError, IndexError)):
_convert_date_format("not-a-date")
# ============================================================================
# Test Exponential Backoff Retry
# ============================================================================
class TestExponentialBackoffRetry:
"""Test the _exponential_backoff_retry helper function."""
def test_returns_on_first_success(self, mock_time_sleep):
"""Test that successful function returns immediately without retries."""
mock_func = Mock(return_value="success")
result = _exponential_backoff_retry(mock_func, max_retries=3)
assert result == "success"
assert mock_func.call_count == 1
assert mock_time_sleep.call_count == 0
def test_retries_on_failure(self, mock_time_sleep):
"""Test that function retries on failure up to max_retries."""
mock_func = Mock(side_effect=[
Exception("First failure"),
Exception("Second failure"),
"success"
])
result = _exponential_backoff_retry(mock_func, max_retries=3)
assert result == "success"
assert mock_func.call_count == 3
# Should sleep after 1st and 2nd failures
assert mock_time_sleep.call_count == 2
def test_exponential_delay(self, mock_time_sleep):
"""Test that delays increase exponentially."""
mock_func = Mock(side_effect=[
Exception("Failure 1"),
Exception("Failure 2"),
"success"
])
_exponential_backoff_retry(mock_func, max_retries=3)
# Verify exponential backoff: 2^0=1, 2^1=2
calls = mock_time_sleep.call_args_list
assert len(calls) == 2
assert calls[0][0][0] == 1 # First retry: 2^0 = 1 second
assert calls[1][0][0] == 2 # Second retry: 2^1 = 2 seconds
def test_raises_after_max_retries(self, mock_time_sleep):
"""Test that original error is raised after exhausting retries."""
error_msg = "Persistent failure"
mock_func = Mock(side_effect=Exception(error_msg))
with pytest.raises(Exception, match=error_msg):
_exponential_backoff_retry(mock_func, max_retries=3)
assert mock_func.call_count == 4 # Initial + 3 retries
assert mock_time_sleep.call_count == 3
def test_raises_rate_limit_error(self, mock_time_sleep):
"""Test that rate limit errors are raised as AKShareRateLimitError."""
mock_func = Mock(side_effect=Exception("Rate limit exceeded"))
with pytest.raises(AKShareRateLimitError):
_exponential_backoff_retry(mock_func, max_retries=2)
def test_handles_timeout_errors(self, mock_time_sleep):
"""Test handling of timeout errors."""
from requests.exceptions import Timeout
mock_func = Mock(side_effect=[Timeout("Network timeout"), "success"])
result = _exponential_backoff_retry(mock_func, max_retries=3)
assert result == "success"
assert mock_func.call_count == 2
def test_max_retries_zero(self):
"""Test behavior with max_retries=0."""
mock_func = Mock(side_effect=Exception("Failure"))
with pytest.raises(Exception):
_exponential_backoff_retry(mock_func, max_retries=0)
assert mock_func.call_count == 1 # Only initial call, no retries
def test_preserves_function_arguments(self, mock_time_sleep):
"""Test that function arguments are preserved across retries."""
mock_func = Mock(side_effect=[Exception("Fail"), "success"])
result = _exponential_backoff_retry(
lambda: mock_func("arg1", kwarg1="value1"),
max_retries=2
)
assert result == "success"
assert all(
call_args == call("arg1", kwarg1="value1")
for call_args in mock_func.call_args_list
)
# ============================================================================
# Test US Stock Data Retrieval
# ============================================================================
class TestGetAkshareStockDataUs:
"""Test the get_akshare_stock_data_us function."""
def test_returns_dataframe_on_success(self, mock_akshare, sample_us_dataframe):
"""Test successful data retrieval returns DataFrame."""
mock_akshare.stock_us_hist.return_value = sample_us_dataframe
result = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-01-05")
assert isinstance(result, str) # Returns CSV string
assert "AAPL" in result
assert "2024-01-01" in result
mock_akshare.stock_us_hist.assert_called_once_with(
symbol="AAPL",
period="daily",
adjust=""
)
def test_filters_data_by_date_range(self, mock_akshare):
"""Test that data is properly filtered by date range."""
# Create DataFrame with wider date range
full_df = pd.DataFrame({
'date': pd.to_datetime(['2023-12-28', '2024-01-02', '2024-01-03', '2024-01-04', '2024-01-08']),
'open': [145.0, 150.0, 151.0, 152.0, 157.0],
'high': [147.0, 152.0, 153.0, 154.0, 159.0],
'low': [144.0, 149.0, 150.0, 151.0, 156.0],
'close': [146.0, 151.0, 152.0, 153.0, 158.0],
'volume': [900000, 1000000, 1100000, 1200000, 1500000],
})
mock_akshare.stock_us_hist.return_value = full_df
result = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-01-05")
# Result should only contain dates within range
assert "2023-12-28" not in result
assert "2024-01-08" not in result
assert "2024-01-02" in result or "2024-01-03" in result
def test_returns_error_string_on_failure(self, mock_akshare):
"""Test that exceptions return error string instead of raising."""
mock_akshare.stock_us_hist.side_effect = Exception("API error")
result = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-01-05")
assert isinstance(result, str)
assert "error" in result.lower() or "failed" in result.lower()
def test_handles_empty_data(self, mock_akshare):
"""Test handling of empty DataFrame from API."""
mock_akshare.stock_us_hist.return_value = pd.DataFrame()
result = get_akshare_stock_data_us("INVALID", "2024-01-01", "2024-01-05")
assert isinstance(result, str)
assert "no data" in result.lower() or "empty" in result.lower()
def test_handles_network_timeout(self, mock_akshare):
"""Test handling of network timeout errors."""
from requests.exceptions import Timeout
mock_akshare.stock_us_hist.side_effect = Timeout("Connection timeout")
result = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-01-05")
assert isinstance(result, str)
assert "timeout" in result.lower() or "error" in result.lower()
def test_standardizes_output_format(self, mock_akshare, sample_us_dataframe):
"""Test that output format matches expected CSV structure."""
mock_akshare.stock_us_hist.return_value = sample_us_dataframe
result = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-01-05")
# Should contain CSV header information
assert "Stock data for" in result or "AAPL" in result
# Should contain column headers
lines = result.split('\n')
assert len(lines) > 1 # Has header + data rows
def test_handles_malformed_dates(self, mock_akshare):
"""Test handling of invalid date formats."""
result = get_akshare_stock_data_us("AAPL", "invalid-date", "2024-01-05")
assert isinstance(result, str)
# Should return error message rather than raising exception
def test_symbol_case_handling(self, mock_akshare, sample_us_dataframe):
"""Test that symbol is converted to uppercase."""
mock_akshare.stock_us_hist.return_value = sample_us_dataframe
result = get_akshare_stock_data_us("aapl", "2024-01-01", "2024-01-05")
mock_akshare.stock_us_hist.assert_called_once()
call_args = mock_akshare.stock_us_hist.call_args
assert call_args[1]['symbol'] == "AAPL" or call_args[1]['symbol'] == "aapl"
# ============================================================================
# Test Chinese Stock Data Retrieval
# ============================================================================
class TestGetAkshareStockDataCn:
"""Test the get_akshare_stock_data_cn function."""
def test_returns_dataframe_on_success(self, mock_akshare, sample_cn_dataframe):
"""Test successful data retrieval returns DataFrame."""
mock_akshare.stock_zh_a_hist.return_value = sample_cn_dataframe
result = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-01-05")
assert isinstance(result, str)
assert "000001" in result
mock_akshare.stock_zh_a_hist.assert_called_once()
def test_converts_date_format(self, mock_akshare, sample_cn_dataframe):
"""Test that dates are converted to YYYYMMDD format for API."""
mock_akshare.stock_zh_a_hist.return_value = sample_cn_dataframe
get_akshare_stock_data_cn("000001", "2024-01-01", "2024-01-05")
call_args = mock_akshare.stock_zh_a_hist.call_args
# Verify date format conversion happened
assert call_args is not None
def test_returns_error_string_on_failure(self, mock_akshare):
"""Test that exceptions return error string."""
mock_akshare.stock_zh_a_hist.side_effect = Exception("API error")
result = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-01-05")
assert isinstance(result, str)
assert "error" in result.lower() or "failed" in result.lower()
def test_standardizes_column_names(self, mock_akshare, sample_cn_dataframe):
"""Test that Chinese column names are mapped to English."""
mock_akshare.stock_zh_a_hist.return_value = sample_cn_dataframe
result = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-01-05")
# Output should contain English column names
result_lower = result.lower()
assert "date" in result_lower or "open" in result_lower or "close" in result_lower
def test_handles_empty_data(self, mock_akshare):
"""Test handling of empty DataFrame."""
mock_akshare.stock_zh_a_hist.return_value = pd.DataFrame()
result = get_akshare_stock_data_cn("INVALID", "2024-01-01", "2024-01-05")
assert isinstance(result, str)
assert "no data" in result.lower() or "empty" in result.lower()
def test_handles_symbol_with_suffix(self, mock_akshare, sample_cn_dataframe):
"""Test handling of symbols with .SZ or .SH suffixes."""
mock_akshare.stock_zh_a_hist.return_value = sample_cn_dataframe
result = get_akshare_stock_data_cn("000001.SZ", "2024-01-01", "2024-01-05")
assert isinstance(result, str)
# Function should handle suffix appropriately
def test_handles_rate_limit_error(self, mock_akshare):
"""Test that rate limit errors are properly raised."""
mock_akshare.stock_zh_a_hist.side_effect = Exception("访问频率过快") # Chinese rate limit message
# Should raise AKShareRateLimitError when wrapped in retry mechanism
with pytest.raises(AKShareRateLimitError):
_exponential_backoff_retry(
lambda: mock_akshare.stock_zh_a_hist(),
max_retries=1
)
def test_standardizes_output_format(self, mock_akshare, sample_cn_dataframe):
"""Test that output format is standardized."""
mock_akshare.stock_zh_a_hist.return_value = sample_cn_dataframe
result = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-01-05")
# Should be CSV-like format
lines = result.split('\n')
assert len(lines) > 1
# ============================================================================
# Test Auto-Market Detection
# ============================================================================
class TestGetAkshareStockData:
"""Test the get_akshare_stock_data function with auto-market detection."""
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data_us')
def test_auto_detects_us_market(self, mock_us_func):
"""Test that US symbols are automatically detected."""
mock_us_func.return_value = "US data"
result = get_akshare_stock_data("AAPL", "2024-01-01", "2024-01-05")
assert result == "US data"
mock_us_func.assert_called_once_with("AAPL", "2024-01-01", "2024-01-05")
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data_cn')
def test_auto_detects_cn_market_with_sz_suffix(self, mock_cn_func):
"""Test that Chinese symbols with .SZ suffix are detected."""
mock_cn_func.return_value = "CN data"
result = get_akshare_stock_data("000001.SZ", "2024-01-01", "2024-01-05")
assert result == "CN data"
mock_cn_func.assert_called_once_with("000001.SZ", "2024-01-01", "2024-01-05")
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data_cn')
def test_auto_detects_cn_market_with_sh_suffix(self, mock_cn_func):
"""Test that Chinese symbols with .SH suffix are detected."""
mock_cn_func.return_value = "CN data"
result = get_akshare_stock_data("600000.SH", "2024-01-01", "2024-01-05")
assert result == "CN data"
mock_cn_func.assert_called_once_with("600000.SH", "2024-01-01", "2024-01-05")
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data_cn')
def test_auto_detects_cn_market_numeric_only(self, mock_cn_func):
"""Test that numeric-only symbols default to Chinese market."""
mock_cn_func.return_value = "CN data"
result = get_akshare_stock_data("000001", "2024-01-01", "2024-01-05")
assert result == "CN data"
mock_cn_func.assert_called_once_with("000001", "2024-01-01", "2024-01-05")
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data_us')
def test_explicit_market_us(self, mock_us_func):
"""Test that market='us' forces US function."""
mock_us_func.return_value = "US data"
result = get_akshare_stock_data("AAPL", "2024-01-01", "2024-01-05", market="us")
assert result == "US data"
mock_us_func.assert_called_once()
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data_cn')
def test_explicit_market_cn(self, mock_cn_func):
"""Test that market='cn' forces Chinese function."""
mock_cn_func.return_value = "CN data"
result = get_akshare_stock_data("000001", "2024-01-01", "2024-01-05", market="cn")
assert result == "CN data"
mock_cn_func.assert_called_once()
def test_returns_standardized_dataframe(self):
"""Test that output has consistent schema regardless of market."""
with patch('tradingagents.dataflows.akshare.get_akshare_stock_data_us') as mock_us:
mock_us.return_value = "Date,Open,High,Low,Close,Volume\n2024-01-01,150,152,149,151,1000000"
result_us = get_akshare_stock_data("AAPL", "2024-01-01", "2024-01-05")
with patch('tradingagents.dataflows.akshare.get_akshare_stock_data_cn') as mock_cn:
mock_cn.return_value = "Date,Open,High,Low,Close,Volume\n2024-01-01,10,10.2,9.9,10.1,500000"
result_cn = get_akshare_stock_data("000001", "2024-01-01", "2024-01-05")
# Both should have similar structure
assert isinstance(result_us, str)
assert isinstance(result_cn, str)
def test_invalid_market_parameter(self):
"""Test handling of invalid market parameter."""
with pytest.raises(ValueError):
get_akshare_stock_data("AAPL", "2024-01-01", "2024-01-05", market="invalid")
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data_us')
def test_market_auto_default_behavior(self, mock_us_func):
"""Test that market='auto' is the default behavior."""
mock_us_func.return_value = "US data"
# Call without market parameter
result1 = get_akshare_stock_data("AAPL", "2024-01-01", "2024-01-05")
# Call with explicit market='auto'
result2 = get_akshare_stock_data("AAPL", "2024-01-01", "2024-01-05", market="auto")
assert result1 == result2
# ============================================================================
# Test AKShareRateLimitError Exception
# ============================================================================
class TestAKShareRateLimitError:
"""Test the AKShareRateLimitError exception class."""
def test_is_exception_subclass(self):
"""Test that AKShareRateLimitError inherits from Exception."""
assert issubclass(AKShareRateLimitError, Exception)
def test_can_be_raised_and_caught(self):
"""Test that exception can be raised and caught properly."""
with pytest.raises(AKShareRateLimitError):
raise AKShareRateLimitError("Rate limit exceeded")
def test_message_included(self):
"""Test that error message is preserved."""
message = "API rate limit exceeded: 5 calls per minute"
try:
raise AKShareRateLimitError(message)
except AKShareRateLimitError as e:
assert str(e) == message
def test_can_be_caught_as_generic_exception(self):
"""Test that it can be caught as generic Exception."""
with pytest.raises(Exception):
raise AKShareRateLimitError("Rate limit")
def test_distinct_from_other_exceptions(self):
"""Test that it's distinct from other exception types."""
try:
raise AKShareRateLimitError("Rate limit")
except ValueError:
pytest.fail("Should not be caught as ValueError")
except AKShareRateLimitError:
pass # Expected
# ============================================================================
# Test Vendor Integration (interface.py modifications)
# ============================================================================
class TestVendorIntegration:
"""Test integration with the vendor routing system in interface.py."""
def test_akshare_in_vendor_methods(self):
"""Test that akshare is registered in VENDOR_METHODS."""
assert "get_stock_data" in VENDOR_METHODS
assert "akshare" in VENDOR_METHODS["get_stock_data"]
def test_akshare_vendor_function_mapping(self):
"""Test that akshare maps to correct function."""
from tradingagents.dataflows.akshare import get_akshare_stock_data
akshare_impl = VENDOR_METHODS["get_stock_data"]["akshare"]
assert akshare_impl == get_akshare_stock_data
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data')
@patch('tradingagents.dataflows.config.get_config')
def test_route_to_vendor_uses_akshare(self, mock_config, mock_akshare_func):
"""Test that route_to_vendor calls akshare when configured."""
mock_config.return_value = {"data_vendor": "akshare"}
mock_akshare_func.return_value = "AKShare data"
result = route_to_vendor("get_stock_data", "AAPL", "2024-01-01", "2024-01-05")
assert result == "AKShare data"
mock_akshare_func.assert_called_once_with("AAPL", "2024-01-01", "2024-01-05")
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data')
@patch('tradingagents.dataflows.y_finance.get_YFin_data_online')
@patch('tradingagents.dataflows.config.get_config')
def test_fallback_on_rate_limit(self, mock_config, mock_yfinance, mock_akshare):
"""Test that AKShareRateLimitError triggers fallback to next vendor."""
mock_config.return_value = {"data_vendor": "akshare,yfinance"}
mock_akshare.side_effect = AKShareRateLimitError("Rate limit exceeded")
mock_yfinance.return_value = "YFinance data"
result = route_to_vendor("get_stock_data", "AAPL", "2024-01-01", "2024-01-05")
assert result == "YFinance data"
mock_akshare.assert_called_once()
mock_yfinance.assert_called_once()
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data')
@patch('tradingagents.dataflows.y_finance.get_YFin_data_online')
@patch('tradingagents.dataflows.config.get_config')
def test_fallback_chain_akshare_yfinance(self, mock_config, mock_yfinance, mock_akshare):
"""Test multi-vendor fallback chain works correctly."""
mock_config.return_value = {"data_vendor": "akshare,yfinance,local"}
mock_akshare.side_effect = AKShareRateLimitError("Rate limit")
mock_yfinance.return_value = "YFinance fallback data"
result = route_to_vendor("get_stock_data", "AAPL", "2024-01-01", "2024-01-05")
assert "YFinance" in result
# Verify akshare was tried first, then yfinance succeeded
assert mock_akshare.call_count == 1
assert mock_yfinance.call_count == 1
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data')
@patch('tradingagents.dataflows.config.get_config')
def test_akshare_error_string_not_triggers_fallback(self, mock_config, mock_akshare):
"""Test that error strings (not exceptions) don't trigger fallback."""
mock_config.return_value = {"data_vendor": "akshare"}
# Return error string, not exception
mock_akshare.return_value = "Error: No data found"
result = route_to_vendor("get_stock_data", "INVALID", "2024-01-01", "2024-01-05")
# Should return the error string, not attempt fallback
assert "Error" in result
assert mock_akshare.call_count == 1
def test_akshare_in_vendor_list(self):
"""Test that akshare is in the global VENDOR_LIST."""
from tradingagents.dataflows.interface import VENDOR_LIST
assert "akshare" in VENDOR_LIST
# ============================================================================
# Integration Tests
# ============================================================================
class TestAKShareIntegration:
"""Integration tests combining multiple components."""
@patch('tradingagents.dataflows.akshare.ak')
def test_end_to_end_us_stock_retrieval(self, mock_ak):
"""Test complete flow of US stock data retrieval."""
# Setup mock
mock_df = pd.DataFrame({
'date': pd.date_range('2024-01-01', periods=3, freq='D'),
'open': [150.0, 151.0, 152.0],
'high': [152.0, 153.0, 154.0],
'low': [149.0, 150.0, 151.0],
'close': [151.0, 152.0, 153.0],
'volume': [1000000, 1100000, 1200000],
})
mock_ak.stock_us_hist.return_value = mock_df
# Call main function
result = get_akshare_stock_data("AAPL", "2024-01-01", "2024-01-03", market="us")
# Verify result
assert isinstance(result, str)
assert "AAPL" in result or "150" in result
@patch('tradingagents.dataflows.akshare.ak')
def test_end_to_end_cn_stock_retrieval(self, mock_ak):
"""Test complete flow of Chinese stock data retrieval."""
# Setup mock with Chinese column names
mock_df = pd.DataFrame({
'日期': ['2024-01-01', '2024-01-02', '2024-01-03'],
'开盘': [10.0, 10.1, 10.2],
'最高': [10.2, 10.3, 10.4],
'最低': [9.9, 10.0, 10.1],
'收盘': [10.1, 10.2, 10.3],
'成交量': [500000, 550000, 600000],
})
mock_ak.stock_zh_a_hist.return_value = mock_df
# Call main function
result = get_akshare_stock_data("000001", "2024-01-01", "2024-01-03", market="cn")
# Verify result
assert isinstance(result, str)
assert "000001" in result or "10" in result
@patch('tradingagents.dataflows.akshare.ak')
@patch('tradingagents.dataflows.akshare.time.sleep')
def test_retry_mechanism_with_transient_failure(self, mock_sleep, mock_ak):
"""Test that retry mechanism handles transient failures."""
# First call fails, second succeeds
mock_df = pd.DataFrame({
'date': pd.date_range('2024-01-01', periods=2, freq='D'),
'open': [150.0, 151.0],
'high': [152.0, 153.0],
'low': [149.0, 150.0],
'close': [151.0, 152.0],
'volume': [1000000, 1100000],
})
mock_ak.stock_us_hist.side_effect = [
Exception("Transient network error"),
mock_df
]
# Should succeed on retry
result = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-01-02")
assert isinstance(result, str)
assert mock_ak.stock_us_hist.call_count == 2
assert mock_sleep.call_count == 1 # One retry delay
def test_column_name_standardization(self):
"""Test that Chinese and US data have standardized column names."""
with patch('tradingagents.dataflows.akshare.ak') as mock_ak:
# Test US format
us_df = pd.DataFrame({
'date': pd.date_range('2024-01-01', periods=2, freq='D'),
'open': [150.0, 151.0],
'high': [152.0, 153.0],
'low': [149.0, 150.0],
'close': [151.0, 152.0],
'volume': [1000000, 1100000],
})
mock_ak.stock_us_hist.return_value = us_df
us_result = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-01-02")
# Test CN format
cn_df = pd.DataFrame({
'日期': ['2024-01-01', '2024-01-02'],
'开盘': [10.0, 10.1],
'最高': [10.2, 10.3],
'最低': [9.9, 10.0],
'收盘': [10.1, 10.2],
'成交量': [500000, 550000],
})
mock_ak.stock_zh_a_hist.return_value = cn_df
cn_result = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-01-02")
# Both should have English headers
for result in [us_result, cn_result]:
result_lower = result.lower()
# At least some standard column names should appear
has_standard_cols = any(
col in result_lower
for col in ['date', 'open', 'high', 'low', 'close', 'volume']
)
assert has_standard_cols
# ============================================================================
# Edge Cases and Error Handling
# ============================================================================
class TestEdgeCases:
"""Test edge cases and error conditions."""
def test_empty_symbol(self):
"""Test handling of empty symbol string."""
result = get_akshare_stock_data("", "2024-01-01", "2024-01-05")
assert isinstance(result, str)
def test_future_dates(self):
"""Test handling of future dates."""
with patch('tradingagents.dataflows.akshare.ak') as mock_ak:
mock_ak.stock_us_hist.return_value = pd.DataFrame()
result = get_akshare_stock_data_us("AAPL", "2030-01-01", "2030-01-05")
assert isinstance(result, str)
def test_start_date_after_end_date(self):
"""Test handling when start_date > end_date."""
result = get_akshare_stock_data("AAPL", "2024-01-05", "2024-01-01")
assert isinstance(result, str)
def test_very_long_date_range(self):
"""Test handling of very long date ranges (years of data)."""
with patch('tradingagents.dataflows.akshare.ak') as mock_ak:
# Simulate large dataset
large_df = pd.DataFrame({
'date': pd.date_range('2020-01-01', periods=1000, freq='D'),
'open': [150.0] * 1000,
'high': [152.0] * 1000,
'low': [149.0] * 1000,
'close': [151.0] * 1000,
'volume': [1000000] * 1000,
})
mock_ak.stock_us_hist.return_value = large_df
result = get_akshare_stock_data_us("AAPL", "2020-01-01", "2024-01-01")
assert isinstance(result, str)
def test_special_characters_in_symbol(self):
"""Test handling of symbols with special characters."""
with patch('tradingagents.dataflows.akshare.ak') as mock_ak:
mock_ak.stock_us_hist.return_value = pd.DataFrame()
result = get_akshare_stock_data_us("BRK.B", "2024-01-01", "2024-01-05")
assert isinstance(result, str)
@patch('tradingagents.dataflows.akshare.ak')
def test_unicode_in_error_messages(self, mock_ak):
"""Test handling of Unicode characters in error messages."""
mock_ak.stock_zh_a_hist.side_effect = Exception("访问频率过快,请稍后重试")
result = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-01-05")
assert isinstance(result, str)
def test_none_parameters(self):
"""Test handling of None parameters."""
with pytest.raises((TypeError, AttributeError)):
get_akshare_stock_data(None, "2024-01-01", "2024-01-05")