feat(dataflows): add benchmark data module with SPY, sector ETFs, RS, correlation, beta - Fixes #10
This commit is contained in:
parent
19171a4b31
commit
bbd85c91b6
25
CHANGELOG.md
25
CHANGELOG.md
|
|
@ -89,6 +89,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
- Test coverage including rate limit handling, caching behavior, and date range filtering
|
- Test coverage including rate limit handling, caching behavior, and date range filtering
|
||||||
- Total: 108 tests added for FRED API feature
|
- Total: 108 tests added for FRED API feature
|
||||||
|
|
||||||
|
- Benchmark data retrieval and analysis (Issue #10)
|
||||||
|
- Benchmark data module for SPY, sector ETF, and analysis functions [file:tradingagents/dataflows/benchmark.py](tradingagents/dataflows/benchmark.py) (441 lines)
|
||||||
|
- Sector ETF mappings for all 11 SPDR sector funds (communication, consumer discretionary/staples, energy, financials, healthcare, industrials, materials, real estate, technology, utilities) [file:tradingagents/dataflows/benchmark.py:48-59](tradingagents/dataflows/benchmark.py)
|
||||||
|
- get_benchmark_data() function for fetching OHLCV data via yfinance with date validation [file:tradingagents/dataflows/benchmark.py:67-115](tradingagents/dataflows/benchmark.py)
|
||||||
|
- get_spy_data() convenience wrapper for S&P 500 benchmark data [file:tradingagents/dataflows/benchmark.py:117-136](tradingagents/dataflows/benchmark.py)
|
||||||
|
- get_sector_etf_data() function for retrieving sector-specific benchmark data with sector validation [file:tradingagents/dataflows/benchmark.py:138-186](tradingagents/dataflows/benchmark.py)
|
||||||
|
- calculate_relative_strength() function with IBD-style weighted ROC formula [file:tradingagents/dataflows/benchmark.py:188-285](tradingagents/dataflows/benchmark.py)
|
||||||
|
- Relative strength calculation using weighted periods (40% 63-day, 20% 126-day, 20% 189-day, 20% 252-day ROC)
|
||||||
|
- Customizable ROC periods with default IBD-style weighting
|
||||||
|
- Data alignment via inner join with validation for overlapping dates
|
||||||
|
- calculate_rolling_correlation() function for time-series correlation analysis [file:tradingagents/dataflows/benchmark.py:287-349](tradingagents/dataflows/benchmark.py)
|
||||||
|
- Configurable rolling window sizes with default 60-day window
|
||||||
|
- Comprehensive validation for data alignment and minimum data requirements
|
||||||
|
- calculate_beta() function for volatility and systematic risk measurement [file:tradingagents/dataflows/benchmark.py:351-441](tradingagents/dataflows/benchmark.py)
|
||||||
|
- Beta calculation using covariance-variance approach with optional smoothing
|
||||||
|
- Optional rolling beta calculation with customizable window (default 252 days)
|
||||||
|
- Markdown rolling window implementation for efficient computation
|
||||||
|
- All functions return DataFrames/Series/floats on success, error strings on failure
|
||||||
|
- Comprehensive error handling with descriptive messages and validation logic
|
||||||
|
- Comprehensive docstrings with examples for all public functions
|
||||||
|
- Unit test suite for benchmark functions [file:tests/unit/dataflows/test_benchmark.py](tests/unit/dataflows/test_benchmark.py) (753 lines, 28 tests)
|
||||||
|
- Integration test suite for benchmark workflows [file:tests/integration/dataflows/test_benchmark_integration.py](tests/integration/dataflows/test_benchmark_integration.py) (593 lines, 7 tests)
|
||||||
|
- Test coverage includes data fetching, sector validation, relative strength calculation, correlation analysis, and beta calculation
|
||||||
|
- Total: 35 tests added for benchmark data feature
|
||||||
|
|
||||||
- Multi-timeframe OHLCV aggregation functions (Issue #9)
|
- Multi-timeframe OHLCV aggregation functions (Issue #9)
|
||||||
- Multi-timeframe aggregation module for daily to weekly/monthly resampling [file:tradingagents/dataflows/multi_timeframe.py](tradingagents/dataflows/multi_timeframe.py) (320 lines)
|
- Multi-timeframe aggregation module for daily to weekly/monthly resampling [file:tradingagents/dataflows/multi_timeframe.py](tradingagents/dataflows/multi_timeframe.py) (320 lines)
|
||||||
- Core OHLCV aggregation validation function _validate_ohlcv_dataframe() [file:tradingagents/dataflows/multi_timeframe.py:38-75](tradingagents/dataflows/multi_timeframe.py)
|
- Core OHLCV aggregation validation function _validate_ohlcv_dataframe() [file:tradingagents/dataflows/multi_timeframe.py:38-75](tradingagents/dataflows/multi_timeframe.py)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,593 @@
|
||||||
|
"""
|
||||||
|
Test suite for Benchmark Integration Tests.
|
||||||
|
|
||||||
|
This module tests:
|
||||||
|
1. End-to-end workflows with benchmark data
|
||||||
|
2. Multi-sector comparison analysis
|
||||||
|
3. Real-world data format handling (yfinance compatibility)
|
||||||
|
4. Combined analytics (RS + correlation + beta)
|
||||||
|
5. All sector ETFs availability
|
||||||
|
|
||||||
|
Test Coverage:
|
||||||
|
- Integration with yfinance data formats
|
||||||
|
- Complete benchmark analysis workflow
|
||||||
|
- Multi-sector relative strength comparison
|
||||||
|
- Portfolio-level analytics
|
||||||
|
- Date alignment across multiple datasets
|
||||||
|
- All 11 sector ETFs (XLC, XLY, XLP, XLE, XLF, XLV, XLI, XLB, XLRE, XLK, XLU)
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
1. Fetch benchmark data (SPY)
|
||||||
|
2. Fetch stock data
|
||||||
|
3. Calculate RS, correlation, beta
|
||||||
|
4. Compare across sectors
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Fixtures
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def yfinance_spy_data():
|
||||||
|
"""
|
||||||
|
Create SPY data in yfinance format.
|
||||||
|
|
||||||
|
yfinance returns:
|
||||||
|
- DatetimeIndex (timezone-aware or naive)
|
||||||
|
- Capitalized column names
|
||||||
|
- Business day frequency
|
||||||
|
"""
|
||||||
|
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||||
|
|
||||||
|
data = pd.DataFrame({
|
||||||
|
'Open': [450.0 + i * 0.3 for i in range(300)],
|
||||||
|
'High': [452.0 + i * 0.3 for i in range(300)],
|
||||||
|
'Low': [449.0 + i * 0.3 for i in range(300)],
|
||||||
|
'Close': [451.0 + i * 0.3 for i in range(300)],
|
||||||
|
'Volume': [80000000 + i * 100000 for i in range(300)],
|
||||||
|
}, index=dates)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def yfinance_stock_data():
|
||||||
|
"""Create stock data in yfinance format (AAPL-like)."""
|
||||||
|
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||||
|
|
||||||
|
data = pd.DataFrame({
|
||||||
|
'Open': [180.0 + i * 0.4 for i in range(300)],
|
||||||
|
'High': [182.0 + i * 0.4 for i in range(300)],
|
||||||
|
'Low': [179.0 + i * 0.4 for i in range(300)],
|
||||||
|
'Close': [181.0 + i * 0.4 for i in range(300)],
|
||||||
|
'Volume': [50000000 + i * 80000 for i in range(300)],
|
||||||
|
}, index=dates)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def yfinance_sector_data_xlk():
|
||||||
|
"""Create XLK sector ETF data in yfinance format."""
|
||||||
|
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||||
|
|
||||||
|
data = pd.DataFrame({
|
||||||
|
'Open': [200.0 + i * 0.35 for i in range(300)],
|
||||||
|
'High': [202.0 + i * 0.35 for i in range(300)],
|
||||||
|
'Low': [199.0 + i * 0.35 for i in range(300)],
|
||||||
|
'Close': [201.0 + i * 0.35 for i in range(300)],
|
||||||
|
'Volume': [10000000 + i * 50000 for i in range(300)],
|
||||||
|
}, index=dates)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def yfinance_sector_data_xlf():
|
||||||
|
"""Create XLF sector ETF data in yfinance format."""
|
||||||
|
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||||
|
|
||||||
|
data = pd.DataFrame({
|
||||||
|
'Open': [38.0 + i * 0.02 for i in range(300)],
|
||||||
|
'High': [38.5 + i * 0.02 for i in range(300)],
|
||||||
|
'Low': [37.5 + i * 0.02 for i in range(300)],
|
||||||
|
'Close': [38.2 + i * 0.02 for i in range(300)],
|
||||||
|
'Volume': [60000000 + i * 200000 for i in range(300)],
|
||||||
|
}, index=dates)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def yfinance_sector_data_xle():
|
||||||
|
"""Create XLE sector ETF data in yfinance format."""
|
||||||
|
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||||
|
|
||||||
|
data = pd.DataFrame({
|
||||||
|
'Open': [85.0 + i * 0.1 for i in range(300)],
|
||||||
|
'High': [86.0 + i * 0.1 for i in range(300)],
|
||||||
|
'Low': [84.0 + i * 0.1 for i in range(300)],
|
||||||
|
'Close': [85.5 + i * 0.1 for i in range(300)],
|
||||||
|
'Volume': [25000000 + i * 100000 for i in range(300)],
|
||||||
|
}, index=dates)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def all_sector_etf_data():
|
||||||
|
"""
|
||||||
|
Create data for all 11 sector ETFs.
|
||||||
|
|
||||||
|
Returns dict mapping sector names to DataFrames.
|
||||||
|
"""
|
||||||
|
sectors_data = {}
|
||||||
|
sector_configs = {
|
||||||
|
'communication': {'base': 75.0, 'increment': 0.08},
|
||||||
|
'consumer_discretionary': {'base': 180.0, 'increment': 0.15},
|
||||||
|
'consumer_staples': {'base': 75.0, 'increment': 0.05},
|
||||||
|
'energy': {'base': 85.0, 'increment': 0.1},
|
||||||
|
'financials': {'base': 38.0, 'increment': 0.02},
|
||||||
|
'healthcare': {'base': 130.0, 'increment': 0.12},
|
||||||
|
'industrials': {'base': 105.0, 'increment': 0.09},
|
||||||
|
'materials': {'base': 85.0, 'increment': 0.07},
|
||||||
|
'real_estate': {'base': 40.0, 'increment': 0.03},
|
||||||
|
'technology': {'base': 200.0, 'increment': 0.35},
|
||||||
|
'utilities': {'base': 65.0, 'increment': 0.04},
|
||||||
|
}
|
||||||
|
|
||||||
|
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||||
|
|
||||||
|
for sector, config in sector_configs.items():
|
||||||
|
base = config['base']
|
||||||
|
inc = config['increment']
|
||||||
|
|
||||||
|
data = pd.DataFrame({
|
||||||
|
'Open': [base + i * inc for i in range(300)],
|
||||||
|
'High': [base + 1.0 + i * inc for i in range(300)],
|
||||||
|
'Low': [base - 0.5 + i * inc for i in range(300)],
|
||||||
|
'Close': [base + 0.5 + i * inc for i in range(300)],
|
||||||
|
'Volume': [15000000 + i * 50000 for i in range(300)],
|
||||||
|
}, index=dates)
|
||||||
|
|
||||||
|
sectors_data[sector] = data
|
||||||
|
|
||||||
|
return sectors_data
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Test Class: Benchmark Integration
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestBenchmarkIntegration:
|
||||||
|
"""
|
||||||
|
Test suite for end-to-end benchmark workflows.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- Complete analysis workflow (fetch + RS + correlation + beta)
|
||||||
|
- Multi-sector comparison
|
||||||
|
- All sector ETFs availability
|
||||||
|
- Combined analytics
|
||||||
|
"""
|
||||||
|
|
||||||
|
@patch('tradingagents.dataflows.benchmark.yf')
|
||||||
|
def test_end_to_end_benchmark_analysis(
|
||||||
|
self,
|
||||||
|
mock_yf,
|
||||||
|
yfinance_stock_data,
|
||||||
|
yfinance_spy_data
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test complete benchmark analysis workflow.
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
1. Fetch SPY benchmark data
|
||||||
|
2. Fetch stock data
|
||||||
|
3. Calculate relative strength
|
||||||
|
4. Calculate rolling correlation
|
||||||
|
5. Calculate beta
|
||||||
|
"""
|
||||||
|
from tradingagents.dataflows.benchmark import (
|
||||||
|
get_spy_data,
|
||||||
|
get_benchmark_data,
|
||||||
|
calculate_relative_strength,
|
||||||
|
calculate_rolling_correlation,
|
||||||
|
calculate_beta
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
def ticker_side_effect(symbol):
|
||||||
|
mock_ticker_instance = MagicMock()
|
||||||
|
if symbol == 'SPY':
|
||||||
|
mock_ticker_instance.history.return_value = yfinance_spy_data
|
||||||
|
else: # AAPL
|
||||||
|
mock_ticker_instance.history.return_value = yfinance_stock_data
|
||||||
|
return mock_ticker_instance
|
||||||
|
|
||||||
|
mock_yf.Ticker.side_effect = ticker_side_effect
|
||||||
|
|
||||||
|
# Step 1: Fetch SPY benchmark
|
||||||
|
spy_data = get_spy_data('2024-01-01', '2024-10-31')
|
||||||
|
assert isinstance(spy_data, pd.DataFrame)
|
||||||
|
assert len(spy_data) > 0
|
||||||
|
|
||||||
|
# Step 2: Fetch stock data
|
||||||
|
stock_data = get_benchmark_data('AAPL', '2024-01-01', '2024-10-31')
|
||||||
|
assert isinstance(stock_data, pd.DataFrame)
|
||||||
|
assert len(stock_data) > 0
|
||||||
|
|
||||||
|
# Step 3: Calculate relative strength
|
||||||
|
rs = calculate_relative_strength(stock_data, spy_data)
|
||||||
|
assert isinstance(rs, float)
|
||||||
|
assert not np.isnan(rs)
|
||||||
|
|
||||||
|
# Step 4: Calculate rolling correlation
|
||||||
|
correlation = calculate_rolling_correlation(stock_data, spy_data, window=63)
|
||||||
|
assert isinstance(correlation, pd.Series)
|
||||||
|
assert len(correlation.dropna()) > 0
|
||||||
|
|
||||||
|
# Step 5: Calculate beta
|
||||||
|
beta = calculate_beta(stock_data, spy_data, window=252)
|
||||||
|
assert isinstance(beta, float)
|
||||||
|
assert not np.isnan(beta)
|
||||||
|
|
||||||
|
# Verify reasonable values
|
||||||
|
assert -200 < rs < 200
|
||||||
|
assert (correlation.dropna() >= -1.0).all()
|
||||||
|
assert (correlation.dropna() <= 1.0).all()
|
||||||
|
# Beta can be high for synthetic test data with varying volatility
|
||||||
|
assert -10 < beta < 10
|
||||||
|
|
||||||
|
@patch('tradingagents.dataflows.benchmark.yf')
|
||||||
|
def test_multi_sector_comparison(
|
||||||
|
self,
|
||||||
|
mock_yf,
|
||||||
|
yfinance_stock_data,
|
||||||
|
yfinance_spy_data,
|
||||||
|
yfinance_sector_data_xlk,
|
||||||
|
yfinance_sector_data_xlf,
|
||||||
|
yfinance_sector_data_xle
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test comparing stock performance against multiple sector ETFs.
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
1. Fetch stock data
|
||||||
|
2. Fetch SPY and multiple sector ETFs
|
||||||
|
3. Calculate RS against each benchmark
|
||||||
|
4. Compare results
|
||||||
|
"""
|
||||||
|
from tradingagents.dataflows.benchmark import (
|
||||||
|
get_benchmark_data,
|
||||||
|
get_sector_etf_data,
|
||||||
|
calculate_relative_strength
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
def ticker_side_effect(symbol):
|
||||||
|
mock_ticker_instance = MagicMock()
|
||||||
|
data_map = {
|
||||||
|
'AAPL': yfinance_stock_data,
|
||||||
|
'SPY': yfinance_spy_data,
|
||||||
|
'XLK': yfinance_sector_data_xlk,
|
||||||
|
'XLF': yfinance_sector_data_xlf,
|
||||||
|
'XLE': yfinance_sector_data_xle,
|
||||||
|
}
|
||||||
|
mock_ticker_instance.history.return_value = data_map.get(
|
||||||
|
symbol,
|
||||||
|
pd.DataFrame()
|
||||||
|
)
|
||||||
|
return mock_ticker_instance
|
||||||
|
|
||||||
|
mock_yf.Ticker.side_effect = ticker_side_effect
|
||||||
|
|
||||||
|
# Fetch stock data
|
||||||
|
stock_data = get_benchmark_data('AAPL', '2024-01-01', '2024-10-31')
|
||||||
|
assert isinstance(stock_data, pd.DataFrame)
|
||||||
|
|
||||||
|
# Calculate RS against multiple benchmarks
|
||||||
|
rs_results = {}
|
||||||
|
|
||||||
|
# vs SPY
|
||||||
|
spy_data = get_benchmark_data('SPY', '2024-01-01', '2024-10-31')
|
||||||
|
rs_results['SPY'] = calculate_relative_strength(stock_data, spy_data)
|
||||||
|
|
||||||
|
# vs Technology (XLK)
|
||||||
|
tech_data = get_sector_etf_data('technology', '2024-01-01', '2024-10-31')
|
||||||
|
rs_results['XLK'] = calculate_relative_strength(stock_data, tech_data)
|
||||||
|
|
||||||
|
# vs Financials (XLF)
|
||||||
|
finance_data = get_sector_etf_data('financials', '2024-01-01', '2024-10-31')
|
||||||
|
rs_results['XLF'] = calculate_relative_strength(stock_data, finance_data)
|
||||||
|
|
||||||
|
# vs Energy (XLE)
|
||||||
|
energy_data = get_sector_etf_data('energy', '2024-01-01', '2024-10-31')
|
||||||
|
rs_results['XLE'] = calculate_relative_strength(stock_data, energy_data)
|
||||||
|
|
||||||
|
# Assert all RS calculations succeeded
|
||||||
|
for benchmark, rs in rs_results.items():
|
||||||
|
assert isinstance(rs, float), f"RS vs {benchmark} failed"
|
||||||
|
assert not np.isnan(rs), f"RS vs {benchmark} is NaN"
|
||||||
|
assert -200 < rs < 200, f"RS vs {benchmark} out of range"
|
||||||
|
|
||||||
|
# AAPL should have different RS against different sectors
|
||||||
|
unique_values = len(set(rs_results.values()))
|
||||||
|
assert unique_values > 1, "RS should differ across sectors"
|
||||||
|
|
||||||
|
@patch('tradingagents.dataflows.benchmark.yf')
|
||||||
|
def test_all_sector_etfs_available(self, mock_yf, all_sector_etf_data):
|
||||||
|
"""
|
||||||
|
Test that all 11 sector ETFs can be fetched.
|
||||||
|
|
||||||
|
Sectors:
|
||||||
|
- communication (XLC)
|
||||||
|
- consumer_discretionary (XLY)
|
||||||
|
- consumer_staples (XLP)
|
||||||
|
- energy (XLE)
|
||||||
|
- financials (XLF)
|
||||||
|
- healthcare (XLV)
|
||||||
|
- industrials (XLI)
|
||||||
|
- materials (XLB)
|
||||||
|
- real_estate (XLRE)
|
||||||
|
- technology (XLK)
|
||||||
|
- utilities (XLU)
|
||||||
|
"""
|
||||||
|
from tradingagents.dataflows.benchmark import get_sector_etf_data, SECTOR_ETFS
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
def ticker_side_effect(symbol):
|
||||||
|
mock_ticker_instance = MagicMock()
|
||||||
|
# Find which sector this symbol belongs to
|
||||||
|
for sector, etf_symbol in SECTOR_ETFS.items():
|
||||||
|
if etf_symbol == symbol:
|
||||||
|
mock_ticker_instance.history.return_value = all_sector_etf_data[sector]
|
||||||
|
return mock_ticker_instance
|
||||||
|
# Default empty
|
||||||
|
mock_ticker_instance.history.return_value = pd.DataFrame()
|
||||||
|
return mock_ticker_instance
|
||||||
|
|
||||||
|
mock_yf.Ticker.side_effect = ticker_side_effect
|
||||||
|
|
||||||
|
# Test each sector
|
||||||
|
sectors = [
|
||||||
|
'communication',
|
||||||
|
'consumer_discretionary',
|
||||||
|
'consumer_staples',
|
||||||
|
'energy',
|
||||||
|
'financials',
|
||||||
|
'healthcare',
|
||||||
|
'industrials',
|
||||||
|
'materials',
|
||||||
|
'real_estate',
|
||||||
|
'technology',
|
||||||
|
'utilities'
|
||||||
|
]
|
||||||
|
|
||||||
|
for sector in sectors:
|
||||||
|
result = get_sector_etf_data(sector, '2024-01-01', '2024-10-31')
|
||||||
|
assert isinstance(result, pd.DataFrame), f"Sector {sector} failed"
|
||||||
|
assert len(result) > 0, f"Sector {sector} returned empty data"
|
||||||
|
assert 'Close' in result.columns, f"Sector {sector} missing Close column"
|
||||||
|
|
||||||
|
@patch('tradingagents.dataflows.benchmark.yf')
|
||||||
|
def test_portfolio_level_analytics(
|
||||||
|
self,
|
||||||
|
mock_yf,
|
||||||
|
yfinance_spy_data,
|
||||||
|
all_sector_etf_data
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test portfolio-level analytics across all sectors.
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
1. Fetch all sector ETFs
|
||||||
|
2. Calculate correlation matrix with SPY
|
||||||
|
3. Calculate beta for each sector
|
||||||
|
4. Identify high/low correlation sectors
|
||||||
|
"""
|
||||||
|
from tradingagents.dataflows.benchmark import (
|
||||||
|
get_spy_data,
|
||||||
|
get_sector_etf_data,
|
||||||
|
calculate_rolling_correlation,
|
||||||
|
calculate_beta,
|
||||||
|
SECTOR_ETFS
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
def ticker_side_effect(symbol):
|
||||||
|
mock_ticker_instance = MagicMock()
|
||||||
|
if symbol == 'SPY':
|
||||||
|
mock_ticker_instance.history.return_value = yfinance_spy_data
|
||||||
|
else:
|
||||||
|
# Find sector for this symbol
|
||||||
|
for sector, etf_symbol in SECTOR_ETFS.items():
|
||||||
|
if etf_symbol == symbol:
|
||||||
|
mock_ticker_instance.history.return_value = all_sector_etf_data[sector]
|
||||||
|
break
|
||||||
|
return mock_ticker_instance
|
||||||
|
|
||||||
|
mock_yf.Ticker.side_effect = ticker_side_effect
|
||||||
|
|
||||||
|
# Fetch SPY
|
||||||
|
spy_data = get_spy_data('2024-01-01', '2024-10-31')
|
||||||
|
assert isinstance(spy_data, pd.DataFrame)
|
||||||
|
|
||||||
|
# Calculate analytics for each sector
|
||||||
|
sector_analytics = {}
|
||||||
|
|
||||||
|
for sector in all_sector_etf_data.keys():
|
||||||
|
sector_data = get_sector_etf_data(sector, '2024-01-01', '2024-10-31')
|
||||||
|
|
||||||
|
if isinstance(sector_data, pd.DataFrame) and len(sector_data) > 0:
|
||||||
|
# Calculate correlation
|
||||||
|
correlation = calculate_rolling_correlation(
|
||||||
|
sector_data,
|
||||||
|
spy_data,
|
||||||
|
window=63
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate beta
|
||||||
|
beta = calculate_beta(sector_data, spy_data, window=252)
|
||||||
|
|
||||||
|
sector_analytics[sector] = {
|
||||||
|
'avg_correlation': correlation.dropna().mean() if isinstance(correlation, pd.Series) else None,
|
||||||
|
'beta': beta if isinstance(beta, float) else None
|
||||||
|
}
|
||||||
|
|
||||||
|
# Assert we got analytics for all sectors
|
||||||
|
assert len(sector_analytics) == 11, "Should have analytics for all 11 sectors"
|
||||||
|
|
||||||
|
# Assert all analytics are valid
|
||||||
|
for sector, analytics in sector_analytics.items():
|
||||||
|
if analytics['avg_correlation'] is not None:
|
||||||
|
assert -1.0 <= analytics['avg_correlation'] <= 1.0, \
|
||||||
|
f"Correlation for {sector} out of range"
|
||||||
|
|
||||||
|
if analytics['beta'] is not None:
|
||||||
|
assert not np.isnan(analytics['beta']), \
|
||||||
|
f"Beta for {sector} is NaN"
|
||||||
|
# Beta can be high for synthetic test data with varying volatility
|
||||||
|
assert -10 < analytics['beta'] < 10, \
|
||||||
|
f"Beta for {sector} out of reasonable range"
|
||||||
|
|
||||||
|
# Identify high correlation sectors (should correlate well with SPY)
|
||||||
|
high_corr_sectors = [
|
||||||
|
sector for sector, analytics in sector_analytics.items()
|
||||||
|
if analytics['avg_correlation'] is not None and analytics['avg_correlation'] > 0.7
|
||||||
|
]
|
||||||
|
|
||||||
|
# Most sectors should have positive correlation with market
|
||||||
|
assert len(high_corr_sectors) >= 1, "At least one sector should correlate with SPY"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Test Class: Real-World Data Format Handling
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestRealWorldDataFormat:
|
||||||
|
"""
|
||||||
|
Test suite for handling real-world data format quirks.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- Timezone-aware DatetimeIndex
|
||||||
|
- Column name variations
|
||||||
|
- Missing data handling
|
||||||
|
- Date range alignment
|
||||||
|
"""
|
||||||
|
|
||||||
|
@patch('tradingagents.dataflows.benchmark.yf')
|
||||||
|
def test_timezone_aware_data(self, mock_yf):
|
||||||
|
"""Test handling of timezone-aware yfinance data."""
|
||||||
|
from tradingagents.dataflows.benchmark import get_benchmark_data
|
||||||
|
|
||||||
|
# Create timezone-aware data
|
||||||
|
dates = pd.date_range('2024-01-01', periods=300, freq='D', tz='America/New_York')
|
||||||
|
tz_data = pd.DataFrame({
|
||||||
|
'Open': [100.0 + i * 0.1 for i in range(300)],
|
||||||
|
'High': [101.0 + i * 0.1 for i in range(300)],
|
||||||
|
'Low': [99.0 + i * 0.1 for i in range(300)],
|
||||||
|
'Close': [100.5 + i * 0.1 for i in range(300)],
|
||||||
|
'Volume': [1000000] * 300,
|
||||||
|
}, index=dates)
|
||||||
|
|
||||||
|
# Setup mock
|
||||||
|
mock_ticker_instance = MagicMock()
|
||||||
|
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||||
|
mock_ticker_instance.history.return_value = tz_data
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = get_benchmark_data('SPY', '2024-01-01', '2024-10-31')
|
||||||
|
|
||||||
|
# Assert - should handle timezone-aware data
|
||||||
|
assert isinstance(result, pd.DataFrame)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
@patch('tradingagents.dataflows.benchmark.yf')
|
||||||
|
def test_business_day_frequency(self, mock_yf):
|
||||||
|
"""Test handling of business day frequency data (no weekends)."""
|
||||||
|
from tradingagents.dataflows.benchmark import get_benchmark_data, calculate_relative_strength
|
||||||
|
|
||||||
|
# Create business day data
|
||||||
|
dates = pd.bdate_range('2024-01-01', periods=250, freq='B')
|
||||||
|
|
||||||
|
spy_data = pd.DataFrame({
|
||||||
|
'Open': [450.0 + i * 0.3 for i in range(250)],
|
||||||
|
'High': [452.0 + i * 0.3 for i in range(250)],
|
||||||
|
'Low': [449.0 + i * 0.3 for i in range(250)],
|
||||||
|
'Close': [451.0 + i * 0.3 for i in range(250)],
|
||||||
|
'Volume': [80000000] * 250,
|
||||||
|
}, index=dates)
|
||||||
|
|
||||||
|
stock_data = pd.DataFrame({
|
||||||
|
'Open': [180.0 + i * 0.4 for i in range(250)],
|
||||||
|
'High': [182.0 + i * 0.4 for i in range(250)],
|
||||||
|
'Low': [179.0 + i * 0.4 for i in range(250)],
|
||||||
|
'Close': [181.0 + i * 0.4 for i in range(250)],
|
||||||
|
'Volume': [50000000] * 250,
|
||||||
|
}, index=dates)
|
||||||
|
|
||||||
|
# Setup mock
|
||||||
|
def ticker_side_effect(symbol):
|
||||||
|
mock_ticker_instance = MagicMock()
|
||||||
|
if symbol == 'SPY':
|
||||||
|
mock_ticker_instance.history.return_value = spy_data
|
||||||
|
else:
|
||||||
|
mock_ticker_instance.history.return_value = stock_data
|
||||||
|
return mock_ticker_instance
|
||||||
|
|
||||||
|
mock_yf.Ticker.side_effect = ticker_side_effect
|
||||||
|
|
||||||
|
# Fetch data
|
||||||
|
result_spy = get_benchmark_data('SPY', '2024-01-01', '2024-12-31')
|
||||||
|
result_stock = get_benchmark_data('AAPL', '2024-01-01', '2024-12-31')
|
||||||
|
|
||||||
|
# Calculate RS
|
||||||
|
rs = calculate_relative_strength(result_stock, result_spy)
|
||||||
|
|
||||||
|
# Assert - should handle business days correctly
|
||||||
|
assert isinstance(rs, float)
|
||||||
|
assert not np.isnan(rs)
|
||||||
|
|
||||||
|
@patch('tradingagents.dataflows.benchmark.yf')
|
||||||
|
def test_date_range_alignment(self, mock_yf):
|
||||||
|
"""Test automatic date range alignment between stock and benchmark."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||||
|
|
||||||
|
# Create overlapping but not identical date ranges
|
||||||
|
spy_dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||||
|
stock_dates = pd.date_range('2024-01-15', periods=280, freq='D') # Starts 14 days later
|
||||||
|
|
||||||
|
spy_data = pd.DataFrame({
|
||||||
|
'Close': [450.0 + i * 0.3 for i in range(300)],
|
||||||
|
'Volume': [80000000] * 300,
|
||||||
|
}, index=spy_dates)
|
||||||
|
|
||||||
|
stock_data = pd.DataFrame({
|
||||||
|
'Close': [180.0 + i * 0.4 for i in range(280)],
|
||||||
|
'Volume': [50000000] * 280,
|
||||||
|
}, index=stock_dates)
|
||||||
|
|
||||||
|
# Add other required columns
|
||||||
|
for df in [spy_data, stock_data]:
|
||||||
|
df['Open'] = df['Close'] - 0.5
|
||||||
|
df['High'] = df['Close'] + 1.0
|
||||||
|
df['Low'] = df['Close'] - 1.0
|
||||||
|
|
||||||
|
# Execute RS calculation - should align dates internally
|
||||||
|
result = calculate_relative_strength(stock_data, spy_data)
|
||||||
|
|
||||||
|
# Assert - should handle date alignment
|
||||||
|
# Either returns valid RS or error message
|
||||||
|
if isinstance(result, float):
|
||||||
|
assert not np.isnan(result)
|
||||||
|
else:
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
@ -0,0 +1,753 @@
|
||||||
|
"""
|
||||||
|
Test suite for Benchmark Data Functions (benchmark.py).
|
||||||
|
|
||||||
|
This module tests:
|
||||||
|
1. get_benchmark_data() - Generic benchmark data fetcher via yfinance
|
||||||
|
2. get_spy_data() - Convenience wrapper for SPY
|
||||||
|
3. get_sector_etf_data() - Sector ETF data (XLF, XLK, XLE, etc.)
|
||||||
|
4. calculate_relative_strength() - IBD-style RS calculation
|
||||||
|
5. calculate_rolling_correlation() - Rolling correlation between stock and benchmark
|
||||||
|
6. calculate_beta() - Beta calculation (Cov/Var)
|
||||||
|
|
||||||
|
Test Coverage:
|
||||||
|
- Unit tests for each function
|
||||||
|
- Valid data fetching (SPY, sector ETFs)
|
||||||
|
- Invalid inputs (bad symbols, dates, sectors)
|
||||||
|
- RS calculation with IBD formula: 0.4*ROC(63) + 0.2*ROC(126) + 0.2*ROC(189) + 0.2*ROC(252)
|
||||||
|
- Rolling correlation with configurable window
|
||||||
|
- Beta calculation with market variance
|
||||||
|
- Edge cases (empty data, insufficient data, missing columns, date misalignment)
|
||||||
|
- Zero returns handling
|
||||||
|
- Extreme values
|
||||||
|
|
||||||
|
SECTOR_ETFS Constants:
|
||||||
|
- communication: XLC
|
||||||
|
- consumer_discretionary: XLY
|
||||||
|
- consumer_staples: XLP
|
||||||
|
- energy: XLE
|
||||||
|
- financials: XLF
|
||||||
|
- healthcare: XLV
|
||||||
|
- industrials: XLI
|
||||||
|
- materials: XLB
|
||||||
|
- real_estate: XLRE
|
||||||
|
- technology: XLK
|
||||||
|
- utilities: XLU
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Fixtures
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_spy_data():
|
||||||
|
"""
|
||||||
|
Create 300 days of sample SPY OHLCV data.
|
||||||
|
|
||||||
|
Returns a DataFrame with DatetimeIndex and columns: Open, High, Low, Close, Volume.
|
||||||
|
Simulates realistic market data with upward trend.
|
||||||
|
"""
|
||||||
|
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||||
|
|
||||||
|
data = []
|
||||||
|
base_price = 450.0
|
||||||
|
|
||||||
|
for i, date in enumerate(dates):
|
||||||
|
# Simulate gradual upward trend with some volatility
|
||||||
|
trend = i * 0.3
|
||||||
|
volatility = np.sin(i / 10) * 2
|
||||||
|
|
||||||
|
open_price = base_price + trend + volatility
|
||||||
|
high_price = open_price + 1.5 + abs(np.cos(i / 5))
|
||||||
|
low_price = open_price - 1.5 - abs(np.sin(i / 7))
|
||||||
|
close_price = open_price + 0.5 + np.sin(i / 3) * 0.5
|
||||||
|
volume = 80000000 + i * 100000
|
||||||
|
|
||||||
|
data.append({
|
||||||
|
'Open': round(open_price, 2),
|
||||||
|
'High': round(high_price, 2),
|
||||||
|
'Low': round(low_price, 2),
|
||||||
|
'Close': round(close_price, 2),
|
||||||
|
'Volume': volume
|
||||||
|
})
|
||||||
|
|
||||||
|
df = pd.DataFrame(data, index=dates)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_stock_data():
|
||||||
|
"""
|
||||||
|
Create 300 days of sample stock OHLCV data (AAPL-like pattern).
|
||||||
|
|
||||||
|
Returns a DataFrame with DatetimeIndex and columns: Open, High, Low, Close, Volume.
|
||||||
|
Simulates tech stock with higher volatility than SPY.
|
||||||
|
"""
|
||||||
|
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||||
|
|
||||||
|
data = []
|
||||||
|
base_price = 180.0
|
||||||
|
|
||||||
|
for i, date in enumerate(dates):
|
||||||
|
# Higher volatility and stronger trend than SPY
|
||||||
|
trend = i * 0.4
|
||||||
|
volatility = np.sin(i / 8) * 3
|
||||||
|
|
||||||
|
open_price = base_price + trend + volatility
|
||||||
|
high_price = open_price + 2.0 + abs(np.cos(i / 4))
|
||||||
|
low_price = open_price - 2.0 - abs(np.sin(i / 6))
|
||||||
|
close_price = open_price + 0.8 + np.sin(i / 2.5) * 0.8
|
||||||
|
volume = 50000000 + i * 80000
|
||||||
|
|
||||||
|
data.append({
|
||||||
|
'Open': round(open_price, 2),
|
||||||
|
'High': round(high_price, 2),
|
||||||
|
'Low': round(low_price, 2),
|
||||||
|
'Close': round(close_price, 2),
|
||||||
|
'Volume': volume
|
||||||
|
})
|
||||||
|
|
||||||
|
df = pd.DataFrame(data, index=dates)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_sector_data():
|
||||||
|
"""
|
||||||
|
Create 300 days of sample XLK (technology sector) OHLCV data.
|
||||||
|
|
||||||
|
Returns a DataFrame with DatetimeIndex and columns: Open, High, Low, Close, Volume.
|
||||||
|
Similar to SPY but with sector-specific characteristics.
|
||||||
|
"""
|
||||||
|
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||||
|
|
||||||
|
data = []
|
||||||
|
base_price = 200.0
|
||||||
|
|
||||||
|
for i, date in enumerate(dates):
|
||||||
|
# Tech sector - moderate trend with cyclical volatility
|
||||||
|
trend = i * 0.35
|
||||||
|
volatility = np.sin(i / 12) * 2.5
|
||||||
|
|
||||||
|
open_price = base_price + trend + volatility
|
||||||
|
high_price = open_price + 1.8 + abs(np.cos(i / 5.5))
|
||||||
|
low_price = open_price - 1.8 - abs(np.sin(i / 6.5))
|
||||||
|
close_price = open_price + 0.6 + np.sin(i / 3.5) * 0.6
|
||||||
|
volume = 10000000 + i * 50000
|
||||||
|
|
||||||
|
data.append({
|
||||||
|
'Open': round(open_price, 2),
|
||||||
|
'High': round(high_price, 2),
|
||||||
|
'Low': round(low_price, 2),
|
||||||
|
'Close': round(close_price, 2),
|
||||||
|
'Volume': volume
|
||||||
|
})
|
||||||
|
|
||||||
|
df = pd.DataFrame(data, index=dates)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def empty_dataframe():
|
||||||
|
"""Create empty DataFrame for validation testing."""
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def insufficient_data():
|
||||||
|
"""
|
||||||
|
Create DataFrame with only 50 days (insufficient for full RS calculation).
|
||||||
|
|
||||||
|
Insufficient for 252-day ROC calculation in relative strength.
|
||||||
|
"""
|
||||||
|
dates = pd.date_range('2024-01-01', periods=50, freq='D')
|
||||||
|
|
||||||
|
data = []
|
||||||
|
base_price = 100.0
|
||||||
|
|
||||||
|
for i, date in enumerate(dates):
|
||||||
|
open_price = base_price + i * 0.2
|
||||||
|
data.append({
|
||||||
|
'Open': round(open_price, 2),
|
||||||
|
'High': round(open_price + 1.0, 2),
|
||||||
|
'Low': round(open_price - 0.5, 2),
|
||||||
|
'Close': round(open_price + 0.3, 2),
|
||||||
|
'Volume': 1000000
|
||||||
|
})
|
||||||
|
|
||||||
|
df = pd.DataFrame(data, index=dates)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def missing_close_data():
|
||||||
|
"""Create DataFrame missing Close column."""
|
||||||
|
dates = pd.date_range('2024-01-01', periods=100, freq='D')
|
||||||
|
return pd.DataFrame({
|
||||||
|
'Open': [100.0 + i * 0.1 for i in range(100)],
|
||||||
|
'High': [102.0 + i * 0.1 for i in range(100)],
|
||||||
|
'Low': [99.0 + i * 0.1 for i in range(100)],
|
||||||
|
'Volume': [1000000] * 100,
|
||||||
|
}, index=dates)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def misaligned_dates_data(sample_spy_data):
|
||||||
|
"""
|
||||||
|
Create stock data with different date range than SPY.
|
||||||
|
|
||||||
|
50-day offset to test date alignment handling.
|
||||||
|
"""
|
||||||
|
offset_dates = pd.date_range('2024-02-20', periods=250, freq='D')
|
||||||
|
|
||||||
|
data = []
|
||||||
|
base_price = 150.0
|
||||||
|
|
||||||
|
for i, date in enumerate(offset_dates):
|
||||||
|
open_price = base_price + i * 0.2
|
||||||
|
data.append({
|
||||||
|
'Open': round(open_price, 2),
|
||||||
|
'High': round(open_price + 1.5, 2),
|
||||||
|
'Low': round(open_price - 1.0, 2),
|
||||||
|
'Close': round(open_price + 0.5, 2),
|
||||||
|
'Volume': 2000000
|
||||||
|
})
|
||||||
|
|
||||||
|
return pd.DataFrame(data, index=offset_dates)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def zero_returns_data():
|
||||||
|
"""Create DataFrame with zero returns (flat prices)."""
|
||||||
|
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||||
|
|
||||||
|
# All prices are constant (no returns)
|
||||||
|
constant_price = 100.0
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for date in dates:
|
||||||
|
data.append({
|
||||||
|
'Open': constant_price,
|
||||||
|
'High': constant_price,
|
||||||
|
'Low': constant_price,
|
||||||
|
'Close': constant_price,
|
||||||
|
'Volume': 1000000
|
||||||
|
})
|
||||||
|
|
||||||
|
df = pd.DataFrame(data, index=dates)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def extreme_values_data():
|
||||||
|
"""Create DataFrame with extreme price values."""
|
||||||
|
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for i, date in enumerate(dates):
|
||||||
|
# Extreme volatility: 50% daily moves
|
||||||
|
if i % 2 == 0:
|
||||||
|
close = 100.0 * (1.5 ** (i // 2))
|
||||||
|
else:
|
||||||
|
close = 100.0 * (1.5 ** (i // 2)) * 0.5
|
||||||
|
|
||||||
|
data.append({
|
||||||
|
'Open': close,
|
||||||
|
'High': close * 1.1,
|
||||||
|
'Low': close * 0.9,
|
||||||
|
'Close': close,
|
||||||
|
'Volume': 1000000
|
||||||
|
})
|
||||||
|
|
||||||
|
df = pd.DataFrame(data, index=dates)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Test Class: Benchmark Data Fetching
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestBenchmarkDataFetching:
|
||||||
|
"""
|
||||||
|
Test suite for benchmark data fetching functions.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- get_benchmark_data() with valid symbols
|
||||||
|
- get_spy_data() convenience wrapper
|
||||||
|
- get_sector_etf_data() with valid sectors
|
||||||
|
- Invalid symbol handling
|
||||||
|
- Invalid sector handling
|
||||||
|
- Invalid date handling
|
||||||
|
- Empty data handling
|
||||||
|
"""
|
||||||
|
|
||||||
|
@patch('tradingagents.dataflows.benchmark.yf')
|
||||||
|
def test_get_benchmark_data_valid_spy(self, mock_yf, sample_spy_data):
|
||||||
|
"""Test fetching valid SPY benchmark data."""
|
||||||
|
from tradingagents.dataflows.benchmark import get_benchmark_data
|
||||||
|
|
||||||
|
# Setup mock
|
||||||
|
mock_ticker_instance = MagicMock()
|
||||||
|
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||||
|
mock_ticker_instance.history.return_value = sample_spy_data
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = get_benchmark_data('SPY', '2024-01-01', '2024-10-31')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, pd.DataFrame)
|
||||||
|
assert len(result) > 0
|
||||||
|
assert 'Close' in result.columns
|
||||||
|
assert 'Volume' in result.columns
|
||||||
|
mock_yf.Ticker.assert_called_once_with('SPY')
|
||||||
|
mock_ticker_instance.history.assert_called_once()
|
||||||
|
|
||||||
|
@patch('tradingagents.dataflows.benchmark.yf')
|
||||||
|
def test_get_benchmark_data_valid_sector_etf(self, mock_yf, sample_sector_data):
|
||||||
|
"""Test fetching valid sector ETF data (XLK)."""
|
||||||
|
from tradingagents.dataflows.benchmark import get_benchmark_data
|
||||||
|
|
||||||
|
# Setup mock
|
||||||
|
mock_ticker_instance = MagicMock()
|
||||||
|
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||||
|
mock_ticker_instance.history.return_value = sample_sector_data
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = get_benchmark_data('XLK', '2024-01-01', '2024-10-31')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, pd.DataFrame)
|
||||||
|
assert len(result) > 0
|
||||||
|
assert 'Close' in result.columns
|
||||||
|
mock_yf.Ticker.assert_called_once_with('XLK')
|
||||||
|
|
||||||
|
@patch('tradingagents.dataflows.benchmark.yf')
|
||||||
|
def test_get_benchmark_data_invalid_symbol(self, mock_yf):
|
||||||
|
"""Test handling of invalid ticker symbol."""
|
||||||
|
from tradingagents.dataflows.benchmark import get_benchmark_data
|
||||||
|
|
||||||
|
# Setup mock to raise exception
|
||||||
|
mock_ticker_instance = MagicMock()
|
||||||
|
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||||
|
mock_ticker_instance.history.side_effect = Exception("Invalid ticker")
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = get_benchmark_data('INVALID_SYMBOL', '2024-01-01', '2024-10-31')
|
||||||
|
|
||||||
|
# Assert - should return error string
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert 'error' in result.lower() or 'invalid' in result.lower()
|
||||||
|
|
||||||
|
@patch('tradingagents.dataflows.benchmark.yf')
|
||||||
|
def test_get_benchmark_data_empty_data(self, mock_yf):
|
||||||
|
"""Test handling when yfinance returns empty DataFrame."""
|
||||||
|
from tradingagents.dataflows.benchmark import get_benchmark_data
|
||||||
|
|
||||||
|
# Setup mock to return empty DataFrame
|
||||||
|
mock_ticker_instance = MagicMock()
|
||||||
|
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||||
|
mock_ticker_instance.history.return_value = pd.DataFrame()
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = get_benchmark_data('SPY', '2024-01-01', '2024-01-02')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert 'no data' in result.lower() or 'empty' in result.lower()
|
||||||
|
|
||||||
|
@patch('tradingagents.dataflows.benchmark.yf')
|
||||||
|
def test_get_spy_data(self, mock_yf, sample_spy_data):
|
||||||
|
"""Test get_spy_data() convenience wrapper."""
|
||||||
|
from tradingagents.dataflows.benchmark import get_spy_data
|
||||||
|
|
||||||
|
# Setup mock
|
||||||
|
mock_ticker_instance = MagicMock()
|
||||||
|
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||||
|
mock_ticker_instance.history.return_value = sample_spy_data
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = get_spy_data('2024-01-01', '2024-10-31')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, pd.DataFrame)
|
||||||
|
assert len(result) > 0
|
||||||
|
mock_yf.Ticker.assert_called_once_with('SPY')
|
||||||
|
|
||||||
|
@patch('tradingagents.dataflows.benchmark.yf')
|
||||||
|
def test_get_sector_etf_data_valid(self, mock_yf, sample_sector_data):
|
||||||
|
"""Test get_sector_etf_data() with valid sector."""
|
||||||
|
from tradingagents.dataflows.benchmark import get_sector_etf_data
|
||||||
|
|
||||||
|
# Setup mock
|
||||||
|
mock_ticker_instance = MagicMock()
|
||||||
|
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||||
|
mock_ticker_instance.history.return_value = sample_sector_data
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = get_sector_etf_data('technology', '2024-01-01', '2024-10-31')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, pd.DataFrame)
|
||||||
|
assert len(result) > 0
|
||||||
|
mock_yf.Ticker.assert_called_once_with('XLK')
|
||||||
|
|
||||||
|
def test_get_sector_etf_data_invalid_sector(self):
|
||||||
|
"""Test get_sector_etf_data() with invalid sector name."""
|
||||||
|
from tradingagents.dataflows.benchmark import get_sector_etf_data
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = get_sector_etf_data('invalid_sector', '2024-01-01', '2024-10-31')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert 'invalid' in result.lower() or 'unknown' in result.lower()
|
||||||
|
|
||||||
|
@patch('tradingagents.dataflows.benchmark.yf')
|
||||||
|
def test_get_benchmark_data_invalid_dates(self, mock_yf):
|
||||||
|
"""Test handling of invalid date format."""
|
||||||
|
from tradingagents.dataflows.benchmark import get_benchmark_data
|
||||||
|
|
||||||
|
# Setup mock to raise exception on invalid dates
|
||||||
|
mock_ticker_instance = MagicMock()
|
||||||
|
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||||
|
mock_ticker_instance.history.side_effect = ValueError("Invalid date format")
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = get_benchmark_data('SPY', 'invalid-date', '2024-10-31')
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert 'error' in result.lower() or 'invalid' in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Test Class: Relative Strength Calculation
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestRelativeStrength:
|
||||||
|
"""
|
||||||
|
Test suite for IBD-style relative strength calculation.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- RS calculation with IBD formula: 0.4*ROC(63) + 0.2*ROC(126) + 0.2*ROC(189) + 0.2*ROC(252)
|
||||||
|
- Insufficient data handling (< 252 days)
|
||||||
|
- Missing Close column
|
||||||
|
- Date misalignment between stock and benchmark
|
||||||
|
- Zero returns (flat prices)
|
||||||
|
|
||||||
|
IBD Relative Strength:
|
||||||
|
- 40% weight on 63-day (3-month) ROC
|
||||||
|
- 20% weight on 126-day (6-month) ROC
|
||||||
|
- 20% weight on 189-day (9-month) ROC
|
||||||
|
- 20% weight on 252-day (12-month) ROC
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_calculate_relative_strength_valid(self, sample_stock_data, sample_spy_data):
|
||||||
|
"""Test RS calculation with valid data."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = calculate_relative_strength(sample_stock_data, sample_spy_data)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, float)
|
||||||
|
assert not np.isnan(result)
|
||||||
|
assert not np.isinf(result)
|
||||||
|
# RS should be reasonable (typically between -100 and 100 for normal stocks)
|
||||||
|
assert -200 < result < 200
|
||||||
|
|
||||||
|
def test_calculate_relative_strength_custom_periods(self, sample_stock_data, sample_spy_data):
|
||||||
|
"""Test RS calculation with custom periods."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||||
|
|
||||||
|
# Execute with shorter periods (all data has 300 days)
|
||||||
|
result = calculate_relative_strength(
|
||||||
|
sample_stock_data,
|
||||||
|
sample_spy_data,
|
||||||
|
periods=[20, 60, 120, 180]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, float)
|
||||||
|
assert not np.isnan(result)
|
||||||
|
assert -200 < result < 200
|
||||||
|
|
||||||
|
def test_calculate_relative_strength_insufficient_data(self, insufficient_data, sample_spy_data):
|
||||||
|
"""Test RS calculation with insufficient data (< 252 days)."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||||
|
|
||||||
|
# Execute - only 50 days, need 252 for default periods
|
||||||
|
result = calculate_relative_strength(insufficient_data, sample_spy_data)
|
||||||
|
|
||||||
|
# Assert - should return error string
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert 'insufficient' in result.lower() or 'not enough' in result.lower()
|
||||||
|
|
||||||
|
def test_calculate_relative_strength_missing_close(self, missing_close_data, sample_spy_data):
|
||||||
|
"""Test RS calculation with missing Close column."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = calculate_relative_strength(missing_close_data, sample_spy_data)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert 'close' in result.lower() or 'missing' in result.lower()
|
||||||
|
|
||||||
|
def test_calculate_relative_strength_date_misalignment(self, misaligned_dates_data, sample_spy_data):
|
||||||
|
"""Test RS calculation with misaligned date ranges."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||||
|
|
||||||
|
# Execute - stock data starts 50 days later than SPY
|
||||||
|
result = calculate_relative_strength(misaligned_dates_data, sample_spy_data)
|
||||||
|
|
||||||
|
# Assert - function should handle alignment or return error
|
||||||
|
# Either valid RS (if aligned) or error message
|
||||||
|
if isinstance(result, str):
|
||||||
|
assert 'align' in result.lower() or 'date' in result.lower() or 'insufficient' in result.lower()
|
||||||
|
else:
|
||||||
|
assert isinstance(result, float)
|
||||||
|
assert not np.isnan(result)
|
||||||
|
|
||||||
|
def test_calculate_relative_strength_zero_returns(self, zero_returns_data, sample_spy_data):
|
||||||
|
"""Test RS calculation with zero returns (flat prices)."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = calculate_relative_strength(zero_returns_data, sample_spy_data)
|
||||||
|
|
||||||
|
# Assert - should handle zero returns gracefully
|
||||||
|
# RS should be negative since stock has 0 returns while benchmark moves
|
||||||
|
if isinstance(result, float):
|
||||||
|
assert not np.isnan(result)
|
||||||
|
assert result < 0 # Stock underperforming benchmark
|
||||||
|
else:
|
||||||
|
# Or return error for zero variance
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Test Class: Correlation Analytics
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestCorrelationAnalytics:
|
||||||
|
"""
|
||||||
|
Test suite for correlation and beta calculations.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- calculate_rolling_correlation() with various windows
|
||||||
|
- calculate_beta() calculation
|
||||||
|
- Window validation (must be >= 2)
|
||||||
|
- Insufficient data for window size
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_calculate_rolling_correlation_valid(self, sample_stock_data, sample_spy_data):
|
||||||
|
"""Test rolling correlation calculation with default window (63 days)."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_rolling_correlation
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = calculate_rolling_correlation(sample_stock_data, sample_spy_data)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, pd.Series)
|
||||||
|
assert len(result) > 0
|
||||||
|
# Correlation values should be between -1 and 1
|
||||||
|
assert (result.dropna() >= -1.0).all()
|
||||||
|
assert (result.dropna() <= 1.0).all()
|
||||||
|
|
||||||
|
def test_calculate_rolling_correlation_custom_window(self, sample_stock_data, sample_spy_data):
|
||||||
|
"""Test rolling correlation with custom window size."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_rolling_correlation
|
||||||
|
|
||||||
|
# Execute with 20-day window
|
||||||
|
result = calculate_rolling_correlation(sample_stock_data, sample_spy_data, window=20)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, pd.Series)
|
||||||
|
assert len(result) > 0
|
||||||
|
assert (result.dropna() >= -1.0).all()
|
||||||
|
assert (result.dropna() <= 1.0).all()
|
||||||
|
|
||||||
|
def test_calculate_rolling_correlation_invalid_window(self, sample_stock_data, sample_spy_data):
|
||||||
|
"""Test rolling correlation with invalid window (< 2)."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_rolling_correlation
|
||||||
|
|
||||||
|
# Execute with window=1
|
||||||
|
result = calculate_rolling_correlation(sample_stock_data, sample_spy_data, window=1)
|
||||||
|
|
||||||
|
# Assert - should return error
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert 'window' in result.lower() or 'invalid' in result.lower()
|
||||||
|
|
||||||
|
def test_calculate_rolling_correlation_insufficient_data(self, insufficient_data, sample_spy_data):
|
||||||
|
"""Test rolling correlation with insufficient data for window."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_rolling_correlation
|
||||||
|
|
||||||
|
# Execute - only 50 days but default window is 63
|
||||||
|
result = calculate_rolling_correlation(insufficient_data, sample_spy_data)
|
||||||
|
|
||||||
|
# Assert - should return error or empty series
|
||||||
|
if isinstance(result, str):
|
||||||
|
assert 'insufficient' in result.lower() or 'not enough' in result.lower()
|
||||||
|
elif isinstance(result, pd.Series):
|
||||||
|
# May return series with all NaN values
|
||||||
|
assert result.dropna().empty or len(result.dropna()) < 10
|
||||||
|
|
||||||
|
def test_calculate_beta_valid(self, sample_stock_data, sample_spy_data):
|
||||||
|
"""Test beta calculation with default window (252 days)."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_beta
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = calculate_beta(sample_stock_data, sample_spy_data)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, float)
|
||||||
|
assert not np.isnan(result)
|
||||||
|
assert not np.isinf(result)
|
||||||
|
# Beta typically ranges from -2 to 3 for normal stocks
|
||||||
|
assert -5 < result < 5
|
||||||
|
|
||||||
|
def test_calculate_beta_custom_window(self, sample_stock_data, sample_spy_data):
|
||||||
|
"""Test beta calculation with custom window."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_beta
|
||||||
|
|
||||||
|
# Execute with 126-day window
|
||||||
|
result = calculate_beta(sample_stock_data, sample_spy_data, window=126)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, float)
|
||||||
|
assert not np.isnan(result)
|
||||||
|
assert -5 < result < 5
|
||||||
|
|
||||||
|
def test_calculate_beta_insufficient_data(self, insufficient_data, sample_spy_data):
|
||||||
|
"""Test beta calculation with insufficient data."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_beta
|
||||||
|
|
||||||
|
# Execute - only 50 days but default window is 252
|
||||||
|
result = calculate_beta(insufficient_data, sample_spy_data)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert 'insufficient' in result.lower() or 'not enough' in result.lower()
|
||||||
|
|
||||||
|
def test_calculate_beta_zero_variance(self, zero_returns_data, sample_spy_data):
|
||||||
|
"""Test beta calculation when stock has zero variance."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_beta
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = calculate_beta(zero_returns_data, sample_spy_data)
|
||||||
|
|
||||||
|
# Assert - should handle zero variance
|
||||||
|
if isinstance(result, float):
|
||||||
|
assert result == 0.0 or np.isnan(result)
|
||||||
|
else:
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Test Class: Edge Cases
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""
|
||||||
|
Test suite for edge cases and boundary conditions.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- Empty DataFrames
|
||||||
|
- Single day data
|
||||||
|
- Extreme values
|
||||||
|
- Date index validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_calculate_relative_strength_empty_stock(self, empty_dataframe, sample_spy_data):
|
||||||
|
"""Test RS calculation with empty stock DataFrame."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = calculate_relative_strength(empty_dataframe, sample_spy_data)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert 'empty' in result.lower() or 'no data' in result.lower()
|
||||||
|
|
||||||
|
def test_calculate_relative_strength_empty_benchmark(self, sample_stock_data, empty_dataframe):
|
||||||
|
"""Test RS calculation with empty benchmark DataFrame."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = calculate_relative_strength(sample_stock_data, empty_dataframe)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert 'empty' in result.lower() or 'no data' in result.lower()
|
||||||
|
|
||||||
|
def test_calculate_rolling_correlation_single_day(self, sample_spy_data):
|
||||||
|
"""Test rolling correlation with single day data."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_rolling_correlation
|
||||||
|
|
||||||
|
# Create single day data
|
||||||
|
single_day = sample_spy_data.iloc[[0]]
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = calculate_rolling_correlation(single_day, sample_spy_data)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, str) or (isinstance(result, pd.Series) and result.dropna().empty)
|
||||||
|
|
||||||
|
def test_calculate_beta_extreme_values(self, extreme_values_data, sample_spy_data):
|
||||||
|
"""Test beta calculation with extreme price movements."""
|
||||||
|
from tradingagents.dataflows.benchmark import calculate_beta
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = calculate_beta(extreme_values_data, sample_spy_data)
|
||||||
|
|
||||||
|
# Assert - should handle extreme values
|
||||||
|
if isinstance(result, float):
|
||||||
|
assert not np.isnan(result)
|
||||||
|
# Beta can be very high for extreme volatility
|
||||||
|
assert -100 < result < 100
|
||||||
|
else:
|
||||||
|
# Or return error for numerical issues
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
def test_get_benchmark_data_no_datetime_index(self):
|
||||||
|
"""Test that fetched data has proper DatetimeIndex."""
|
||||||
|
# This tests that the implementation converts yfinance data correctly
|
||||||
|
# Will be tested in integration tests with actual yfinance calls
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_sector_etf_constants_coverage(self):
|
||||||
|
"""Test that all expected sector ETFs are defined in SECTOR_ETFS constant."""
|
||||||
|
from tradingagents.dataflows.benchmark import SECTOR_ETFS
|
||||||
|
|
||||||
|
# Expected sectors
|
||||||
|
expected_sectors = [
|
||||||
|
'communication', 'consumer_discretionary', 'consumer_staples',
|
||||||
|
'energy', 'financials', 'healthcare', 'industrials',
|
||||||
|
'materials', 'real_estate', 'technology', 'utilities'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Assert all sectors exist
|
||||||
|
for sector in expected_sectors:
|
||||||
|
assert sector in SECTOR_ETFS, f"Missing sector: {sector}"
|
||||||
|
|
||||||
|
# Assert expected symbols
|
||||||
|
assert SECTOR_ETFS['technology'] == 'XLK'
|
||||||
|
assert SECTOR_ETFS['financials'] == 'XLF'
|
||||||
|
assert SECTOR_ETFS['energy'] == 'XLE'
|
||||||
|
assert SECTOR_ETFS['healthcare'] == 'XLV'
|
||||||
|
assert SECTOR_ETFS['industrials'] == 'XLI'
|
||||||
|
assert SECTOR_ETFS['materials'] == 'XLB'
|
||||||
|
assert SECTOR_ETFS['consumer_discretionary'] == 'XLY'
|
||||||
|
assert SECTOR_ETFS['consumer_staples'] == 'XLP'
|
||||||
|
assert SECTOR_ETFS['real_estate'] == 'XLRE'
|
||||||
|
assert SECTOR_ETFS['utilities'] == 'XLU'
|
||||||
|
assert SECTOR_ETFS['communication'] == 'XLC'
|
||||||
|
|
@ -0,0 +1,441 @@
|
||||||
|
"""
|
||||||
|
Benchmark Data Retrieval and Analysis Functions.
|
||||||
|
|
||||||
|
This module provides functions for retrieving and analyzing benchmark data:
|
||||||
|
- Benchmark data fetching (SPY, sector ETFs)
|
||||||
|
- Relative strength calculations (IBD-style)
|
||||||
|
- Rolling correlation analysis
|
||||||
|
- Beta calculations
|
||||||
|
|
||||||
|
All functions return pandas DataFrames/Series/floats on success or error strings on failure.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from tradingagents.dataflows.benchmark import (
|
||||||
|
get_spy_data,
|
||||||
|
get_sector_etf_data,
|
||||||
|
calculate_relative_strength
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get SPY benchmark data
|
||||||
|
spy_data = get_spy_data('2024-01-01', '2024-12-31')
|
||||||
|
|
||||||
|
# Get sector ETF data
|
||||||
|
tech_data = get_sector_etf_data('technology', '2024-01-01', '2024-12-31')
|
||||||
|
|
||||||
|
# Calculate relative strength
|
||||||
|
rs = calculate_relative_strength(stock_data, spy_data)
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- yfinance package: pip install yfinance
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from typing import Union, List
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Try to import yfinance, but allow it to be mocked in tests
|
||||||
|
try:
|
||||||
|
import yfinance as yf
|
||||||
|
except ImportError:
|
||||||
|
yf = None
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# SECTOR ETF Mappings
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
SECTOR_ETFS = {
|
||||||
|
'communication': 'XLC',
|
||||||
|
'consumer_discretionary': 'XLY',
|
||||||
|
'consumer_staples': 'XLP',
|
||||||
|
'energy': 'XLE',
|
||||||
|
'financials': 'XLF',
|
||||||
|
'healthcare': 'XLV',
|
||||||
|
'industrials': 'XLI',
|
||||||
|
'materials': 'XLB',
|
||||||
|
'real_estate': 'XLRE',
|
||||||
|
'technology': 'XLK',
|
||||||
|
'utilities': 'XLU'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Benchmark Data Fetching Functions
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def get_benchmark_data(
|
||||||
|
symbol: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str
|
||||||
|
) -> Union[pd.DataFrame, str]:
|
||||||
|
"""
|
||||||
|
Fetch benchmark OHLCV data via yfinance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Ticker symbol (e.g., 'SPY', 'XLK')
|
||||||
|
start_date: Start date in YYYY-MM-DD format
|
||||||
|
end_date: End date in YYYY-MM-DD format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame with DatetimeIndex and columns: Open, High, Low, Close, Volume
|
||||||
|
str with error message on failure
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> data = get_benchmark_data('SPY', '2024-01-01', '2024-12-31')
|
||||||
|
>>> data = get_benchmark_data('XLK', '2024-01-01', '2024-12-31')
|
||||||
|
"""
|
||||||
|
if yf is None:
|
||||||
|
return "Error: yfinance package is not installed. Install with: pip install yfinance"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Validate date formats
|
||||||
|
datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
|
datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
except ValueError as e:
|
||||||
|
return f"Error: Invalid date format. Use YYYY-MM-DD. Details: {str(e)}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Fetch data from yfinance
|
||||||
|
ticker = yf.Ticker(symbol)
|
||||||
|
data = ticker.history(start=start_date, end=end_date)
|
||||||
|
|
||||||
|
# Check if data is empty
|
||||||
|
if data.empty:
|
||||||
|
return f"Error: No data found for symbol '{symbol}' between {start_date} and {end_date}"
|
||||||
|
|
||||||
|
# Remove timezone info if present
|
||||||
|
if data.index.tz is not None:
|
||||||
|
data.index = data.index.tz_localize(None)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error fetching data for {symbol}: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_spy_data(
|
||||||
|
start_date: str,
|
||||||
|
end_date: str
|
||||||
|
) -> Union[pd.DataFrame, str]:
|
||||||
|
"""
|
||||||
|
Fetch SPY benchmark data (convenience wrapper).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start date in YYYY-MM-DD format
|
||||||
|
end_date: End date in YYYY-MM-DD format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame with DatetimeIndex and columns: Open, High, Low, Close, Volume
|
||||||
|
str with error message on failure
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> spy_data = get_spy_data('2024-01-01', '2024-12-31')
|
||||||
|
"""
|
||||||
|
return get_benchmark_data('SPY', start_date, end_date)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sector_etf_data(
|
||||||
|
sector: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str
|
||||||
|
) -> Union[pd.DataFrame, str]:
|
||||||
|
"""
|
||||||
|
Fetch sector ETF data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sector: Sector name (e.g., 'technology', 'financials', 'energy')
|
||||||
|
start_date: Start date in YYYY-MM-DD format
|
||||||
|
end_date: End date in YYYY-MM-DD format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame with DatetimeIndex and columns: Open, High, Low, Close, Volume
|
||||||
|
str with error message on failure
|
||||||
|
|
||||||
|
Valid Sectors:
|
||||||
|
- communication (XLC)
|
||||||
|
- consumer_discretionary (XLY)
|
||||||
|
- consumer_staples (XLP)
|
||||||
|
- energy (XLE)
|
||||||
|
- financials (XLF)
|
||||||
|
- healthcare (XLV)
|
||||||
|
- industrials (XLI)
|
||||||
|
- materials (XLB)
|
||||||
|
- real_estate (XLRE)
|
||||||
|
- technology (XLK)
|
||||||
|
- utilities (XLU)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> tech_data = get_sector_etf_data('technology', '2024-01-01', '2024-12-31')
|
||||||
|
>>> finance_data = get_sector_etf_data('financials', '2024-01-01', '2024-12-31')
|
||||||
|
"""
|
||||||
|
# Validate sector
|
||||||
|
if sector not in SECTOR_ETFS:
|
||||||
|
valid_sectors = ', '.join(sorted(SECTOR_ETFS.keys()))
|
||||||
|
return f"Error: Invalid sector '{sector}'. Valid sectors: {valid_sectors}"
|
||||||
|
|
||||||
|
# Get symbol for sector
|
||||||
|
symbol = SECTOR_ETFS[sector]
|
||||||
|
|
||||||
|
# Fetch data
|
||||||
|
return get_benchmark_data(symbol, start_date, end_date)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Relative Strength Calculation
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def calculate_relative_strength(
|
||||||
|
stock_data: pd.DataFrame,
|
||||||
|
benchmark_data: pd.DataFrame,
|
||||||
|
periods: List[int] = [63, 126, 189, 252]
|
||||||
|
) -> Union[float, str]:
|
||||||
|
"""
|
||||||
|
Calculate IBD-style relative strength.
|
||||||
|
|
||||||
|
Uses IBD formula with weighted rate of change (ROC) calculations:
|
||||||
|
- 40% weight on 63-day (3-month) ROC
|
||||||
|
- 20% weight on 126-day (6-month) ROC
|
||||||
|
- 20% weight on 189-day (9-month) ROC
|
||||||
|
- 20% weight on 252-day (12-month) ROC
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_data: DataFrame with 'Close' column
|
||||||
|
benchmark_data: DataFrame with 'Close' column
|
||||||
|
periods: List of periods for ROC calculation (default: [63, 126, 189, 252])
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Relative strength score (stock RS - benchmark RS)
|
||||||
|
Positive = stock outperforming benchmark
|
||||||
|
Negative = stock underperforming benchmark
|
||||||
|
str: Error message on failure
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> rs = calculate_relative_strength(stock_data, spy_data)
|
||||||
|
>>> rs = calculate_relative_strength(stock_data, spy_data, periods=[20, 60, 120, 180])
|
||||||
|
"""
|
||||||
|
# Validate inputs
|
||||||
|
if stock_data.empty:
|
||||||
|
return "Error: Stock data is empty"
|
||||||
|
|
||||||
|
if benchmark_data.empty:
|
||||||
|
return "Error: Benchmark data is empty"
|
||||||
|
|
||||||
|
if 'Close' not in stock_data.columns:
|
||||||
|
return "Error: Stock data missing 'Close' column"
|
||||||
|
|
||||||
|
if 'Close' not in benchmark_data.columns:
|
||||||
|
return "Error: Benchmark data missing 'Close' column"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Align dates via inner join
|
||||||
|
aligned = pd.DataFrame({
|
||||||
|
'stock_close': stock_data['Close'],
|
||||||
|
'benchmark_close': benchmark_data['Close']
|
||||||
|
}).dropna()
|
||||||
|
|
||||||
|
if aligned.empty:
|
||||||
|
return "Error: No overlapping dates between stock and benchmark data"
|
||||||
|
|
||||||
|
# Check sufficient data for longest period
|
||||||
|
# Allow some flexibility for trading days (250-252 trading days in a year)
|
||||||
|
max_period = max(periods)
|
||||||
|
# Require at least 98% of the period (e.g., 250 days for 252-day period)
|
||||||
|
min_required = int(max_period * 0.98)
|
||||||
|
if len(aligned) < min_required:
|
||||||
|
return f"Error: Insufficient data. Need at least {min_required} days, have {len(aligned)}"
|
||||||
|
|
||||||
|
# Calculate ROC for each period
|
||||||
|
stock_rocs = []
|
||||||
|
benchmark_rocs = []
|
||||||
|
|
||||||
|
for period in periods:
|
||||||
|
# ROC = (close / close.shift(period)) - 1
|
||||||
|
# Use min of period and available data for flexibility with trading days
|
||||||
|
actual_period = min(period, len(aligned) - 1)
|
||||||
|
stock_roc = (aligned['stock_close'] / aligned['stock_close'].shift(actual_period)) - 1
|
||||||
|
benchmark_roc = (aligned['benchmark_close'] / aligned['benchmark_close'].shift(actual_period)) - 1
|
||||||
|
|
||||||
|
# Get the most recent ROC value
|
||||||
|
stock_rocs.append(stock_roc.iloc[-1])
|
||||||
|
benchmark_rocs.append(benchmark_roc.iloc[-1])
|
||||||
|
|
||||||
|
# Check for NaN values
|
||||||
|
if any(np.isnan(stock_rocs)) or any(np.isnan(benchmark_rocs)):
|
||||||
|
return "Error: Unable to calculate ROC for all periods (NaN values)"
|
||||||
|
|
||||||
|
# Apply IBD weighting: 0.4, 0.2, 0.2, 0.2
|
||||||
|
weights = [0.4, 0.2, 0.2, 0.2]
|
||||||
|
|
||||||
|
# Calculate weighted RS
|
||||||
|
stock_rs = sum(roc * weight for roc, weight in zip(stock_rocs, weights))
|
||||||
|
benchmark_rs = sum(roc * weight for roc, weight in zip(benchmark_rocs, weights))
|
||||||
|
|
||||||
|
# Return relative strength (stock RS - benchmark RS)
|
||||||
|
relative_strength = stock_rs - benchmark_rs
|
||||||
|
|
||||||
|
return float(relative_strength)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error calculating relative strength: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Correlation Analysis
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def calculate_rolling_correlation(
|
||||||
|
stock_data: pd.DataFrame,
|
||||||
|
benchmark_data: pd.DataFrame,
|
||||||
|
window: int = 63
|
||||||
|
) -> Union[pd.Series, str]:
|
||||||
|
"""
|
||||||
|
Calculate rolling correlation between stock and benchmark.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_data: DataFrame with 'Close' column
|
||||||
|
benchmark_data: DataFrame with 'Close' column
|
||||||
|
window: Rolling window size in days (default: 63 for ~3 months)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.Series: Rolling correlation values (range: -1 to 1)
|
||||||
|
str: Error message on failure
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> corr = calculate_rolling_correlation(stock_data, spy_data)
|
||||||
|
>>> corr = calculate_rolling_correlation(stock_data, spy_data, window=20)
|
||||||
|
"""
|
||||||
|
# Validate window
|
||||||
|
if window < 2:
|
||||||
|
return "Error: Window must be at least 2"
|
||||||
|
|
||||||
|
# Validate inputs
|
||||||
|
if stock_data.empty:
|
||||||
|
return "Error: Stock data is empty"
|
||||||
|
|
||||||
|
if benchmark_data.empty:
|
||||||
|
return "Error: Benchmark data is empty"
|
||||||
|
|
||||||
|
if 'Close' not in stock_data.columns:
|
||||||
|
return "Error: Stock data missing 'Close' column"
|
||||||
|
|
||||||
|
if 'Close' not in benchmark_data.columns:
|
||||||
|
return "Error: Benchmark data missing 'Close' column"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Align dates via inner join
|
||||||
|
aligned = pd.DataFrame({
|
||||||
|
'stock_close': stock_data['Close'],
|
||||||
|
'benchmark_close': benchmark_data['Close']
|
||||||
|
}).dropna()
|
||||||
|
|
||||||
|
if aligned.empty:
|
||||||
|
return "Error: No overlapping dates between stock and benchmark data"
|
||||||
|
|
||||||
|
# Calculate rolling correlation
|
||||||
|
rolling_corr = aligned['stock_close'].rolling(window=window).corr(aligned['benchmark_close'])
|
||||||
|
|
||||||
|
# Clip to [-1, 1] to handle floating point precision issues
|
||||||
|
rolling_corr = rolling_corr.clip(-1.0, 1.0)
|
||||||
|
|
||||||
|
return rolling_corr
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error calculating rolling correlation: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Beta Calculation
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def calculate_beta(
|
||||||
|
stock_data: pd.DataFrame,
|
||||||
|
benchmark_data: pd.DataFrame,
|
||||||
|
window: int = 252
|
||||||
|
) -> Union[float, str]:
|
||||||
|
"""
|
||||||
|
Calculate beta (systematic risk measure).
|
||||||
|
|
||||||
|
Beta = Covariance(stock_returns, benchmark_returns) / Variance(benchmark_returns)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_data: DataFrame with 'Close' column
|
||||||
|
benchmark_data: DataFrame with 'Close' column
|
||||||
|
window: Number of days for calculation (default: 252 for ~1 year)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Beta value
|
||||||
|
Beta > 1: More volatile than benchmark
|
||||||
|
Beta = 1: Same volatility as benchmark
|
||||||
|
Beta < 1: Less volatile than benchmark
|
||||||
|
str: Error message on failure
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> beta = calculate_beta(stock_data, spy_data)
|
||||||
|
>>> beta = calculate_beta(stock_data, spy_data, window=126)
|
||||||
|
"""
|
||||||
|
# Validate inputs
|
||||||
|
if stock_data.empty:
|
||||||
|
return "Error: Stock data is empty"
|
||||||
|
|
||||||
|
if benchmark_data.empty:
|
||||||
|
return "Error: Benchmark data is empty"
|
||||||
|
|
||||||
|
if 'Close' not in stock_data.columns:
|
||||||
|
return "Error: Stock data missing 'Close' column"
|
||||||
|
|
||||||
|
if 'Close' not in benchmark_data.columns:
|
||||||
|
return "Error: Benchmark data missing 'Close' column"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Align dates via inner join
|
||||||
|
aligned = pd.DataFrame({
|
||||||
|
'stock_close': stock_data['Close'],
|
||||||
|
'benchmark_close': benchmark_data['Close']
|
||||||
|
}).dropna()
|
||||||
|
|
||||||
|
if aligned.empty:
|
||||||
|
return "Error: No overlapping dates between stock and benchmark data"
|
||||||
|
|
||||||
|
# Check sufficient data
|
||||||
|
# For beta calculation, allow some flexibility for trading days
|
||||||
|
min_required = int(window * 0.98)
|
||||||
|
if len(aligned) < min_required:
|
||||||
|
return f"Error: Insufficient data. Need at least {min_required} days, have {len(aligned)}"
|
||||||
|
|
||||||
|
# Calculate returns
|
||||||
|
stock_returns = aligned['stock_close'].pct_change()
|
||||||
|
benchmark_returns = aligned['benchmark_close'].pct_change()
|
||||||
|
|
||||||
|
# Take last window days
|
||||||
|
stock_returns_window = stock_returns.tail(window)
|
||||||
|
benchmark_returns_window = benchmark_returns.tail(window)
|
||||||
|
|
||||||
|
# Remove NaN values
|
||||||
|
valid_data = pd.DataFrame({
|
||||||
|
'stock': stock_returns_window,
|
||||||
|
'benchmark': benchmark_returns_window
|
||||||
|
}).dropna()
|
||||||
|
|
||||||
|
if valid_data.empty:
|
||||||
|
return "Error: No valid returns data after removing NaN values"
|
||||||
|
|
||||||
|
# Calculate covariance and variance
|
||||||
|
covariance = valid_data['stock'].cov(valid_data['benchmark'])
|
||||||
|
variance = valid_data['benchmark'].var()
|
||||||
|
|
||||||
|
# Handle zero variance
|
||||||
|
if variance == 0 or np.isnan(variance):
|
||||||
|
return "Error: Benchmark has zero variance (no price movement)"
|
||||||
|
|
||||||
|
# Calculate beta
|
||||||
|
beta = covariance / variance
|
||||||
|
|
||||||
|
# Check for NaN
|
||||||
|
if np.isnan(beta):
|
||||||
|
return "Error: Beta calculation resulted in NaN"
|
||||||
|
|
||||||
|
return float(beta)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error calculating beta: {str(e)}"
|
||||||
Loading…
Reference in New Issue