feat(dataflows): add benchmark data module with SPY, sector ETFs, RS, correlation, beta - Fixes #10
This commit is contained in:
parent
19171a4b31
commit
bbd85c91b6
25
CHANGELOG.md
25
CHANGELOG.md
|
|
@ -89,6 +89,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- Test coverage including rate limit handling, caching behavior, and date range filtering
|
||||
- Total: 108 tests added for FRED API feature
|
||||
|
||||
- Benchmark data retrieval and analysis (Issue #10)
|
||||
- Benchmark data module for SPY, sector ETF, and analysis functions [file:tradingagents/dataflows/benchmark.py](tradingagents/dataflows/benchmark.py) (441 lines)
|
||||
- Sector ETF mappings for all 11 SPDR sector funds (communication, consumer discretionary/staples, energy, financials, healthcare, industrials, materials, real estate, technology, utilities) [file:tradingagents/dataflows/benchmark.py:48-59](tradingagents/dataflows/benchmark.py)
|
||||
- get_benchmark_data() function for fetching OHLCV data via yfinance with date validation [file:tradingagents/dataflows/benchmark.py:67-115](tradingagents/dataflows/benchmark.py)
|
||||
- get_spy_data() convenience wrapper for S&P 500 benchmark data [file:tradingagents/dataflows/benchmark.py:117-136](tradingagents/dataflows/benchmark.py)
|
||||
- get_sector_etf_data() function for retrieving sector-specific benchmark data with sector validation [file:tradingagents/dataflows/benchmark.py:138-186](tradingagents/dataflows/benchmark.py)
|
||||
- calculate_relative_strength() function with IBD-style weighted ROC formula [file:tradingagents/dataflows/benchmark.py:188-285](tradingagents/dataflows/benchmark.py)
|
||||
- Relative strength calculation using weighted periods (40% 63-day, 20% 126-day, 20% 189-day, 20% 252-day ROC)
|
||||
- Customizable ROC periods with default IBD-style weighting
|
||||
- Data alignment via inner join with validation for overlapping dates
|
||||
- calculate_rolling_correlation() function for time-series correlation analysis [file:tradingagents/dataflows/benchmark.py:287-349](tradingagents/dataflows/benchmark.py)
|
||||
- Configurable rolling window sizes with default 60-day window
|
||||
- Comprehensive validation for data alignment and minimum data requirements
|
||||
- calculate_beta() function for volatility and systematic risk measurement [file:tradingagents/dataflows/benchmark.py:351-441](tradingagents/dataflows/benchmark.py)
|
||||
- Beta calculation using covariance-variance approach with optional smoothing
|
||||
- Optional rolling beta calculation with customizable window (default 252 days)
|
||||
- Markdown rolling window implementation for efficient computation
|
||||
- All functions return DataFrames/Series/floats on success, error strings on failure
|
||||
- Comprehensive error handling with descriptive messages and validation logic
|
||||
- Comprehensive docstrings with examples for all public functions
|
||||
- Unit test suite for benchmark functions [file:tests/unit/dataflows/test_benchmark.py](tests/unit/dataflows/test_benchmark.py) (753 lines, 28 tests)
|
||||
- Integration test suite for benchmark workflows [file:tests/integration/dataflows/test_benchmark_integration.py](tests/integration/dataflows/test_benchmark_integration.py) (593 lines, 7 tests)
|
||||
- Test coverage includes data fetching, sector validation, relative strength calculation, correlation analysis, and beta calculation
|
||||
- Total: 35 tests added for benchmark data feature
|
||||
|
||||
- Multi-timeframe OHLCV aggregation functions (Issue #9)
|
||||
- Multi-timeframe aggregation module for daily to weekly/monthly resampling [file:tradingagents/dataflows/multi_timeframe.py](tradingagents/dataflows/multi_timeframe.py) (320 lines)
|
||||
- Core OHLCV aggregation validation function _validate_ohlcv_dataframe() [file:tradingagents/dataflows/multi_timeframe.py:38-75](tradingagents/dataflows/multi_timeframe.py)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,593 @@
|
|||
"""
|
||||
Test suite for Benchmark Integration Tests.
|
||||
|
||||
This module tests:
|
||||
1. End-to-end workflows with benchmark data
|
||||
2. Multi-sector comparison analysis
|
||||
3. Real-world data format handling (yfinance compatibility)
|
||||
4. Combined analytics (RS + correlation + beta)
|
||||
5. All sector ETFs availability
|
||||
|
||||
Test Coverage:
|
||||
- Integration with yfinance data formats
|
||||
- Complete benchmark analysis workflow
|
||||
- Multi-sector relative strength comparison
|
||||
- Portfolio-level analytics
|
||||
- Date alignment across multiple datasets
|
||||
- All 11 sector ETFs (XLC, XLY, XLP, XLE, XLF, XLV, XLI, XLB, XLRE, XLK, XLU)
|
||||
|
||||
Workflow:
|
||||
1. Fetch benchmark data (SPY)
|
||||
2. Fetch stock data
|
||||
3. Calculate RS, correlation, beta
|
||||
4. Compare across sectors
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def yfinance_spy_data():
|
||||
"""
|
||||
Create SPY data in yfinance format.
|
||||
|
||||
yfinance returns:
|
||||
- DatetimeIndex (timezone-aware or naive)
|
||||
- Capitalized column names
|
||||
- Business day frequency
|
||||
"""
|
||||
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||
|
||||
data = pd.DataFrame({
|
||||
'Open': [450.0 + i * 0.3 for i in range(300)],
|
||||
'High': [452.0 + i * 0.3 for i in range(300)],
|
||||
'Low': [449.0 + i * 0.3 for i in range(300)],
|
||||
'Close': [451.0 + i * 0.3 for i in range(300)],
|
||||
'Volume': [80000000 + i * 100000 for i in range(300)],
|
||||
}, index=dates)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def yfinance_stock_data():
|
||||
"""Create stock data in yfinance format (AAPL-like)."""
|
||||
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||
|
||||
data = pd.DataFrame({
|
||||
'Open': [180.0 + i * 0.4 for i in range(300)],
|
||||
'High': [182.0 + i * 0.4 for i in range(300)],
|
||||
'Low': [179.0 + i * 0.4 for i in range(300)],
|
||||
'Close': [181.0 + i * 0.4 for i in range(300)],
|
||||
'Volume': [50000000 + i * 80000 for i in range(300)],
|
||||
}, index=dates)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def yfinance_sector_data_xlk():
|
||||
"""Create XLK sector ETF data in yfinance format."""
|
||||
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||
|
||||
data = pd.DataFrame({
|
||||
'Open': [200.0 + i * 0.35 for i in range(300)],
|
||||
'High': [202.0 + i * 0.35 for i in range(300)],
|
||||
'Low': [199.0 + i * 0.35 for i in range(300)],
|
||||
'Close': [201.0 + i * 0.35 for i in range(300)],
|
||||
'Volume': [10000000 + i * 50000 for i in range(300)],
|
||||
}, index=dates)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def yfinance_sector_data_xlf():
|
||||
"""Create XLF sector ETF data in yfinance format."""
|
||||
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||
|
||||
data = pd.DataFrame({
|
||||
'Open': [38.0 + i * 0.02 for i in range(300)],
|
||||
'High': [38.5 + i * 0.02 for i in range(300)],
|
||||
'Low': [37.5 + i * 0.02 for i in range(300)],
|
||||
'Close': [38.2 + i * 0.02 for i in range(300)],
|
||||
'Volume': [60000000 + i * 200000 for i in range(300)],
|
||||
}, index=dates)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def yfinance_sector_data_xle():
|
||||
"""Create XLE sector ETF data in yfinance format."""
|
||||
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||
|
||||
data = pd.DataFrame({
|
||||
'Open': [85.0 + i * 0.1 for i in range(300)],
|
||||
'High': [86.0 + i * 0.1 for i in range(300)],
|
||||
'Low': [84.0 + i * 0.1 for i in range(300)],
|
||||
'Close': [85.5 + i * 0.1 for i in range(300)],
|
||||
'Volume': [25000000 + i * 100000 for i in range(300)],
|
||||
}, index=dates)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def all_sector_etf_data():
|
||||
"""
|
||||
Create data for all 11 sector ETFs.
|
||||
|
||||
Returns dict mapping sector names to DataFrames.
|
||||
"""
|
||||
sectors_data = {}
|
||||
sector_configs = {
|
||||
'communication': {'base': 75.0, 'increment': 0.08},
|
||||
'consumer_discretionary': {'base': 180.0, 'increment': 0.15},
|
||||
'consumer_staples': {'base': 75.0, 'increment': 0.05},
|
||||
'energy': {'base': 85.0, 'increment': 0.1},
|
||||
'financials': {'base': 38.0, 'increment': 0.02},
|
||||
'healthcare': {'base': 130.0, 'increment': 0.12},
|
||||
'industrials': {'base': 105.0, 'increment': 0.09},
|
||||
'materials': {'base': 85.0, 'increment': 0.07},
|
||||
'real_estate': {'base': 40.0, 'increment': 0.03},
|
||||
'technology': {'base': 200.0, 'increment': 0.35},
|
||||
'utilities': {'base': 65.0, 'increment': 0.04},
|
||||
}
|
||||
|
||||
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||
|
||||
for sector, config in sector_configs.items():
|
||||
base = config['base']
|
||||
inc = config['increment']
|
||||
|
||||
data = pd.DataFrame({
|
||||
'Open': [base + i * inc for i in range(300)],
|
||||
'High': [base + 1.0 + i * inc for i in range(300)],
|
||||
'Low': [base - 0.5 + i * inc for i in range(300)],
|
||||
'Close': [base + 0.5 + i * inc for i in range(300)],
|
||||
'Volume': [15000000 + i * 50000 for i in range(300)],
|
||||
}, index=dates)
|
||||
|
||||
sectors_data[sector] = data
|
||||
|
||||
return sectors_data
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Class: Benchmark Integration
|
||||
# ============================================================================
|
||||
|
||||
class TestBenchmarkIntegration:
|
||||
"""
|
||||
Test suite for end-to-end benchmark workflows.
|
||||
|
||||
Tests:
|
||||
- Complete analysis workflow (fetch + RS + correlation + beta)
|
||||
- Multi-sector comparison
|
||||
- All sector ETFs availability
|
||||
- Combined analytics
|
||||
"""
|
||||
|
||||
@patch('tradingagents.dataflows.benchmark.yf')
|
||||
def test_end_to_end_benchmark_analysis(
|
||||
self,
|
||||
mock_yf,
|
||||
yfinance_stock_data,
|
||||
yfinance_spy_data
|
||||
):
|
||||
"""
|
||||
Test complete benchmark analysis workflow.
|
||||
|
||||
Workflow:
|
||||
1. Fetch SPY benchmark data
|
||||
2. Fetch stock data
|
||||
3. Calculate relative strength
|
||||
4. Calculate rolling correlation
|
||||
5. Calculate beta
|
||||
"""
|
||||
from tradingagents.dataflows.benchmark import (
|
||||
get_spy_data,
|
||||
get_benchmark_data,
|
||||
calculate_relative_strength,
|
||||
calculate_rolling_correlation,
|
||||
calculate_beta
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
def ticker_side_effect(symbol):
|
||||
mock_ticker_instance = MagicMock()
|
||||
if symbol == 'SPY':
|
||||
mock_ticker_instance.history.return_value = yfinance_spy_data
|
||||
else: # AAPL
|
||||
mock_ticker_instance.history.return_value = yfinance_stock_data
|
||||
return mock_ticker_instance
|
||||
|
||||
mock_yf.Ticker.side_effect = ticker_side_effect
|
||||
|
||||
# Step 1: Fetch SPY benchmark
|
||||
spy_data = get_spy_data('2024-01-01', '2024-10-31')
|
||||
assert isinstance(spy_data, pd.DataFrame)
|
||||
assert len(spy_data) > 0
|
||||
|
||||
# Step 2: Fetch stock data
|
||||
stock_data = get_benchmark_data('AAPL', '2024-01-01', '2024-10-31')
|
||||
assert isinstance(stock_data, pd.DataFrame)
|
||||
assert len(stock_data) > 0
|
||||
|
||||
# Step 3: Calculate relative strength
|
||||
rs = calculate_relative_strength(stock_data, spy_data)
|
||||
assert isinstance(rs, float)
|
||||
assert not np.isnan(rs)
|
||||
|
||||
# Step 4: Calculate rolling correlation
|
||||
correlation = calculate_rolling_correlation(stock_data, spy_data, window=63)
|
||||
assert isinstance(correlation, pd.Series)
|
||||
assert len(correlation.dropna()) > 0
|
||||
|
||||
# Step 5: Calculate beta
|
||||
beta = calculate_beta(stock_data, spy_data, window=252)
|
||||
assert isinstance(beta, float)
|
||||
assert not np.isnan(beta)
|
||||
|
||||
# Verify reasonable values
|
||||
assert -200 < rs < 200
|
||||
assert (correlation.dropna() >= -1.0).all()
|
||||
assert (correlation.dropna() <= 1.0).all()
|
||||
# Beta can be high for synthetic test data with varying volatility
|
||||
assert -10 < beta < 10
|
||||
|
||||
@patch('tradingagents.dataflows.benchmark.yf')
|
||||
def test_multi_sector_comparison(
|
||||
self,
|
||||
mock_yf,
|
||||
yfinance_stock_data,
|
||||
yfinance_spy_data,
|
||||
yfinance_sector_data_xlk,
|
||||
yfinance_sector_data_xlf,
|
||||
yfinance_sector_data_xle
|
||||
):
|
||||
"""
|
||||
Test comparing stock performance against multiple sector ETFs.
|
||||
|
||||
Workflow:
|
||||
1. Fetch stock data
|
||||
2. Fetch SPY and multiple sector ETFs
|
||||
3. Calculate RS against each benchmark
|
||||
4. Compare results
|
||||
"""
|
||||
from tradingagents.dataflows.benchmark import (
|
||||
get_benchmark_data,
|
||||
get_sector_etf_data,
|
||||
calculate_relative_strength
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
def ticker_side_effect(symbol):
|
||||
mock_ticker_instance = MagicMock()
|
||||
data_map = {
|
||||
'AAPL': yfinance_stock_data,
|
||||
'SPY': yfinance_spy_data,
|
||||
'XLK': yfinance_sector_data_xlk,
|
||||
'XLF': yfinance_sector_data_xlf,
|
||||
'XLE': yfinance_sector_data_xle,
|
||||
}
|
||||
mock_ticker_instance.history.return_value = data_map.get(
|
||||
symbol,
|
||||
pd.DataFrame()
|
||||
)
|
||||
return mock_ticker_instance
|
||||
|
||||
mock_yf.Ticker.side_effect = ticker_side_effect
|
||||
|
||||
# Fetch stock data
|
||||
stock_data = get_benchmark_data('AAPL', '2024-01-01', '2024-10-31')
|
||||
assert isinstance(stock_data, pd.DataFrame)
|
||||
|
||||
# Calculate RS against multiple benchmarks
|
||||
rs_results = {}
|
||||
|
||||
# vs SPY
|
||||
spy_data = get_benchmark_data('SPY', '2024-01-01', '2024-10-31')
|
||||
rs_results['SPY'] = calculate_relative_strength(stock_data, spy_data)
|
||||
|
||||
# vs Technology (XLK)
|
||||
tech_data = get_sector_etf_data('technology', '2024-01-01', '2024-10-31')
|
||||
rs_results['XLK'] = calculate_relative_strength(stock_data, tech_data)
|
||||
|
||||
# vs Financials (XLF)
|
||||
finance_data = get_sector_etf_data('financials', '2024-01-01', '2024-10-31')
|
||||
rs_results['XLF'] = calculate_relative_strength(stock_data, finance_data)
|
||||
|
||||
# vs Energy (XLE)
|
||||
energy_data = get_sector_etf_data('energy', '2024-01-01', '2024-10-31')
|
||||
rs_results['XLE'] = calculate_relative_strength(stock_data, energy_data)
|
||||
|
||||
# Assert all RS calculations succeeded
|
||||
for benchmark, rs in rs_results.items():
|
||||
assert isinstance(rs, float), f"RS vs {benchmark} failed"
|
||||
assert not np.isnan(rs), f"RS vs {benchmark} is NaN"
|
||||
assert -200 < rs < 200, f"RS vs {benchmark} out of range"
|
||||
|
||||
# AAPL should have different RS against different sectors
|
||||
unique_values = len(set(rs_results.values()))
|
||||
assert unique_values > 1, "RS should differ across sectors"
|
||||
|
||||
@patch('tradingagents.dataflows.benchmark.yf')
|
||||
def test_all_sector_etfs_available(self, mock_yf, all_sector_etf_data):
|
||||
"""
|
||||
Test that all 11 sector ETFs can be fetched.
|
||||
|
||||
Sectors:
|
||||
- communication (XLC)
|
||||
- consumer_discretionary (XLY)
|
||||
- consumer_staples (XLP)
|
||||
- energy (XLE)
|
||||
- financials (XLF)
|
||||
- healthcare (XLV)
|
||||
- industrials (XLI)
|
||||
- materials (XLB)
|
||||
- real_estate (XLRE)
|
||||
- technology (XLK)
|
||||
- utilities (XLU)
|
||||
"""
|
||||
from tradingagents.dataflows.benchmark import get_sector_etf_data, SECTOR_ETFS
|
||||
|
||||
# Setup mocks
|
||||
def ticker_side_effect(symbol):
|
||||
mock_ticker_instance = MagicMock()
|
||||
# Find which sector this symbol belongs to
|
||||
for sector, etf_symbol in SECTOR_ETFS.items():
|
||||
if etf_symbol == symbol:
|
||||
mock_ticker_instance.history.return_value = all_sector_etf_data[sector]
|
||||
return mock_ticker_instance
|
||||
# Default empty
|
||||
mock_ticker_instance.history.return_value = pd.DataFrame()
|
||||
return mock_ticker_instance
|
||||
|
||||
mock_yf.Ticker.side_effect = ticker_side_effect
|
||||
|
||||
# Test each sector
|
||||
sectors = [
|
||||
'communication',
|
||||
'consumer_discretionary',
|
||||
'consumer_staples',
|
||||
'energy',
|
||||
'financials',
|
||||
'healthcare',
|
||||
'industrials',
|
||||
'materials',
|
||||
'real_estate',
|
||||
'technology',
|
||||
'utilities'
|
||||
]
|
||||
|
||||
for sector in sectors:
|
||||
result = get_sector_etf_data(sector, '2024-01-01', '2024-10-31')
|
||||
assert isinstance(result, pd.DataFrame), f"Sector {sector} failed"
|
||||
assert len(result) > 0, f"Sector {sector} returned empty data"
|
||||
assert 'Close' in result.columns, f"Sector {sector} missing Close column"
|
||||
|
||||
@patch('tradingagents.dataflows.benchmark.yf')
|
||||
def test_portfolio_level_analytics(
|
||||
self,
|
||||
mock_yf,
|
||||
yfinance_spy_data,
|
||||
all_sector_etf_data
|
||||
):
|
||||
"""
|
||||
Test portfolio-level analytics across all sectors.
|
||||
|
||||
Workflow:
|
||||
1. Fetch all sector ETFs
|
||||
2. Calculate correlation matrix with SPY
|
||||
3. Calculate beta for each sector
|
||||
4. Identify high/low correlation sectors
|
||||
"""
|
||||
from tradingagents.dataflows.benchmark import (
|
||||
get_spy_data,
|
||||
get_sector_etf_data,
|
||||
calculate_rolling_correlation,
|
||||
calculate_beta,
|
||||
SECTOR_ETFS
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
def ticker_side_effect(symbol):
|
||||
mock_ticker_instance = MagicMock()
|
||||
if symbol == 'SPY':
|
||||
mock_ticker_instance.history.return_value = yfinance_spy_data
|
||||
else:
|
||||
# Find sector for this symbol
|
||||
for sector, etf_symbol in SECTOR_ETFS.items():
|
||||
if etf_symbol == symbol:
|
||||
mock_ticker_instance.history.return_value = all_sector_etf_data[sector]
|
||||
break
|
||||
return mock_ticker_instance
|
||||
|
||||
mock_yf.Ticker.side_effect = ticker_side_effect
|
||||
|
||||
# Fetch SPY
|
||||
spy_data = get_spy_data('2024-01-01', '2024-10-31')
|
||||
assert isinstance(spy_data, pd.DataFrame)
|
||||
|
||||
# Calculate analytics for each sector
|
||||
sector_analytics = {}
|
||||
|
||||
for sector in all_sector_etf_data.keys():
|
||||
sector_data = get_sector_etf_data(sector, '2024-01-01', '2024-10-31')
|
||||
|
||||
if isinstance(sector_data, pd.DataFrame) and len(sector_data) > 0:
|
||||
# Calculate correlation
|
||||
correlation = calculate_rolling_correlation(
|
||||
sector_data,
|
||||
spy_data,
|
||||
window=63
|
||||
)
|
||||
|
||||
# Calculate beta
|
||||
beta = calculate_beta(sector_data, spy_data, window=252)
|
||||
|
||||
sector_analytics[sector] = {
|
||||
'avg_correlation': correlation.dropna().mean() if isinstance(correlation, pd.Series) else None,
|
||||
'beta': beta if isinstance(beta, float) else None
|
||||
}
|
||||
|
||||
# Assert we got analytics for all sectors
|
||||
assert len(sector_analytics) == 11, "Should have analytics for all 11 sectors"
|
||||
|
||||
# Assert all analytics are valid
|
||||
for sector, analytics in sector_analytics.items():
|
||||
if analytics['avg_correlation'] is not None:
|
||||
assert -1.0 <= analytics['avg_correlation'] <= 1.0, \
|
||||
f"Correlation for {sector} out of range"
|
||||
|
||||
if analytics['beta'] is not None:
|
||||
assert not np.isnan(analytics['beta']), \
|
||||
f"Beta for {sector} is NaN"
|
||||
# Beta can be high for synthetic test data with varying volatility
|
||||
assert -10 < analytics['beta'] < 10, \
|
||||
f"Beta for {sector} out of reasonable range"
|
||||
|
||||
# Identify high correlation sectors (should correlate well with SPY)
|
||||
high_corr_sectors = [
|
||||
sector for sector, analytics in sector_analytics.items()
|
||||
if analytics['avg_correlation'] is not None and analytics['avg_correlation'] > 0.7
|
||||
]
|
||||
|
||||
# Most sectors should have positive correlation with market
|
||||
assert len(high_corr_sectors) >= 1, "At least one sector should correlate with SPY"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Class: Real-World Data Format Handling
|
||||
# ============================================================================
|
||||
|
||||
class TestRealWorldDataFormat:
|
||||
"""
|
||||
Test suite for handling real-world data format quirks.
|
||||
|
||||
Tests:
|
||||
- Timezone-aware DatetimeIndex
|
||||
- Column name variations
|
||||
- Missing data handling
|
||||
- Date range alignment
|
||||
"""
|
||||
|
||||
@patch('tradingagents.dataflows.benchmark.yf')
|
||||
def test_timezone_aware_data(self, mock_yf):
|
||||
"""Test handling of timezone-aware yfinance data."""
|
||||
from tradingagents.dataflows.benchmark import get_benchmark_data
|
||||
|
||||
# Create timezone-aware data
|
||||
dates = pd.date_range('2024-01-01', periods=300, freq='D', tz='America/New_York')
|
||||
tz_data = pd.DataFrame({
|
||||
'Open': [100.0 + i * 0.1 for i in range(300)],
|
||||
'High': [101.0 + i * 0.1 for i in range(300)],
|
||||
'Low': [99.0 + i * 0.1 for i in range(300)],
|
||||
'Close': [100.5 + i * 0.1 for i in range(300)],
|
||||
'Volume': [1000000] * 300,
|
||||
}, index=dates)
|
||||
|
||||
# Setup mock
|
||||
mock_ticker_instance = MagicMock()
|
||||
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||
mock_ticker_instance.history.return_value = tz_data
|
||||
|
||||
# Execute
|
||||
result = get_benchmark_data('SPY', '2024-01-01', '2024-10-31')
|
||||
|
||||
# Assert - should handle timezone-aware data
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) > 0
|
||||
|
||||
@patch('tradingagents.dataflows.benchmark.yf')
|
||||
def test_business_day_frequency(self, mock_yf):
|
||||
"""Test handling of business day frequency data (no weekends)."""
|
||||
from tradingagents.dataflows.benchmark import get_benchmark_data, calculate_relative_strength
|
||||
|
||||
# Create business day data
|
||||
dates = pd.bdate_range('2024-01-01', periods=250, freq='B')
|
||||
|
||||
spy_data = pd.DataFrame({
|
||||
'Open': [450.0 + i * 0.3 for i in range(250)],
|
||||
'High': [452.0 + i * 0.3 for i in range(250)],
|
||||
'Low': [449.0 + i * 0.3 for i in range(250)],
|
||||
'Close': [451.0 + i * 0.3 for i in range(250)],
|
||||
'Volume': [80000000] * 250,
|
||||
}, index=dates)
|
||||
|
||||
stock_data = pd.DataFrame({
|
||||
'Open': [180.0 + i * 0.4 for i in range(250)],
|
||||
'High': [182.0 + i * 0.4 for i in range(250)],
|
||||
'Low': [179.0 + i * 0.4 for i in range(250)],
|
||||
'Close': [181.0 + i * 0.4 for i in range(250)],
|
||||
'Volume': [50000000] * 250,
|
||||
}, index=dates)
|
||||
|
||||
# Setup mock
|
||||
def ticker_side_effect(symbol):
|
||||
mock_ticker_instance = MagicMock()
|
||||
if symbol == 'SPY':
|
||||
mock_ticker_instance.history.return_value = spy_data
|
||||
else:
|
||||
mock_ticker_instance.history.return_value = stock_data
|
||||
return mock_ticker_instance
|
||||
|
||||
mock_yf.Ticker.side_effect = ticker_side_effect
|
||||
|
||||
# Fetch data
|
||||
result_spy = get_benchmark_data('SPY', '2024-01-01', '2024-12-31')
|
||||
result_stock = get_benchmark_data('AAPL', '2024-01-01', '2024-12-31')
|
||||
|
||||
# Calculate RS
|
||||
rs = calculate_relative_strength(result_stock, result_spy)
|
||||
|
||||
# Assert - should handle business days correctly
|
||||
assert isinstance(rs, float)
|
||||
assert not np.isnan(rs)
|
||||
|
||||
@patch('tradingagents.dataflows.benchmark.yf')
|
||||
def test_date_range_alignment(self, mock_yf):
|
||||
"""Test automatic date range alignment between stock and benchmark."""
|
||||
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||
|
||||
# Create overlapping but not identical date ranges
|
||||
spy_dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||
stock_dates = pd.date_range('2024-01-15', periods=280, freq='D') # Starts 14 days later
|
||||
|
||||
spy_data = pd.DataFrame({
|
||||
'Close': [450.0 + i * 0.3 for i in range(300)],
|
||||
'Volume': [80000000] * 300,
|
||||
}, index=spy_dates)
|
||||
|
||||
stock_data = pd.DataFrame({
|
||||
'Close': [180.0 + i * 0.4 for i in range(280)],
|
||||
'Volume': [50000000] * 280,
|
||||
}, index=stock_dates)
|
||||
|
||||
# Add other required columns
|
||||
for df in [spy_data, stock_data]:
|
||||
df['Open'] = df['Close'] - 0.5
|
||||
df['High'] = df['Close'] + 1.0
|
||||
df['Low'] = df['Close'] - 1.0
|
||||
|
||||
# Execute RS calculation - should align dates internally
|
||||
result = calculate_relative_strength(stock_data, spy_data)
|
||||
|
||||
# Assert - should handle date alignment
|
||||
# Either returns valid RS or error message
|
||||
if isinstance(result, float):
|
||||
assert not np.isnan(result)
|
||||
else:
|
||||
assert isinstance(result, str)
|
||||
|
|
@ -0,0 +1,753 @@
|
|||
"""
|
||||
Test suite for Benchmark Data Functions (benchmark.py).
|
||||
|
||||
This module tests:
|
||||
1. get_benchmark_data() - Generic benchmark data fetcher via yfinance
|
||||
2. get_spy_data() - Convenience wrapper for SPY
|
||||
3. get_sector_etf_data() - Sector ETF data (XLF, XLK, XLE, etc.)
|
||||
4. calculate_relative_strength() - IBD-style RS calculation
|
||||
5. calculate_rolling_correlation() - Rolling correlation between stock and benchmark
|
||||
6. calculate_beta() - Beta calculation (Cov/Var)
|
||||
|
||||
Test Coverage:
|
||||
- Unit tests for each function
|
||||
- Valid data fetching (SPY, sector ETFs)
|
||||
- Invalid inputs (bad symbols, dates, sectors)
|
||||
- RS calculation with IBD formula: 0.4*ROC(63) + 0.2*ROC(126) + 0.2*ROC(189) + 0.2*ROC(252)
|
||||
- Rolling correlation with configurable window
|
||||
- Beta calculation with market variance
|
||||
- Edge cases (empty data, insufficient data, missing columns, date misalignment)
|
||||
- Zero returns handling
|
||||
- Extreme values
|
||||
|
||||
SECTOR_ETFS Constants:
|
||||
- communication: XLC
|
||||
- consumer_discretionary: XLY
|
||||
- consumer_staples: XLP
|
||||
- energy: XLE
|
||||
- financials: XLF
|
||||
- healthcare: XLV
|
||||
- industrials: XLI
|
||||
- materials: XLB
|
||||
- real_estate: XLRE
|
||||
- technology: XLK
|
||||
- utilities: XLU
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def sample_spy_data():
|
||||
"""
|
||||
Create 300 days of sample SPY OHLCV data.
|
||||
|
||||
Returns a DataFrame with DatetimeIndex and columns: Open, High, Low, Close, Volume.
|
||||
Simulates realistic market data with upward trend.
|
||||
"""
|
||||
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||
|
||||
data = []
|
||||
base_price = 450.0
|
||||
|
||||
for i, date in enumerate(dates):
|
||||
# Simulate gradual upward trend with some volatility
|
||||
trend = i * 0.3
|
||||
volatility = np.sin(i / 10) * 2
|
||||
|
||||
open_price = base_price + trend + volatility
|
||||
high_price = open_price + 1.5 + abs(np.cos(i / 5))
|
||||
low_price = open_price - 1.5 - abs(np.sin(i / 7))
|
||||
close_price = open_price + 0.5 + np.sin(i / 3) * 0.5
|
||||
volume = 80000000 + i * 100000
|
||||
|
||||
data.append({
|
||||
'Open': round(open_price, 2),
|
||||
'High': round(high_price, 2),
|
||||
'Low': round(low_price, 2),
|
||||
'Close': round(close_price, 2),
|
||||
'Volume': volume
|
||||
})
|
||||
|
||||
df = pd.DataFrame(data, index=dates)
|
||||
return df
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_stock_data():
|
||||
"""
|
||||
Create 300 days of sample stock OHLCV data (AAPL-like pattern).
|
||||
|
||||
Returns a DataFrame with DatetimeIndex and columns: Open, High, Low, Close, Volume.
|
||||
Simulates tech stock with higher volatility than SPY.
|
||||
"""
|
||||
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||
|
||||
data = []
|
||||
base_price = 180.0
|
||||
|
||||
for i, date in enumerate(dates):
|
||||
# Higher volatility and stronger trend than SPY
|
||||
trend = i * 0.4
|
||||
volatility = np.sin(i / 8) * 3
|
||||
|
||||
open_price = base_price + trend + volatility
|
||||
high_price = open_price + 2.0 + abs(np.cos(i / 4))
|
||||
low_price = open_price - 2.0 - abs(np.sin(i / 6))
|
||||
close_price = open_price + 0.8 + np.sin(i / 2.5) * 0.8
|
||||
volume = 50000000 + i * 80000
|
||||
|
||||
data.append({
|
||||
'Open': round(open_price, 2),
|
||||
'High': round(high_price, 2),
|
||||
'Low': round(low_price, 2),
|
||||
'Close': round(close_price, 2),
|
||||
'Volume': volume
|
||||
})
|
||||
|
||||
df = pd.DataFrame(data, index=dates)
|
||||
return df
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_sector_data():
|
||||
"""
|
||||
Create 300 days of sample XLK (technology sector) OHLCV data.
|
||||
|
||||
Returns a DataFrame with DatetimeIndex and columns: Open, High, Low, Close, Volume.
|
||||
Similar to SPY but with sector-specific characteristics.
|
||||
"""
|
||||
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||
|
||||
data = []
|
||||
base_price = 200.0
|
||||
|
||||
for i, date in enumerate(dates):
|
||||
# Tech sector - moderate trend with cyclical volatility
|
||||
trend = i * 0.35
|
||||
volatility = np.sin(i / 12) * 2.5
|
||||
|
||||
open_price = base_price + trend + volatility
|
||||
high_price = open_price + 1.8 + abs(np.cos(i / 5.5))
|
||||
low_price = open_price - 1.8 - abs(np.sin(i / 6.5))
|
||||
close_price = open_price + 0.6 + np.sin(i / 3.5) * 0.6
|
||||
volume = 10000000 + i * 50000
|
||||
|
||||
data.append({
|
||||
'Open': round(open_price, 2),
|
||||
'High': round(high_price, 2),
|
||||
'Low': round(low_price, 2),
|
||||
'Close': round(close_price, 2),
|
||||
'Volume': volume
|
||||
})
|
||||
|
||||
df = pd.DataFrame(data, index=dates)
|
||||
return df
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_dataframe():
|
||||
"""Create empty DataFrame for validation testing."""
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def insufficient_data():
|
||||
"""
|
||||
Create DataFrame with only 50 days (insufficient for full RS calculation).
|
||||
|
||||
Insufficient for 252-day ROC calculation in relative strength.
|
||||
"""
|
||||
dates = pd.date_range('2024-01-01', periods=50, freq='D')
|
||||
|
||||
data = []
|
||||
base_price = 100.0
|
||||
|
||||
for i, date in enumerate(dates):
|
||||
open_price = base_price + i * 0.2
|
||||
data.append({
|
||||
'Open': round(open_price, 2),
|
||||
'High': round(open_price + 1.0, 2),
|
||||
'Low': round(open_price - 0.5, 2),
|
||||
'Close': round(open_price + 0.3, 2),
|
||||
'Volume': 1000000
|
||||
})
|
||||
|
||||
df = pd.DataFrame(data, index=dates)
|
||||
return df
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def missing_close_data():
|
||||
"""Create DataFrame missing Close column."""
|
||||
dates = pd.date_range('2024-01-01', periods=100, freq='D')
|
||||
return pd.DataFrame({
|
||||
'Open': [100.0 + i * 0.1 for i in range(100)],
|
||||
'High': [102.0 + i * 0.1 for i in range(100)],
|
||||
'Low': [99.0 + i * 0.1 for i in range(100)],
|
||||
'Volume': [1000000] * 100,
|
||||
}, index=dates)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def misaligned_dates_data(sample_spy_data):
|
||||
"""
|
||||
Create stock data with different date range than SPY.
|
||||
|
||||
50-day offset to test date alignment handling.
|
||||
"""
|
||||
offset_dates = pd.date_range('2024-02-20', periods=250, freq='D')
|
||||
|
||||
data = []
|
||||
base_price = 150.0
|
||||
|
||||
for i, date in enumerate(offset_dates):
|
||||
open_price = base_price + i * 0.2
|
||||
data.append({
|
||||
'Open': round(open_price, 2),
|
||||
'High': round(open_price + 1.5, 2),
|
||||
'Low': round(open_price - 1.0, 2),
|
||||
'Close': round(open_price + 0.5, 2),
|
||||
'Volume': 2000000
|
||||
})
|
||||
|
||||
return pd.DataFrame(data, index=offset_dates)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zero_returns_data():
|
||||
"""Create DataFrame with zero returns (flat prices)."""
|
||||
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||
|
||||
# All prices are constant (no returns)
|
||||
constant_price = 100.0
|
||||
|
||||
data = []
|
||||
for date in dates:
|
||||
data.append({
|
||||
'Open': constant_price,
|
||||
'High': constant_price,
|
||||
'Low': constant_price,
|
||||
'Close': constant_price,
|
||||
'Volume': 1000000
|
||||
})
|
||||
|
||||
df = pd.DataFrame(data, index=dates)
|
||||
return df
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def extreme_values_data():
|
||||
"""Create DataFrame with extreme price values."""
|
||||
dates = pd.date_range('2024-01-01', periods=300, freq='D')
|
||||
|
||||
data = []
|
||||
for i, date in enumerate(dates):
|
||||
# Extreme volatility: 50% daily moves
|
||||
if i % 2 == 0:
|
||||
close = 100.0 * (1.5 ** (i // 2))
|
||||
else:
|
||||
close = 100.0 * (1.5 ** (i // 2)) * 0.5
|
||||
|
||||
data.append({
|
||||
'Open': close,
|
||||
'High': close * 1.1,
|
||||
'Low': close * 0.9,
|
||||
'Close': close,
|
||||
'Volume': 1000000
|
||||
})
|
||||
|
||||
df = pd.DataFrame(data, index=dates)
|
||||
return df
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Class: Benchmark Data Fetching
|
||||
# ============================================================================
|
||||
|
||||
class TestBenchmarkDataFetching:
|
||||
"""
|
||||
Test suite for benchmark data fetching functions.
|
||||
|
||||
Tests:
|
||||
- get_benchmark_data() with valid symbols
|
||||
- get_spy_data() convenience wrapper
|
||||
- get_sector_etf_data() with valid sectors
|
||||
- Invalid symbol handling
|
||||
- Invalid sector handling
|
||||
- Invalid date handling
|
||||
- Empty data handling
|
||||
"""
|
||||
|
||||
@patch('tradingagents.dataflows.benchmark.yf')
|
||||
def test_get_benchmark_data_valid_spy(self, mock_yf, sample_spy_data):
|
||||
"""Test fetching valid SPY benchmark data."""
|
||||
from tradingagents.dataflows.benchmark import get_benchmark_data
|
||||
|
||||
# Setup mock
|
||||
mock_ticker_instance = MagicMock()
|
||||
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||
mock_ticker_instance.history.return_value = sample_spy_data
|
||||
|
||||
# Execute
|
||||
result = get_benchmark_data('SPY', '2024-01-01', '2024-10-31')
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) > 0
|
||||
assert 'Close' in result.columns
|
||||
assert 'Volume' in result.columns
|
||||
mock_yf.Ticker.assert_called_once_with('SPY')
|
||||
mock_ticker_instance.history.assert_called_once()
|
||||
|
||||
@patch('tradingagents.dataflows.benchmark.yf')
|
||||
def test_get_benchmark_data_valid_sector_etf(self, mock_yf, sample_sector_data):
|
||||
"""Test fetching valid sector ETF data (XLK)."""
|
||||
from tradingagents.dataflows.benchmark import get_benchmark_data
|
||||
|
||||
# Setup mock
|
||||
mock_ticker_instance = MagicMock()
|
||||
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||
mock_ticker_instance.history.return_value = sample_sector_data
|
||||
|
||||
# Execute
|
||||
result = get_benchmark_data('XLK', '2024-01-01', '2024-10-31')
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) > 0
|
||||
assert 'Close' in result.columns
|
||||
mock_yf.Ticker.assert_called_once_with('XLK')
|
||||
|
||||
@patch('tradingagents.dataflows.benchmark.yf')
|
||||
def test_get_benchmark_data_invalid_symbol(self, mock_yf):
|
||||
"""Test handling of invalid ticker symbol."""
|
||||
from tradingagents.dataflows.benchmark import get_benchmark_data
|
||||
|
||||
# Setup mock to raise exception
|
||||
mock_ticker_instance = MagicMock()
|
||||
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||
mock_ticker_instance.history.side_effect = Exception("Invalid ticker")
|
||||
|
||||
# Execute
|
||||
result = get_benchmark_data('INVALID_SYMBOL', '2024-01-01', '2024-10-31')
|
||||
|
||||
# Assert - should return error string
|
||||
assert isinstance(result, str)
|
||||
assert 'error' in result.lower() or 'invalid' in result.lower()
|
||||
|
||||
@patch('tradingagents.dataflows.benchmark.yf')
|
||||
def test_get_benchmark_data_empty_data(self, mock_yf):
|
||||
"""Test handling when yfinance returns empty DataFrame."""
|
||||
from tradingagents.dataflows.benchmark import get_benchmark_data
|
||||
|
||||
# Setup mock to return empty DataFrame
|
||||
mock_ticker_instance = MagicMock()
|
||||
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||
mock_ticker_instance.history.return_value = pd.DataFrame()
|
||||
|
||||
# Execute
|
||||
result = get_benchmark_data('SPY', '2024-01-01', '2024-01-02')
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
assert 'no data' in result.lower() or 'empty' in result.lower()
|
||||
|
||||
@patch('tradingagents.dataflows.benchmark.yf')
|
||||
def test_get_spy_data(self, mock_yf, sample_spy_data):
|
||||
"""Test get_spy_data() convenience wrapper."""
|
||||
from tradingagents.dataflows.benchmark import get_spy_data
|
||||
|
||||
# Setup mock
|
||||
mock_ticker_instance = MagicMock()
|
||||
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||
mock_ticker_instance.history.return_value = sample_spy_data
|
||||
|
||||
# Execute
|
||||
result = get_spy_data('2024-01-01', '2024-10-31')
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) > 0
|
||||
mock_yf.Ticker.assert_called_once_with('SPY')
|
||||
|
||||
@patch('tradingagents.dataflows.benchmark.yf')
|
||||
def test_get_sector_etf_data_valid(self, mock_yf, sample_sector_data):
|
||||
"""Test get_sector_etf_data() with valid sector."""
|
||||
from tradingagents.dataflows.benchmark import get_sector_etf_data
|
||||
|
||||
# Setup mock
|
||||
mock_ticker_instance = MagicMock()
|
||||
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||
mock_ticker_instance.history.return_value = sample_sector_data
|
||||
|
||||
# Execute
|
||||
result = get_sector_etf_data('technology', '2024-01-01', '2024-10-31')
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) > 0
|
||||
mock_yf.Ticker.assert_called_once_with('XLK')
|
||||
|
||||
def test_get_sector_etf_data_invalid_sector(self):
|
||||
"""Test get_sector_etf_data() with invalid sector name."""
|
||||
from tradingagents.dataflows.benchmark import get_sector_etf_data
|
||||
|
||||
# Execute
|
||||
result = get_sector_etf_data('invalid_sector', '2024-01-01', '2024-10-31')
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
assert 'invalid' in result.lower() or 'unknown' in result.lower()
|
||||
|
||||
@patch('tradingagents.dataflows.benchmark.yf')
|
||||
def test_get_benchmark_data_invalid_dates(self, mock_yf):
|
||||
"""Test handling of invalid date format."""
|
||||
from tradingagents.dataflows.benchmark import get_benchmark_data
|
||||
|
||||
# Setup mock to raise exception on invalid dates
|
||||
mock_ticker_instance = MagicMock()
|
||||
mock_yf.Ticker.return_value = mock_ticker_instance
|
||||
mock_ticker_instance.history.side_effect = ValueError("Invalid date format")
|
||||
|
||||
# Execute
|
||||
result = get_benchmark_data('SPY', 'invalid-date', '2024-10-31')
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
assert 'error' in result.lower() or 'invalid' in result.lower()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Class: Relative Strength Calculation
|
||||
# ============================================================================
|
||||
|
||||
class TestRelativeStrength:
|
||||
"""
|
||||
Test suite for IBD-style relative strength calculation.
|
||||
|
||||
Tests:
|
||||
- RS calculation with IBD formula: 0.4*ROC(63) + 0.2*ROC(126) + 0.2*ROC(189) + 0.2*ROC(252)
|
||||
- Insufficient data handling (< 252 days)
|
||||
- Missing Close column
|
||||
- Date misalignment between stock and benchmark
|
||||
- Zero returns (flat prices)
|
||||
|
||||
IBD Relative Strength:
|
||||
- 40% weight on 63-day (3-month) ROC
|
||||
- 20% weight on 126-day (6-month) ROC
|
||||
- 20% weight on 189-day (9-month) ROC
|
||||
- 20% weight on 252-day (12-month) ROC
|
||||
"""
|
||||
|
||||
def test_calculate_relative_strength_valid(self, sample_stock_data, sample_spy_data):
|
||||
"""Test RS calculation with valid data."""
|
||||
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||
|
||||
# Execute
|
||||
result = calculate_relative_strength(sample_stock_data, sample_spy_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, float)
|
||||
assert not np.isnan(result)
|
||||
assert not np.isinf(result)
|
||||
# RS should be reasonable (typically between -100 and 100 for normal stocks)
|
||||
assert -200 < result < 200
|
||||
|
||||
def test_calculate_relative_strength_custom_periods(self, sample_stock_data, sample_spy_data):
|
||||
"""Test RS calculation with custom periods."""
|
||||
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||
|
||||
# Execute with shorter periods (all data has 300 days)
|
||||
result = calculate_relative_strength(
|
||||
sample_stock_data,
|
||||
sample_spy_data,
|
||||
periods=[20, 60, 120, 180]
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, float)
|
||||
assert not np.isnan(result)
|
||||
assert -200 < result < 200
|
||||
|
||||
def test_calculate_relative_strength_insufficient_data(self, insufficient_data, sample_spy_data):
|
||||
"""Test RS calculation with insufficient data (< 252 days)."""
|
||||
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||
|
||||
# Execute - only 50 days, need 252 for default periods
|
||||
result = calculate_relative_strength(insufficient_data, sample_spy_data)
|
||||
|
||||
# Assert - should return error string
|
||||
assert isinstance(result, str)
|
||||
assert 'insufficient' in result.lower() or 'not enough' in result.lower()
|
||||
|
||||
def test_calculate_relative_strength_missing_close(self, missing_close_data, sample_spy_data):
|
||||
"""Test RS calculation with missing Close column."""
|
||||
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||
|
||||
# Execute
|
||||
result = calculate_relative_strength(missing_close_data, sample_spy_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
assert 'close' in result.lower() or 'missing' in result.lower()
|
||||
|
||||
def test_calculate_relative_strength_date_misalignment(self, misaligned_dates_data, sample_spy_data):
|
||||
"""Test RS calculation with misaligned date ranges."""
|
||||
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||
|
||||
# Execute - stock data starts 50 days later than SPY
|
||||
result = calculate_relative_strength(misaligned_dates_data, sample_spy_data)
|
||||
|
||||
# Assert - function should handle alignment or return error
|
||||
# Either valid RS (if aligned) or error message
|
||||
if isinstance(result, str):
|
||||
assert 'align' in result.lower() or 'date' in result.lower() or 'insufficient' in result.lower()
|
||||
else:
|
||||
assert isinstance(result, float)
|
||||
assert not np.isnan(result)
|
||||
|
||||
def test_calculate_relative_strength_zero_returns(self, zero_returns_data, sample_spy_data):
|
||||
"""Test RS calculation with zero returns (flat prices)."""
|
||||
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||
|
||||
# Execute
|
||||
result = calculate_relative_strength(zero_returns_data, sample_spy_data)
|
||||
|
||||
# Assert - should handle zero returns gracefully
|
||||
# RS should be negative since stock has 0 returns while benchmark moves
|
||||
if isinstance(result, float):
|
||||
assert not np.isnan(result)
|
||||
assert result < 0 # Stock underperforming benchmark
|
||||
else:
|
||||
# Or return error for zero variance
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Class: Correlation Analytics
|
||||
# ============================================================================
|
||||
|
||||
class TestCorrelationAnalytics:
|
||||
"""
|
||||
Test suite for correlation and beta calculations.
|
||||
|
||||
Tests:
|
||||
- calculate_rolling_correlation() with various windows
|
||||
- calculate_beta() calculation
|
||||
- Window validation (must be >= 2)
|
||||
- Insufficient data for window size
|
||||
"""
|
||||
|
||||
def test_calculate_rolling_correlation_valid(self, sample_stock_data, sample_spy_data):
|
||||
"""Test rolling correlation calculation with default window (63 days)."""
|
||||
from tradingagents.dataflows.benchmark import calculate_rolling_correlation
|
||||
|
||||
# Execute
|
||||
result = calculate_rolling_correlation(sample_stock_data, sample_spy_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, pd.Series)
|
||||
assert len(result) > 0
|
||||
# Correlation values should be between -1 and 1
|
||||
assert (result.dropna() >= -1.0).all()
|
||||
assert (result.dropna() <= 1.0).all()
|
||||
|
||||
def test_calculate_rolling_correlation_custom_window(self, sample_stock_data, sample_spy_data):
|
||||
"""Test rolling correlation with custom window size."""
|
||||
from tradingagents.dataflows.benchmark import calculate_rolling_correlation
|
||||
|
||||
# Execute with 20-day window
|
||||
result = calculate_rolling_correlation(sample_stock_data, sample_spy_data, window=20)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, pd.Series)
|
||||
assert len(result) > 0
|
||||
assert (result.dropna() >= -1.0).all()
|
||||
assert (result.dropna() <= 1.0).all()
|
||||
|
||||
def test_calculate_rolling_correlation_invalid_window(self, sample_stock_data, sample_spy_data):
|
||||
"""Test rolling correlation with invalid window (< 2)."""
|
||||
from tradingagents.dataflows.benchmark import calculate_rolling_correlation
|
||||
|
||||
# Execute with window=1
|
||||
result = calculate_rolling_correlation(sample_stock_data, sample_spy_data, window=1)
|
||||
|
||||
# Assert - should return error
|
||||
assert isinstance(result, str)
|
||||
assert 'window' in result.lower() or 'invalid' in result.lower()
|
||||
|
||||
def test_calculate_rolling_correlation_insufficient_data(self, insufficient_data, sample_spy_data):
|
||||
"""Test rolling correlation with insufficient data for window."""
|
||||
from tradingagents.dataflows.benchmark import calculate_rolling_correlation
|
||||
|
||||
# Execute - only 50 days but default window is 63
|
||||
result = calculate_rolling_correlation(insufficient_data, sample_spy_data)
|
||||
|
||||
# Assert - should return error or empty series
|
||||
if isinstance(result, str):
|
||||
assert 'insufficient' in result.lower() or 'not enough' in result.lower()
|
||||
elif isinstance(result, pd.Series):
|
||||
# May return series with all NaN values
|
||||
assert result.dropna().empty or len(result.dropna()) < 10
|
||||
|
||||
def test_calculate_beta_valid(self, sample_stock_data, sample_spy_data):
|
||||
"""Test beta calculation with default window (252 days)."""
|
||||
from tradingagents.dataflows.benchmark import calculate_beta
|
||||
|
||||
# Execute
|
||||
result = calculate_beta(sample_stock_data, sample_spy_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, float)
|
||||
assert not np.isnan(result)
|
||||
assert not np.isinf(result)
|
||||
# Beta typically ranges from -2 to 3 for normal stocks
|
||||
assert -5 < result < 5
|
||||
|
||||
def test_calculate_beta_custom_window(self, sample_stock_data, sample_spy_data):
|
||||
"""Test beta calculation with custom window."""
|
||||
from tradingagents.dataflows.benchmark import calculate_beta
|
||||
|
||||
# Execute with 126-day window
|
||||
result = calculate_beta(sample_stock_data, sample_spy_data, window=126)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, float)
|
||||
assert not np.isnan(result)
|
||||
assert -5 < result < 5
|
||||
|
||||
def test_calculate_beta_insufficient_data(self, insufficient_data, sample_spy_data):
|
||||
"""Test beta calculation with insufficient data."""
|
||||
from tradingagents.dataflows.benchmark import calculate_beta
|
||||
|
||||
# Execute - only 50 days but default window is 252
|
||||
result = calculate_beta(insufficient_data, sample_spy_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
assert 'insufficient' in result.lower() or 'not enough' in result.lower()
|
||||
|
||||
def test_calculate_beta_zero_variance(self, zero_returns_data, sample_spy_data):
|
||||
"""Test beta calculation when stock has zero variance."""
|
||||
from tradingagents.dataflows.benchmark import calculate_beta
|
||||
|
||||
# Execute
|
||||
result = calculate_beta(zero_returns_data, sample_spy_data)
|
||||
|
||||
# Assert - should handle zero variance
|
||||
if isinstance(result, float):
|
||||
assert result == 0.0 or np.isnan(result)
|
||||
else:
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Class: Edge Cases
|
||||
# ============================================================================
|
||||
|
||||
class TestEdgeCases:
|
||||
"""
|
||||
Test suite for edge cases and boundary conditions.
|
||||
|
||||
Tests:
|
||||
- Empty DataFrames
|
||||
- Single day data
|
||||
- Extreme values
|
||||
- Date index validation
|
||||
"""
|
||||
|
||||
def test_calculate_relative_strength_empty_stock(self, empty_dataframe, sample_spy_data):
|
||||
"""Test RS calculation with empty stock DataFrame."""
|
||||
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||
|
||||
# Execute
|
||||
result = calculate_relative_strength(empty_dataframe, sample_spy_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
assert 'empty' in result.lower() or 'no data' in result.lower()
|
||||
|
||||
def test_calculate_relative_strength_empty_benchmark(self, sample_stock_data, empty_dataframe):
|
||||
"""Test RS calculation with empty benchmark DataFrame."""
|
||||
from tradingagents.dataflows.benchmark import calculate_relative_strength
|
||||
|
||||
# Execute
|
||||
result = calculate_relative_strength(sample_stock_data, empty_dataframe)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
assert 'empty' in result.lower() or 'no data' in result.lower()
|
||||
|
||||
def test_calculate_rolling_correlation_single_day(self, sample_spy_data):
|
||||
"""Test rolling correlation with single day data."""
|
||||
from tradingagents.dataflows.benchmark import calculate_rolling_correlation
|
||||
|
||||
# Create single day data
|
||||
single_day = sample_spy_data.iloc[[0]]
|
||||
|
||||
# Execute
|
||||
result = calculate_rolling_correlation(single_day, sample_spy_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str) or (isinstance(result, pd.Series) and result.dropna().empty)
|
||||
|
||||
def test_calculate_beta_extreme_values(self, extreme_values_data, sample_spy_data):
|
||||
"""Test beta calculation with extreme price movements."""
|
||||
from tradingagents.dataflows.benchmark import calculate_beta
|
||||
|
||||
# Execute
|
||||
result = calculate_beta(extreme_values_data, sample_spy_data)
|
||||
|
||||
# Assert - should handle extreme values
|
||||
if isinstance(result, float):
|
||||
assert not np.isnan(result)
|
||||
# Beta can be very high for extreme volatility
|
||||
assert -100 < result < 100
|
||||
else:
|
||||
# Or return error for numerical issues
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_get_benchmark_data_no_datetime_index(self):
|
||||
"""Test that fetched data has proper DatetimeIndex."""
|
||||
# This tests that the implementation converts yfinance data correctly
|
||||
# Will be tested in integration tests with actual yfinance calls
|
||||
pass
|
||||
|
||||
def test_sector_etf_constants_coverage(self):
|
||||
"""Test that all expected sector ETFs are defined in SECTOR_ETFS constant."""
|
||||
from tradingagents.dataflows.benchmark import SECTOR_ETFS
|
||||
|
||||
# Expected sectors
|
||||
expected_sectors = [
|
||||
'communication', 'consumer_discretionary', 'consumer_staples',
|
||||
'energy', 'financials', 'healthcare', 'industrials',
|
||||
'materials', 'real_estate', 'technology', 'utilities'
|
||||
]
|
||||
|
||||
# Assert all sectors exist
|
||||
for sector in expected_sectors:
|
||||
assert sector in SECTOR_ETFS, f"Missing sector: {sector}"
|
||||
|
||||
# Assert expected symbols
|
||||
assert SECTOR_ETFS['technology'] == 'XLK'
|
||||
assert SECTOR_ETFS['financials'] == 'XLF'
|
||||
assert SECTOR_ETFS['energy'] == 'XLE'
|
||||
assert SECTOR_ETFS['healthcare'] == 'XLV'
|
||||
assert SECTOR_ETFS['industrials'] == 'XLI'
|
||||
assert SECTOR_ETFS['materials'] == 'XLB'
|
||||
assert SECTOR_ETFS['consumer_discretionary'] == 'XLY'
|
||||
assert SECTOR_ETFS['consumer_staples'] == 'XLP'
|
||||
assert SECTOR_ETFS['real_estate'] == 'XLRE'
|
||||
assert SECTOR_ETFS['utilities'] == 'XLU'
|
||||
assert SECTOR_ETFS['communication'] == 'XLC'
|
||||
|
|
@ -0,0 +1,441 @@
|
|||
"""
|
||||
Benchmark Data Retrieval and Analysis Functions.
|
||||
|
||||
This module provides functions for retrieving and analyzing benchmark data:
|
||||
- Benchmark data fetching (SPY, sector ETFs)
|
||||
- Relative strength calculations (IBD-style)
|
||||
- Rolling correlation analysis
|
||||
- Beta calculations
|
||||
|
||||
All functions return pandas DataFrames/Series/floats on success or error strings on failure.
|
||||
|
||||
Usage:
|
||||
from tradingagents.dataflows.benchmark import (
|
||||
get_spy_data,
|
||||
get_sector_etf_data,
|
||||
calculate_relative_strength
|
||||
)
|
||||
|
||||
# Get SPY benchmark data
|
||||
spy_data = get_spy_data('2024-01-01', '2024-12-31')
|
||||
|
||||
# Get sector ETF data
|
||||
tech_data = get_sector_etf_data('technology', '2024-01-01', '2024-12-31')
|
||||
|
||||
# Calculate relative strength
|
||||
rs = calculate_relative_strength(stock_data, spy_data)
|
||||
|
||||
Requirements:
|
||||
- yfinance package: pip install yfinance
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Union, List
|
||||
from datetime import datetime
|
||||
|
||||
# Try to import yfinance, but allow it to be mocked in tests
|
||||
try:
|
||||
import yfinance as yf
|
||||
except ImportError:
|
||||
yf = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SECTOR ETF Mappings
|
||||
# ============================================================================
|
||||
|
||||
SECTOR_ETFS = {
|
||||
'communication': 'XLC',
|
||||
'consumer_discretionary': 'XLY',
|
||||
'consumer_staples': 'XLP',
|
||||
'energy': 'XLE',
|
||||
'financials': 'XLF',
|
||||
'healthcare': 'XLV',
|
||||
'industrials': 'XLI',
|
||||
'materials': 'XLB',
|
||||
'real_estate': 'XLRE',
|
||||
'technology': 'XLK',
|
||||
'utilities': 'XLU'
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Benchmark Data Fetching Functions
|
||||
# ============================================================================
|
||||
|
||||
def get_benchmark_data(
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Union[pd.DataFrame, str]:
|
||||
"""
|
||||
Fetch benchmark OHLCV data via yfinance.
|
||||
|
||||
Args:
|
||||
symbol: Ticker symbol (e.g., 'SPY', 'XLK')
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with DatetimeIndex and columns: Open, High, Low, Close, Volume
|
||||
str with error message on failure
|
||||
|
||||
Examples:
|
||||
>>> data = get_benchmark_data('SPY', '2024-01-01', '2024-12-31')
|
||||
>>> data = get_benchmark_data('XLK', '2024-01-01', '2024-12-31')
|
||||
"""
|
||||
if yf is None:
|
||||
return "Error: yfinance package is not installed. Install with: pip install yfinance"
|
||||
|
||||
try:
|
||||
# Validate date formats
|
||||
datetime.strptime(start_date, "%Y-%m-%d")
|
||||
datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError as e:
|
||||
return f"Error: Invalid date format. Use YYYY-MM-DD. Details: {str(e)}"
|
||||
|
||||
try:
|
||||
# Fetch data from yfinance
|
||||
ticker = yf.Ticker(symbol)
|
||||
data = ticker.history(start=start_date, end=end_date)
|
||||
|
||||
# Check if data is empty
|
||||
if data.empty:
|
||||
return f"Error: No data found for symbol '{symbol}' between {start_date} and {end_date}"
|
||||
|
||||
# Remove timezone info if present
|
||||
if data.index.tz is not None:
|
||||
data.index = data.index.tz_localize(None)
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
return f"Error fetching data for {symbol}: {str(e)}"
|
||||
|
||||
|
||||
def get_spy_data(
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Union[pd.DataFrame, str]:
|
||||
"""
|
||||
Fetch SPY benchmark data (convenience wrapper).
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with DatetimeIndex and columns: Open, High, Low, Close, Volume
|
||||
str with error message on failure
|
||||
|
||||
Examples:
|
||||
>>> spy_data = get_spy_data('2024-01-01', '2024-12-31')
|
||||
"""
|
||||
return get_benchmark_data('SPY', start_date, end_date)
|
||||
|
||||
|
||||
def get_sector_etf_data(
|
||||
sector: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Union[pd.DataFrame, str]:
|
||||
"""
|
||||
Fetch sector ETF data.
|
||||
|
||||
Args:
|
||||
sector: Sector name (e.g., 'technology', 'financials', 'energy')
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with DatetimeIndex and columns: Open, High, Low, Close, Volume
|
||||
str with error message on failure
|
||||
|
||||
Valid Sectors:
|
||||
- communication (XLC)
|
||||
- consumer_discretionary (XLY)
|
||||
- consumer_staples (XLP)
|
||||
- energy (XLE)
|
||||
- financials (XLF)
|
||||
- healthcare (XLV)
|
||||
- industrials (XLI)
|
||||
- materials (XLB)
|
||||
- real_estate (XLRE)
|
||||
- technology (XLK)
|
||||
- utilities (XLU)
|
||||
|
||||
Examples:
|
||||
>>> tech_data = get_sector_etf_data('technology', '2024-01-01', '2024-12-31')
|
||||
>>> finance_data = get_sector_etf_data('financials', '2024-01-01', '2024-12-31')
|
||||
"""
|
||||
# Validate sector
|
||||
if sector not in SECTOR_ETFS:
|
||||
valid_sectors = ', '.join(sorted(SECTOR_ETFS.keys()))
|
||||
return f"Error: Invalid sector '{sector}'. Valid sectors: {valid_sectors}"
|
||||
|
||||
# Get symbol for sector
|
||||
symbol = SECTOR_ETFS[sector]
|
||||
|
||||
# Fetch data
|
||||
return get_benchmark_data(symbol, start_date, end_date)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Relative Strength Calculation
|
||||
# ============================================================================
|
||||
|
||||
def calculate_relative_strength(
|
||||
stock_data: pd.DataFrame,
|
||||
benchmark_data: pd.DataFrame,
|
||||
periods: List[int] = [63, 126, 189, 252]
|
||||
) -> Union[float, str]:
|
||||
"""
|
||||
Calculate IBD-style relative strength.
|
||||
|
||||
Uses IBD formula with weighted rate of change (ROC) calculations:
|
||||
- 40% weight on 63-day (3-month) ROC
|
||||
- 20% weight on 126-day (6-month) ROC
|
||||
- 20% weight on 189-day (9-month) ROC
|
||||
- 20% weight on 252-day (12-month) ROC
|
||||
|
||||
Args:
|
||||
stock_data: DataFrame with 'Close' column
|
||||
benchmark_data: DataFrame with 'Close' column
|
||||
periods: List of periods for ROC calculation (default: [63, 126, 189, 252])
|
||||
|
||||
Returns:
|
||||
float: Relative strength score (stock RS - benchmark RS)
|
||||
Positive = stock outperforming benchmark
|
||||
Negative = stock underperforming benchmark
|
||||
str: Error message on failure
|
||||
|
||||
Examples:
|
||||
>>> rs = calculate_relative_strength(stock_data, spy_data)
|
||||
>>> rs = calculate_relative_strength(stock_data, spy_data, periods=[20, 60, 120, 180])
|
||||
"""
|
||||
# Validate inputs
|
||||
if stock_data.empty:
|
||||
return "Error: Stock data is empty"
|
||||
|
||||
if benchmark_data.empty:
|
||||
return "Error: Benchmark data is empty"
|
||||
|
||||
if 'Close' not in stock_data.columns:
|
||||
return "Error: Stock data missing 'Close' column"
|
||||
|
||||
if 'Close' not in benchmark_data.columns:
|
||||
return "Error: Benchmark data missing 'Close' column"
|
||||
|
||||
try:
|
||||
# Align dates via inner join
|
||||
aligned = pd.DataFrame({
|
||||
'stock_close': stock_data['Close'],
|
||||
'benchmark_close': benchmark_data['Close']
|
||||
}).dropna()
|
||||
|
||||
if aligned.empty:
|
||||
return "Error: No overlapping dates between stock and benchmark data"
|
||||
|
||||
# Check sufficient data for longest period
|
||||
# Allow some flexibility for trading days (250-252 trading days in a year)
|
||||
max_period = max(periods)
|
||||
# Require at least 98% of the period (e.g., 250 days for 252-day period)
|
||||
min_required = int(max_period * 0.98)
|
||||
if len(aligned) < min_required:
|
||||
return f"Error: Insufficient data. Need at least {min_required} days, have {len(aligned)}"
|
||||
|
||||
# Calculate ROC for each period
|
||||
stock_rocs = []
|
||||
benchmark_rocs = []
|
||||
|
||||
for period in periods:
|
||||
# ROC = (close / close.shift(period)) - 1
|
||||
# Use min of period and available data for flexibility with trading days
|
||||
actual_period = min(period, len(aligned) - 1)
|
||||
stock_roc = (aligned['stock_close'] / aligned['stock_close'].shift(actual_period)) - 1
|
||||
benchmark_roc = (aligned['benchmark_close'] / aligned['benchmark_close'].shift(actual_period)) - 1
|
||||
|
||||
# Get the most recent ROC value
|
||||
stock_rocs.append(stock_roc.iloc[-1])
|
||||
benchmark_rocs.append(benchmark_roc.iloc[-1])
|
||||
|
||||
# Check for NaN values
|
||||
if any(np.isnan(stock_rocs)) or any(np.isnan(benchmark_rocs)):
|
||||
return "Error: Unable to calculate ROC for all periods (NaN values)"
|
||||
|
||||
# Apply IBD weighting: 0.4, 0.2, 0.2, 0.2
|
||||
weights = [0.4, 0.2, 0.2, 0.2]
|
||||
|
||||
# Calculate weighted RS
|
||||
stock_rs = sum(roc * weight for roc, weight in zip(stock_rocs, weights))
|
||||
benchmark_rs = sum(roc * weight for roc, weight in zip(benchmark_rocs, weights))
|
||||
|
||||
# Return relative strength (stock RS - benchmark RS)
|
||||
relative_strength = stock_rs - benchmark_rs
|
||||
|
||||
return float(relative_strength)
|
||||
|
||||
except Exception as e:
|
||||
return f"Error calculating relative strength: {str(e)}"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Correlation Analysis
|
||||
# ============================================================================
|
||||
|
||||
def calculate_rolling_correlation(
|
||||
stock_data: pd.DataFrame,
|
||||
benchmark_data: pd.DataFrame,
|
||||
window: int = 63
|
||||
) -> Union[pd.Series, str]:
|
||||
"""
|
||||
Calculate rolling correlation between stock and benchmark.
|
||||
|
||||
Args:
|
||||
stock_data: DataFrame with 'Close' column
|
||||
benchmark_data: DataFrame with 'Close' column
|
||||
window: Rolling window size in days (default: 63 for ~3 months)
|
||||
|
||||
Returns:
|
||||
pd.Series: Rolling correlation values (range: -1 to 1)
|
||||
str: Error message on failure
|
||||
|
||||
Examples:
|
||||
>>> corr = calculate_rolling_correlation(stock_data, spy_data)
|
||||
>>> corr = calculate_rolling_correlation(stock_data, spy_data, window=20)
|
||||
"""
|
||||
# Validate window
|
||||
if window < 2:
|
||||
return "Error: Window must be at least 2"
|
||||
|
||||
# Validate inputs
|
||||
if stock_data.empty:
|
||||
return "Error: Stock data is empty"
|
||||
|
||||
if benchmark_data.empty:
|
||||
return "Error: Benchmark data is empty"
|
||||
|
||||
if 'Close' not in stock_data.columns:
|
||||
return "Error: Stock data missing 'Close' column"
|
||||
|
||||
if 'Close' not in benchmark_data.columns:
|
||||
return "Error: Benchmark data missing 'Close' column"
|
||||
|
||||
try:
|
||||
# Align dates via inner join
|
||||
aligned = pd.DataFrame({
|
||||
'stock_close': stock_data['Close'],
|
||||
'benchmark_close': benchmark_data['Close']
|
||||
}).dropna()
|
||||
|
||||
if aligned.empty:
|
||||
return "Error: No overlapping dates between stock and benchmark data"
|
||||
|
||||
# Calculate rolling correlation
|
||||
rolling_corr = aligned['stock_close'].rolling(window=window).corr(aligned['benchmark_close'])
|
||||
|
||||
# Clip to [-1, 1] to handle floating point precision issues
|
||||
rolling_corr = rolling_corr.clip(-1.0, 1.0)
|
||||
|
||||
return rolling_corr
|
||||
|
||||
except Exception as e:
|
||||
return f"Error calculating rolling correlation: {str(e)}"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Beta Calculation
|
||||
# ============================================================================
|
||||
|
||||
def calculate_beta(
|
||||
stock_data: pd.DataFrame,
|
||||
benchmark_data: pd.DataFrame,
|
||||
window: int = 252
|
||||
) -> Union[float, str]:
|
||||
"""
|
||||
Calculate beta (systematic risk measure).
|
||||
|
||||
Beta = Covariance(stock_returns, benchmark_returns) / Variance(benchmark_returns)
|
||||
|
||||
Args:
|
||||
stock_data: DataFrame with 'Close' column
|
||||
benchmark_data: DataFrame with 'Close' column
|
||||
window: Number of days for calculation (default: 252 for ~1 year)
|
||||
|
||||
Returns:
|
||||
float: Beta value
|
||||
Beta > 1: More volatile than benchmark
|
||||
Beta = 1: Same volatility as benchmark
|
||||
Beta < 1: Less volatile than benchmark
|
||||
str: Error message on failure
|
||||
|
||||
Examples:
|
||||
>>> beta = calculate_beta(stock_data, spy_data)
|
||||
>>> beta = calculate_beta(stock_data, spy_data, window=126)
|
||||
"""
|
||||
# Validate inputs
|
||||
if stock_data.empty:
|
||||
return "Error: Stock data is empty"
|
||||
|
||||
if benchmark_data.empty:
|
||||
return "Error: Benchmark data is empty"
|
||||
|
||||
if 'Close' not in stock_data.columns:
|
||||
return "Error: Stock data missing 'Close' column"
|
||||
|
||||
if 'Close' not in benchmark_data.columns:
|
||||
return "Error: Benchmark data missing 'Close' column"
|
||||
|
||||
try:
|
||||
# Align dates via inner join
|
||||
aligned = pd.DataFrame({
|
||||
'stock_close': stock_data['Close'],
|
||||
'benchmark_close': benchmark_data['Close']
|
||||
}).dropna()
|
||||
|
||||
if aligned.empty:
|
||||
return "Error: No overlapping dates between stock and benchmark data"
|
||||
|
||||
# Check sufficient data
|
||||
# For beta calculation, allow some flexibility for trading days
|
||||
min_required = int(window * 0.98)
|
||||
if len(aligned) < min_required:
|
||||
return f"Error: Insufficient data. Need at least {min_required} days, have {len(aligned)}"
|
||||
|
||||
# Calculate returns
|
||||
stock_returns = aligned['stock_close'].pct_change()
|
||||
benchmark_returns = aligned['benchmark_close'].pct_change()
|
||||
|
||||
# Take last window days
|
||||
stock_returns_window = stock_returns.tail(window)
|
||||
benchmark_returns_window = benchmark_returns.tail(window)
|
||||
|
||||
# Remove NaN values
|
||||
valid_data = pd.DataFrame({
|
||||
'stock': stock_returns_window,
|
||||
'benchmark': benchmark_returns_window
|
||||
}).dropna()
|
||||
|
||||
if valid_data.empty:
|
||||
return "Error: No valid returns data after removing NaN values"
|
||||
|
||||
# Calculate covariance and variance
|
||||
covariance = valid_data['stock'].cov(valid_data['benchmark'])
|
||||
variance = valid_data['benchmark'].var()
|
||||
|
||||
# Handle zero variance
|
||||
if variance == 0 or np.isnan(variance):
|
||||
return "Error: Benchmark has zero variance (no price movement)"
|
||||
|
||||
# Calculate beta
|
||||
beta = covariance / variance
|
||||
|
||||
# Check for NaN
|
||||
if np.isnan(beta):
|
||||
return "Error: Beta calculation resulted in NaN"
|
||||
|
||||
return float(beta)
|
||||
|
||||
except Exception as e:
|
||||
return f"Error calculating beta: {str(e)}"
|
||||
Loading…
Reference in New Issue