833 lines
34 KiB
Python
833 lines
34 KiB
Python
"""
|
|
Test suite for AKShare data vendor integration.
|
|
|
|
This module tests:
|
|
1. Date format conversion helper (_convert_date_format)
|
|
2. Exponential backoff retry mechanism (_exponential_backoff_retry)
|
|
3. US stock data retrieval (get_akshare_stock_data_us)
|
|
4. Chinese stock data retrieval (get_akshare_stock_data_cn)
|
|
5. Auto-market detection (get_akshare_stock_data)
|
|
6. AKShareRateLimitError exception handling
|
|
7. Integration with vendor routing system (interface.py)
|
|
|
|
Test Coverage:
|
|
- Unit tests for individual helper functions
|
|
- Integration tests for stock data retrieval functions
|
|
- Edge cases (empty data, network errors, rate limits)
|
|
- Vendor fallback behavior with rate limit errors
|
|
"""
|
|
|
|
import pytest
|
|
import pandas as pd
|
|
import time
|
|
import sys
|
|
from unittest.mock import Mock, patch, MagicMock, call
|
|
from datetime import datetime
|
|
from typing import Callable, Any
|
|
|
|
# Clear any cached imports and mock akshare before importing our modules
|
|
if 'tradingagents.dataflows.akshare' in sys.modules:
|
|
del sys.modules['tradingagents.dataflows.akshare']
|
|
if 'akshare' in sys.modules:
|
|
del sys.modules['akshare']
|
|
|
|
mock_akshare = MagicMock()
|
|
sys.modules['akshare'] = mock_akshare
|
|
|
|
# Import modules under test
|
|
from tradingagents.dataflows.akshare import (
|
|
AKShareRateLimitError,
|
|
_convert_date_format,
|
|
_exponential_backoff_retry,
|
|
get_akshare_stock_data_us,
|
|
get_akshare_stock_data_cn,
|
|
get_akshare_stock_data,
|
|
)
|
|
from tradingagents.dataflows.interface import route_to_vendor, VENDOR_METHODS
|
|
|
|
|
|
# ============================================================================
|
|
# Fixtures
|
|
# ============================================================================
|
|
|
|
@pytest.fixture
|
|
def sample_us_dataframe():
|
|
"""Create a sample US stock data DataFrame matching akshare format."""
|
|
return pd.DataFrame({
|
|
'date': pd.date_range('2024-01-01', periods=5, freq='D'),
|
|
'open': [150.0, 151.0, 152.0, 153.0, 154.0],
|
|
'high': [152.0, 153.0, 154.0, 155.0, 156.0],
|
|
'low': [149.0, 150.0, 151.0, 152.0, 153.0],
|
|
'close': [151.0, 152.0, 153.0, 154.0, 155.0],
|
|
'volume': [1000000, 1100000, 1200000, 1300000, 1400000],
|
|
})
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_cn_dataframe():
|
|
"""Create a sample Chinese stock data DataFrame matching akshare format."""
|
|
return pd.DataFrame({
|
|
'日期': ['2024-01-01', '2024-01-02', '2024-01-03', '2024-01-04', '2024-01-05'],
|
|
'开盘': [10.0, 10.1, 10.2, 10.3, 10.4],
|
|
'最高': [10.2, 10.3, 10.4, 10.5, 10.6],
|
|
'最低': [9.9, 10.0, 10.1, 10.2, 10.3],
|
|
'收盘': [10.1, 10.2, 10.3, 10.4, 10.5],
|
|
'成交量': [500000, 550000, 600000, 650000, 700000],
|
|
})
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_standardized_dataframe():
|
|
"""Create a standardized DataFrame with English column names."""
|
|
return pd.DataFrame({
|
|
'Date': pd.date_range('2024-01-01', periods=5, freq='D'),
|
|
'Open': [150.0, 151.0, 152.0, 153.0, 154.0],
|
|
'High': [152.0, 153.0, 154.0, 155.0, 156.0],
|
|
'Low': [149.0, 150.0, 151.0, 152.0, 153.0],
|
|
'Close': [151.0, 152.0, 153.0, 154.0, 155.0],
|
|
'Volume': [1000000, 1100000, 1200000, 1300000, 1400000],
|
|
})
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_akshare():
|
|
"""Mock akshare module for testing."""
|
|
with patch('tradingagents.dataflows.akshare.ak') as mock_ak:
|
|
yield mock_ak
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_time_sleep():
|
|
"""Mock time.sleep to speed up retry tests."""
|
|
with patch('tradingagents.dataflows.akshare.time.sleep') as mock_sleep:
|
|
yield mock_sleep
|
|
|
|
|
|
# ============================================================================
|
|
# Test Date Format Conversion
|
|
# ============================================================================
|
|
|
|
class TestConvertDateFormat:
|
|
"""Test the _convert_date_format helper function."""
|
|
|
|
def test_standard_date_format_with_hyphen(self):
|
|
"""Test conversion from YYYY-MM-DD to YYYYMMDD format."""
|
|
result = _convert_date_format("2024-01-15")
|
|
assert result == "20240115"
|
|
|
|
def test_standard_date_format_with_single_digits(self):
|
|
"""Test conversion handles single-digit months and days."""
|
|
result = _convert_date_format("2024-1-5")
|
|
assert result == "202415"
|
|
|
|
def test_handles_slash_separator(self):
|
|
"""Test conversion from YYYY/MM/DD format."""
|
|
result = _convert_date_format("2024/01/15")
|
|
assert result == "20240115"
|
|
|
|
def test_preserves_yyyymmdd_format(self):
|
|
"""Test that already-correct format passes through."""
|
|
result = _convert_date_format("20240115")
|
|
assert result == "20240115"
|
|
|
|
def test_handles_various_date_formats(self):
|
|
"""Test multiple valid date format variations."""
|
|
test_cases = [
|
|
("2024-12-31", "20241231"),
|
|
("2024-01-01", "20240101"),
|
|
("2023-06-15", "20230615"),
|
|
]
|
|
for input_date, expected in test_cases:
|
|
assert _convert_date_format(input_date) == expected
|
|
|
|
def test_empty_string_raises_error(self):
|
|
"""Test that empty string raises appropriate error."""
|
|
with pytest.raises((ValueError, IndexError)):
|
|
_convert_date_format("")
|
|
|
|
def test_invalid_format_raises_error(self):
|
|
"""Test that invalid format raises appropriate error."""
|
|
with pytest.raises((ValueError, IndexError)):
|
|
_convert_date_format("not-a-date")
|
|
|
|
|
|
# ============================================================================
|
|
# Test Exponential Backoff Retry
|
|
# ============================================================================
|
|
|
|
class TestExponentialBackoffRetry:
|
|
"""Test the _exponential_backoff_retry helper function."""
|
|
|
|
def test_returns_on_first_success(self, mock_time_sleep):
|
|
"""Test that successful function returns immediately without retries."""
|
|
mock_func = Mock(return_value="success")
|
|
|
|
result = _exponential_backoff_retry(mock_func, max_retries=3)
|
|
|
|
assert result == "success"
|
|
assert mock_func.call_count == 1
|
|
assert mock_time_sleep.call_count == 0
|
|
|
|
def test_retries_on_failure(self, mock_time_sleep):
|
|
"""Test that function retries on failure up to max_retries."""
|
|
mock_func = Mock(side_effect=[
|
|
Exception("First failure"),
|
|
Exception("Second failure"),
|
|
"success"
|
|
])
|
|
|
|
result = _exponential_backoff_retry(mock_func, max_retries=3)
|
|
|
|
assert result == "success"
|
|
assert mock_func.call_count == 3
|
|
# Should sleep after 1st and 2nd failures
|
|
assert mock_time_sleep.call_count == 2
|
|
|
|
def test_exponential_delay(self, mock_time_sleep):
|
|
"""Test that delays increase exponentially."""
|
|
mock_func = Mock(side_effect=[
|
|
Exception("Failure 1"),
|
|
Exception("Failure 2"),
|
|
"success"
|
|
])
|
|
|
|
_exponential_backoff_retry(mock_func, max_retries=3)
|
|
|
|
# Verify exponential backoff: 2^0=1, 2^1=2
|
|
calls = mock_time_sleep.call_args_list
|
|
assert len(calls) == 2
|
|
assert calls[0][0][0] == 1 # First retry: 2^0 = 1 second
|
|
assert calls[1][0][0] == 2 # Second retry: 2^1 = 2 seconds
|
|
|
|
def test_raises_after_max_retries(self, mock_time_sleep):
|
|
"""Test that original error is raised after exhausting retries."""
|
|
error_msg = "Persistent failure"
|
|
mock_func = Mock(side_effect=Exception(error_msg))
|
|
|
|
with pytest.raises(Exception, match=error_msg):
|
|
_exponential_backoff_retry(mock_func, max_retries=3)
|
|
|
|
assert mock_func.call_count == 4 # Initial + 3 retries
|
|
assert mock_time_sleep.call_count == 3
|
|
|
|
def test_raises_rate_limit_error(self, mock_time_sleep):
|
|
"""Test that rate limit errors are raised as AKShareRateLimitError."""
|
|
mock_func = Mock(side_effect=Exception("Rate limit exceeded"))
|
|
|
|
with pytest.raises(AKShareRateLimitError):
|
|
_exponential_backoff_retry(mock_func, max_retries=2)
|
|
|
|
def test_handles_timeout_errors(self, mock_time_sleep):
|
|
"""Test handling of timeout errors."""
|
|
from requests.exceptions import Timeout
|
|
mock_func = Mock(side_effect=[Timeout("Network timeout"), "success"])
|
|
|
|
result = _exponential_backoff_retry(mock_func, max_retries=3)
|
|
|
|
assert result == "success"
|
|
assert mock_func.call_count == 2
|
|
|
|
def test_max_retries_zero(self):
|
|
"""Test behavior with max_retries=0."""
|
|
mock_func = Mock(side_effect=Exception("Failure"))
|
|
|
|
with pytest.raises(Exception):
|
|
_exponential_backoff_retry(mock_func, max_retries=0)
|
|
|
|
assert mock_func.call_count == 1 # Only initial call, no retries
|
|
|
|
def test_preserves_function_arguments(self, mock_time_sleep):
|
|
"""Test that function arguments are preserved across retries."""
|
|
mock_func = Mock(side_effect=[Exception("Fail"), "success"])
|
|
|
|
result = _exponential_backoff_retry(
|
|
lambda: mock_func("arg1", kwarg1="value1"),
|
|
max_retries=2
|
|
)
|
|
|
|
assert result == "success"
|
|
assert all(
|
|
call_args == call("arg1", kwarg1="value1")
|
|
for call_args in mock_func.call_args_list
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Test US Stock Data Retrieval
|
|
# ============================================================================
|
|
|
|
class TestGetAkshareStockDataUs:
|
|
"""Test the get_akshare_stock_data_us function."""
|
|
|
|
def test_returns_dataframe_on_success(self, mock_akshare, sample_us_dataframe):
|
|
"""Test successful data retrieval returns DataFrame."""
|
|
mock_akshare.stock_us_hist.return_value = sample_us_dataframe
|
|
|
|
result = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-01-05")
|
|
|
|
assert isinstance(result, str) # Returns CSV string
|
|
assert "AAPL" in result
|
|
assert "2024-01-01" in result
|
|
mock_akshare.stock_us_hist.assert_called_once_with(
|
|
symbol="AAPL",
|
|
period="daily",
|
|
adjust=""
|
|
)
|
|
|
|
def test_filters_data_by_date_range(self, mock_akshare):
|
|
"""Test that data is properly filtered by date range."""
|
|
# Create DataFrame with wider date range
|
|
full_df = pd.DataFrame({
|
|
'date': pd.to_datetime(['2023-12-28', '2024-01-02', '2024-01-03', '2024-01-04', '2024-01-08']),
|
|
'open': [145.0, 150.0, 151.0, 152.0, 157.0],
|
|
'high': [147.0, 152.0, 153.0, 154.0, 159.0],
|
|
'low': [144.0, 149.0, 150.0, 151.0, 156.0],
|
|
'close': [146.0, 151.0, 152.0, 153.0, 158.0],
|
|
'volume': [900000, 1000000, 1100000, 1200000, 1500000],
|
|
})
|
|
mock_akshare.stock_us_hist.return_value = full_df
|
|
|
|
result = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-01-05")
|
|
|
|
# Result should only contain dates within range
|
|
assert "2023-12-28" not in result
|
|
assert "2024-01-08" not in result
|
|
assert "2024-01-02" in result or "2024-01-03" in result
|
|
|
|
def test_returns_error_string_on_failure(self, mock_akshare):
|
|
"""Test that exceptions return error string instead of raising."""
|
|
mock_akshare.stock_us_hist.side_effect = Exception("API error")
|
|
|
|
result = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-01-05")
|
|
|
|
assert isinstance(result, str)
|
|
assert "error" in result.lower() or "failed" in result.lower()
|
|
|
|
def test_handles_empty_data(self, mock_akshare):
|
|
"""Test handling of empty DataFrame from API."""
|
|
mock_akshare.stock_us_hist.return_value = pd.DataFrame()
|
|
|
|
result = get_akshare_stock_data_us("INVALID", "2024-01-01", "2024-01-05")
|
|
|
|
assert isinstance(result, str)
|
|
assert "no data" in result.lower() or "empty" in result.lower()
|
|
|
|
def test_handles_network_timeout(self, mock_akshare):
|
|
"""Test handling of network timeout errors."""
|
|
from requests.exceptions import Timeout
|
|
mock_akshare.stock_us_hist.side_effect = Timeout("Connection timeout")
|
|
|
|
result = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-01-05")
|
|
|
|
assert isinstance(result, str)
|
|
assert "timeout" in result.lower() or "error" in result.lower()
|
|
|
|
def test_standardizes_output_format(self, mock_akshare, sample_us_dataframe):
|
|
"""Test that output format matches expected CSV structure."""
|
|
mock_akshare.stock_us_hist.return_value = sample_us_dataframe
|
|
|
|
result = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-01-05")
|
|
|
|
# Should contain CSV header information
|
|
assert "Stock data for" in result or "AAPL" in result
|
|
# Should contain column headers
|
|
lines = result.split('\n')
|
|
assert len(lines) > 1 # Has header + data rows
|
|
|
|
def test_handles_malformed_dates(self, mock_akshare):
|
|
"""Test handling of invalid date formats."""
|
|
result = get_akshare_stock_data_us("AAPL", "invalid-date", "2024-01-05")
|
|
|
|
assert isinstance(result, str)
|
|
# Should return error message rather than raising exception
|
|
|
|
def test_symbol_case_handling(self, mock_akshare, sample_us_dataframe):
|
|
"""Test that symbol is converted to uppercase."""
|
|
mock_akshare.stock_us_hist.return_value = sample_us_dataframe
|
|
|
|
result = get_akshare_stock_data_us("aapl", "2024-01-01", "2024-01-05")
|
|
|
|
mock_akshare.stock_us_hist.assert_called_once()
|
|
call_args = mock_akshare.stock_us_hist.call_args
|
|
assert call_args[1]['symbol'] == "AAPL" or call_args[1]['symbol'] == "aapl"
|
|
|
|
|
|
# ============================================================================
|
|
# Test Chinese Stock Data Retrieval
|
|
# ============================================================================
|
|
|
|
class TestGetAkshareStockDataCn:
|
|
"""Test the get_akshare_stock_data_cn function."""
|
|
|
|
def test_returns_dataframe_on_success(self, mock_akshare, sample_cn_dataframe):
|
|
"""Test successful data retrieval returns DataFrame."""
|
|
mock_akshare.stock_zh_a_hist.return_value = sample_cn_dataframe
|
|
|
|
result = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-01-05")
|
|
|
|
assert isinstance(result, str)
|
|
assert "000001" in result
|
|
mock_akshare.stock_zh_a_hist.assert_called_once()
|
|
|
|
def test_converts_date_format(self, mock_akshare, sample_cn_dataframe):
|
|
"""Test that dates are converted to YYYYMMDD format for API."""
|
|
mock_akshare.stock_zh_a_hist.return_value = sample_cn_dataframe
|
|
|
|
get_akshare_stock_data_cn("000001", "2024-01-01", "2024-01-05")
|
|
|
|
call_args = mock_akshare.stock_zh_a_hist.call_args
|
|
# Verify date format conversion happened
|
|
assert call_args is not None
|
|
|
|
def test_returns_error_string_on_failure(self, mock_akshare):
|
|
"""Test that exceptions return error string."""
|
|
mock_akshare.stock_zh_a_hist.side_effect = Exception("API error")
|
|
|
|
result = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-01-05")
|
|
|
|
assert isinstance(result, str)
|
|
assert "error" in result.lower() or "failed" in result.lower()
|
|
|
|
def test_standardizes_column_names(self, mock_akshare, sample_cn_dataframe):
|
|
"""Test that Chinese column names are mapped to English."""
|
|
mock_akshare.stock_zh_a_hist.return_value = sample_cn_dataframe
|
|
|
|
result = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-01-05")
|
|
|
|
# Output should contain English column names
|
|
result_lower = result.lower()
|
|
assert "date" in result_lower or "open" in result_lower or "close" in result_lower
|
|
|
|
def test_handles_empty_data(self, mock_akshare):
|
|
"""Test handling of empty DataFrame."""
|
|
mock_akshare.stock_zh_a_hist.return_value = pd.DataFrame()
|
|
|
|
result = get_akshare_stock_data_cn("INVALID", "2024-01-01", "2024-01-05")
|
|
|
|
assert isinstance(result, str)
|
|
assert "no data" in result.lower() or "empty" in result.lower()
|
|
|
|
def test_handles_symbol_with_suffix(self, mock_akshare, sample_cn_dataframe):
|
|
"""Test handling of symbols with .SZ or .SH suffixes."""
|
|
mock_akshare.stock_zh_a_hist.return_value = sample_cn_dataframe
|
|
|
|
result = get_akshare_stock_data_cn("000001.SZ", "2024-01-01", "2024-01-05")
|
|
|
|
assert isinstance(result, str)
|
|
# Function should handle suffix appropriately
|
|
|
|
def test_handles_rate_limit_error(self, mock_akshare):
|
|
"""Test that rate limit errors are properly raised."""
|
|
mock_akshare.stock_zh_a_hist.side_effect = Exception("访问频率过快") # Chinese rate limit message
|
|
|
|
# Should raise AKShareRateLimitError when wrapped in retry mechanism
|
|
with pytest.raises(AKShareRateLimitError):
|
|
_exponential_backoff_retry(
|
|
lambda: mock_akshare.stock_zh_a_hist(),
|
|
max_retries=1
|
|
)
|
|
|
|
def test_standardizes_output_format(self, mock_akshare, sample_cn_dataframe):
|
|
"""Test that output format is standardized."""
|
|
mock_akshare.stock_zh_a_hist.return_value = sample_cn_dataframe
|
|
|
|
result = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-01-05")
|
|
|
|
# Should be CSV-like format
|
|
lines = result.split('\n')
|
|
assert len(lines) > 1
|
|
|
|
|
|
# ============================================================================
|
|
# Test Auto-Market Detection
|
|
# ============================================================================
|
|
|
|
class TestGetAkshareStockData:
|
|
"""Test the get_akshare_stock_data function with auto-market detection."""
|
|
|
|
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data_us')
|
|
def test_auto_detects_us_market(self, mock_us_func):
|
|
"""Test that US symbols are automatically detected."""
|
|
mock_us_func.return_value = "US data"
|
|
|
|
result = get_akshare_stock_data("AAPL", "2024-01-01", "2024-01-05")
|
|
|
|
assert result == "US data"
|
|
mock_us_func.assert_called_once_with("AAPL", "2024-01-01", "2024-01-05")
|
|
|
|
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data_cn')
|
|
def test_auto_detects_cn_market_with_sz_suffix(self, mock_cn_func):
|
|
"""Test that Chinese symbols with .SZ suffix are detected."""
|
|
mock_cn_func.return_value = "CN data"
|
|
|
|
result = get_akshare_stock_data("000001.SZ", "2024-01-01", "2024-01-05")
|
|
|
|
assert result == "CN data"
|
|
mock_cn_func.assert_called_once_with("000001.SZ", "2024-01-01", "2024-01-05")
|
|
|
|
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data_cn')
|
|
def test_auto_detects_cn_market_with_sh_suffix(self, mock_cn_func):
|
|
"""Test that Chinese symbols with .SH suffix are detected."""
|
|
mock_cn_func.return_value = "CN data"
|
|
|
|
result = get_akshare_stock_data("600000.SH", "2024-01-01", "2024-01-05")
|
|
|
|
assert result == "CN data"
|
|
mock_cn_func.assert_called_once_with("600000.SH", "2024-01-01", "2024-01-05")
|
|
|
|
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data_cn')
|
|
def test_auto_detects_cn_market_numeric_only(self, mock_cn_func):
|
|
"""Test that numeric-only symbols default to Chinese market."""
|
|
mock_cn_func.return_value = "CN data"
|
|
|
|
result = get_akshare_stock_data("000001", "2024-01-01", "2024-01-05")
|
|
|
|
assert result == "CN data"
|
|
mock_cn_func.assert_called_once_with("000001", "2024-01-01", "2024-01-05")
|
|
|
|
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data_us')
|
|
def test_explicit_market_us(self, mock_us_func):
|
|
"""Test that market='us' forces US function."""
|
|
mock_us_func.return_value = "US data"
|
|
|
|
result = get_akshare_stock_data("AAPL", "2024-01-01", "2024-01-05", market="us")
|
|
|
|
assert result == "US data"
|
|
mock_us_func.assert_called_once()
|
|
|
|
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data_cn')
|
|
def test_explicit_market_cn(self, mock_cn_func):
|
|
"""Test that market='cn' forces Chinese function."""
|
|
mock_cn_func.return_value = "CN data"
|
|
|
|
result = get_akshare_stock_data("000001", "2024-01-01", "2024-01-05", market="cn")
|
|
|
|
assert result == "CN data"
|
|
mock_cn_func.assert_called_once()
|
|
|
|
def test_returns_standardized_dataframe(self):
|
|
"""Test that output has consistent schema regardless of market."""
|
|
with patch('tradingagents.dataflows.akshare.get_akshare_stock_data_us') as mock_us:
|
|
mock_us.return_value = "Date,Open,High,Low,Close,Volume\n2024-01-01,150,152,149,151,1000000"
|
|
result_us = get_akshare_stock_data("AAPL", "2024-01-01", "2024-01-05")
|
|
|
|
with patch('tradingagents.dataflows.akshare.get_akshare_stock_data_cn') as mock_cn:
|
|
mock_cn.return_value = "Date,Open,High,Low,Close,Volume\n2024-01-01,10,10.2,9.9,10.1,500000"
|
|
result_cn = get_akshare_stock_data("000001", "2024-01-01", "2024-01-05")
|
|
|
|
# Both should have similar structure
|
|
assert isinstance(result_us, str)
|
|
assert isinstance(result_cn, str)
|
|
|
|
def test_invalid_market_parameter(self):
|
|
"""Test handling of invalid market parameter."""
|
|
with pytest.raises(ValueError):
|
|
get_akshare_stock_data("AAPL", "2024-01-01", "2024-01-05", market="invalid")
|
|
|
|
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data_us')
|
|
def test_market_auto_default_behavior(self, mock_us_func):
|
|
"""Test that market='auto' is the default behavior."""
|
|
mock_us_func.return_value = "US data"
|
|
|
|
# Call without market parameter
|
|
result1 = get_akshare_stock_data("AAPL", "2024-01-01", "2024-01-05")
|
|
# Call with explicit market='auto'
|
|
result2 = get_akshare_stock_data("AAPL", "2024-01-01", "2024-01-05", market="auto")
|
|
|
|
assert result1 == result2
|
|
|
|
|
|
# ============================================================================
|
|
# Test AKShareRateLimitError Exception
|
|
# ============================================================================
|
|
|
|
class TestAKShareRateLimitError:
|
|
"""Test the AKShareRateLimitError exception class."""
|
|
|
|
def test_is_exception_subclass(self):
|
|
"""Test that AKShareRateLimitError inherits from Exception."""
|
|
assert issubclass(AKShareRateLimitError, Exception)
|
|
|
|
def test_can_be_raised_and_caught(self):
|
|
"""Test that exception can be raised and caught properly."""
|
|
with pytest.raises(AKShareRateLimitError):
|
|
raise AKShareRateLimitError("Rate limit exceeded")
|
|
|
|
def test_message_included(self):
|
|
"""Test that error message is preserved."""
|
|
message = "API rate limit exceeded: 5 calls per minute"
|
|
try:
|
|
raise AKShareRateLimitError(message)
|
|
except AKShareRateLimitError as e:
|
|
assert str(e) == message
|
|
|
|
def test_can_be_caught_as_generic_exception(self):
|
|
"""Test that it can be caught as generic Exception."""
|
|
with pytest.raises(Exception):
|
|
raise AKShareRateLimitError("Rate limit")
|
|
|
|
def test_distinct_from_other_exceptions(self):
|
|
"""Test that it's distinct from other exception types."""
|
|
try:
|
|
raise AKShareRateLimitError("Rate limit")
|
|
except ValueError:
|
|
pytest.fail("Should not be caught as ValueError")
|
|
except AKShareRateLimitError:
|
|
pass # Expected
|
|
|
|
|
|
# ============================================================================
|
|
# Test Vendor Integration (interface.py modifications)
|
|
# ============================================================================
|
|
|
|
class TestVendorIntegration:
|
|
"""Test integration with the vendor routing system in interface.py."""
|
|
|
|
def test_akshare_in_vendor_methods(self):
|
|
"""Test that akshare is registered in VENDOR_METHODS."""
|
|
assert "get_stock_data" in VENDOR_METHODS
|
|
assert "akshare" in VENDOR_METHODS["get_stock_data"]
|
|
|
|
def test_akshare_vendor_function_mapping(self):
|
|
"""Test that akshare maps to correct function."""
|
|
from tradingagents.dataflows.akshare import get_akshare_stock_data
|
|
|
|
akshare_impl = VENDOR_METHODS["get_stock_data"]["akshare"]
|
|
assert akshare_impl == get_akshare_stock_data
|
|
|
|
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data')
|
|
@patch('tradingagents.dataflows.config.get_config')
|
|
def test_route_to_vendor_uses_akshare(self, mock_config, mock_akshare_func):
|
|
"""Test that route_to_vendor calls akshare when configured."""
|
|
mock_config.return_value = {"data_vendor": "akshare"}
|
|
mock_akshare_func.return_value = "AKShare data"
|
|
|
|
result = route_to_vendor("get_stock_data", "AAPL", "2024-01-01", "2024-01-05")
|
|
|
|
assert result == "AKShare data"
|
|
mock_akshare_func.assert_called_once_with("AAPL", "2024-01-01", "2024-01-05")
|
|
|
|
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data')
|
|
@patch('tradingagents.dataflows.y_finance.get_YFin_data_online')
|
|
@patch('tradingagents.dataflows.config.get_config')
|
|
def test_fallback_on_rate_limit(self, mock_config, mock_yfinance, mock_akshare):
|
|
"""Test that AKShareRateLimitError triggers fallback to next vendor."""
|
|
mock_config.return_value = {"data_vendor": "akshare,yfinance"}
|
|
mock_akshare.side_effect = AKShareRateLimitError("Rate limit exceeded")
|
|
mock_yfinance.return_value = "YFinance data"
|
|
|
|
result = route_to_vendor("get_stock_data", "AAPL", "2024-01-01", "2024-01-05")
|
|
|
|
assert result == "YFinance data"
|
|
mock_akshare.assert_called_once()
|
|
mock_yfinance.assert_called_once()
|
|
|
|
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data')
|
|
@patch('tradingagents.dataflows.y_finance.get_YFin_data_online')
|
|
@patch('tradingagents.dataflows.config.get_config')
|
|
def test_fallback_chain_akshare_yfinance(self, mock_config, mock_yfinance, mock_akshare):
|
|
"""Test multi-vendor fallback chain works correctly."""
|
|
mock_config.return_value = {"data_vendor": "akshare,yfinance,local"}
|
|
mock_akshare.side_effect = AKShareRateLimitError("Rate limit")
|
|
mock_yfinance.return_value = "YFinance fallback data"
|
|
|
|
result = route_to_vendor("get_stock_data", "AAPL", "2024-01-01", "2024-01-05")
|
|
|
|
assert "YFinance" in result
|
|
# Verify akshare was tried first, then yfinance succeeded
|
|
assert mock_akshare.call_count == 1
|
|
assert mock_yfinance.call_count == 1
|
|
|
|
@patch('tradingagents.dataflows.akshare.get_akshare_stock_data')
|
|
@patch('tradingagents.dataflows.config.get_config')
|
|
def test_akshare_error_string_not_triggers_fallback(self, mock_config, mock_akshare):
|
|
"""Test that error strings (not exceptions) don't trigger fallback."""
|
|
mock_config.return_value = {"data_vendor": "akshare"}
|
|
# Return error string, not exception
|
|
mock_akshare.return_value = "Error: No data found"
|
|
|
|
result = route_to_vendor("get_stock_data", "INVALID", "2024-01-01", "2024-01-05")
|
|
|
|
# Should return the error string, not attempt fallback
|
|
assert "Error" in result
|
|
assert mock_akshare.call_count == 1
|
|
|
|
def test_akshare_in_vendor_list(self):
|
|
"""Test that akshare is in the global VENDOR_LIST."""
|
|
from tradingagents.dataflows.interface import VENDOR_LIST
|
|
assert "akshare" in VENDOR_LIST
|
|
|
|
|
|
# ============================================================================
|
|
# Integration Tests
|
|
# ============================================================================
|
|
|
|
class TestAKShareIntegration:
|
|
"""Integration tests combining multiple components."""
|
|
|
|
@patch('tradingagents.dataflows.akshare.ak')
|
|
def test_end_to_end_us_stock_retrieval(self, mock_ak):
|
|
"""Test complete flow of US stock data retrieval."""
|
|
# Setup mock
|
|
mock_df = pd.DataFrame({
|
|
'date': pd.date_range('2024-01-01', periods=3, freq='D'),
|
|
'open': [150.0, 151.0, 152.0],
|
|
'high': [152.0, 153.0, 154.0],
|
|
'low': [149.0, 150.0, 151.0],
|
|
'close': [151.0, 152.0, 153.0],
|
|
'volume': [1000000, 1100000, 1200000],
|
|
})
|
|
mock_ak.stock_us_hist.return_value = mock_df
|
|
|
|
# Call main function
|
|
result = get_akshare_stock_data("AAPL", "2024-01-01", "2024-01-03", market="us")
|
|
|
|
# Verify result
|
|
assert isinstance(result, str)
|
|
assert "AAPL" in result or "150" in result
|
|
|
|
@patch('tradingagents.dataflows.akshare.ak')
|
|
def test_end_to_end_cn_stock_retrieval(self, mock_ak):
|
|
"""Test complete flow of Chinese stock data retrieval."""
|
|
# Setup mock with Chinese column names
|
|
mock_df = pd.DataFrame({
|
|
'日期': ['2024-01-01', '2024-01-02', '2024-01-03'],
|
|
'开盘': [10.0, 10.1, 10.2],
|
|
'最高': [10.2, 10.3, 10.4],
|
|
'最低': [9.9, 10.0, 10.1],
|
|
'收盘': [10.1, 10.2, 10.3],
|
|
'成交量': [500000, 550000, 600000],
|
|
})
|
|
mock_ak.stock_zh_a_hist.return_value = mock_df
|
|
|
|
# Call main function
|
|
result = get_akshare_stock_data("000001", "2024-01-01", "2024-01-03", market="cn")
|
|
|
|
# Verify result
|
|
assert isinstance(result, str)
|
|
assert "000001" in result or "10" in result
|
|
|
|
@patch('tradingagents.dataflows.akshare.ak')
|
|
@patch('tradingagents.dataflows.akshare.time.sleep')
|
|
def test_retry_mechanism_with_transient_failure(self, mock_sleep, mock_ak):
|
|
"""Test that retry mechanism handles transient failures."""
|
|
# First call fails, second succeeds
|
|
mock_df = pd.DataFrame({
|
|
'date': pd.date_range('2024-01-01', periods=2, freq='D'),
|
|
'open': [150.0, 151.0],
|
|
'high': [152.0, 153.0],
|
|
'low': [149.0, 150.0],
|
|
'close': [151.0, 152.0],
|
|
'volume': [1000000, 1100000],
|
|
})
|
|
mock_ak.stock_us_hist.side_effect = [
|
|
Exception("Transient network error"),
|
|
mock_df
|
|
]
|
|
|
|
# Should succeed on retry
|
|
result = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-01-02")
|
|
|
|
assert isinstance(result, str)
|
|
assert mock_ak.stock_us_hist.call_count == 2
|
|
assert mock_sleep.call_count == 1 # One retry delay
|
|
|
|
def test_column_name_standardization(self):
|
|
"""Test that Chinese and US data have standardized column names."""
|
|
with patch('tradingagents.dataflows.akshare.ak') as mock_ak:
|
|
# Test US format
|
|
us_df = pd.DataFrame({
|
|
'date': pd.date_range('2024-01-01', periods=2, freq='D'),
|
|
'open': [150.0, 151.0],
|
|
'high': [152.0, 153.0],
|
|
'low': [149.0, 150.0],
|
|
'close': [151.0, 152.0],
|
|
'volume': [1000000, 1100000],
|
|
})
|
|
mock_ak.stock_us_hist.return_value = us_df
|
|
us_result = get_akshare_stock_data_us("AAPL", "2024-01-01", "2024-01-02")
|
|
|
|
# Test CN format
|
|
cn_df = pd.DataFrame({
|
|
'日期': ['2024-01-01', '2024-01-02'],
|
|
'开盘': [10.0, 10.1],
|
|
'最高': [10.2, 10.3],
|
|
'最低': [9.9, 10.0],
|
|
'收盘': [10.1, 10.2],
|
|
'成交量': [500000, 550000],
|
|
})
|
|
mock_ak.stock_zh_a_hist.return_value = cn_df
|
|
cn_result = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-01-02")
|
|
|
|
# Both should have English headers
|
|
for result in [us_result, cn_result]:
|
|
result_lower = result.lower()
|
|
# At least some standard column names should appear
|
|
has_standard_cols = any(
|
|
col in result_lower
|
|
for col in ['date', 'open', 'high', 'low', 'close', 'volume']
|
|
)
|
|
assert has_standard_cols
|
|
|
|
|
|
# ============================================================================
|
|
# Edge Cases and Error Handling
|
|
# ============================================================================
|
|
|
|
class TestEdgeCases:
|
|
"""Test edge cases and error conditions."""
|
|
|
|
def test_empty_symbol(self):
|
|
"""Test handling of empty symbol string."""
|
|
result = get_akshare_stock_data("", "2024-01-01", "2024-01-05")
|
|
assert isinstance(result, str)
|
|
|
|
def test_future_dates(self):
|
|
"""Test handling of future dates."""
|
|
with patch('tradingagents.dataflows.akshare.ak') as mock_ak:
|
|
mock_ak.stock_us_hist.return_value = pd.DataFrame()
|
|
result = get_akshare_stock_data_us("AAPL", "2030-01-01", "2030-01-05")
|
|
assert isinstance(result, str)
|
|
|
|
def test_start_date_after_end_date(self):
|
|
"""Test handling when start_date > end_date."""
|
|
result = get_akshare_stock_data("AAPL", "2024-01-05", "2024-01-01")
|
|
assert isinstance(result, str)
|
|
|
|
def test_very_long_date_range(self):
|
|
"""Test handling of very long date ranges (years of data)."""
|
|
with patch('tradingagents.dataflows.akshare.ak') as mock_ak:
|
|
# Simulate large dataset
|
|
large_df = pd.DataFrame({
|
|
'date': pd.date_range('2020-01-01', periods=1000, freq='D'),
|
|
'open': [150.0] * 1000,
|
|
'high': [152.0] * 1000,
|
|
'low': [149.0] * 1000,
|
|
'close': [151.0] * 1000,
|
|
'volume': [1000000] * 1000,
|
|
})
|
|
mock_ak.stock_us_hist.return_value = large_df
|
|
|
|
result = get_akshare_stock_data_us("AAPL", "2020-01-01", "2024-01-01")
|
|
assert isinstance(result, str)
|
|
|
|
def test_special_characters_in_symbol(self):
|
|
"""Test handling of symbols with special characters."""
|
|
with patch('tradingagents.dataflows.akshare.ak') as mock_ak:
|
|
mock_ak.stock_us_hist.return_value = pd.DataFrame()
|
|
result = get_akshare_stock_data_us("BRK.B", "2024-01-01", "2024-01-05")
|
|
assert isinstance(result, str)
|
|
|
|
@patch('tradingagents.dataflows.akshare.ak')
|
|
def test_unicode_in_error_messages(self, mock_ak):
|
|
"""Test handling of Unicode characters in error messages."""
|
|
mock_ak.stock_zh_a_hist.side_effect = Exception("访问频率过快,请稍后重试")
|
|
|
|
result = get_akshare_stock_data_cn("000001", "2024-01-01", "2024-01-05")
|
|
assert isinstance(result, str)
|
|
|
|
def test_none_parameters(self):
|
|
"""Test handling of None parameters."""
|
|
with pytest.raises((TypeError, AttributeError)):
|
|
get_akshare_stock_data(None, "2024-01-01", "2024-01-05")
|