855 lines
32 KiB
Python
855 lines
32 KiB
Python
"""Tests for Correlation Analyst agent.
|
|
|
|
Issue #15: [AGENT-14] Correlation Analyst - cross-asset, sector rotation
|
|
|
|
These tests define the logic locally to avoid langchain import issues.
|
|
"""
|
|
|
|
import pytest
|
|
import pandas as pd
|
|
import numpy as np
|
|
from datetime import datetime, timedelta
|
|
from unittest.mock import Mock, MagicMock
|
|
from enum import Enum
|
|
|
|
pytestmark = pytest.mark.unit
|
|
|
|
|
|
# ============================================================================
|
|
# Local Definitions (matching correlation_analyst.py)
|
|
# ============================================================================
|
|
|
|
class CorrelationStrength(str, Enum):
|
|
"""Classification of correlation strength."""
|
|
VERY_STRONG_POSITIVE = "very_strong_positive"
|
|
STRONG_POSITIVE = "strong_positive"
|
|
MODERATE_POSITIVE = "moderate_positive"
|
|
WEAK_POSITIVE = "weak_positive"
|
|
NEGLIGIBLE = "negligible"
|
|
WEAK_NEGATIVE = "weak_negative"
|
|
MODERATE_NEGATIVE = "moderate_negative"
|
|
STRONG_NEGATIVE = "strong_negative"
|
|
VERY_STRONG_NEGATIVE = "very_strong_negative"
|
|
|
|
|
|
class SectorPhase(str, Enum):
|
|
"""Economic cycle phase for sector rotation."""
|
|
EARLY_CYCLE = "early_cycle"
|
|
MID_CYCLE = "mid_cycle"
|
|
LATE_CYCLE = "late_cycle"
|
|
RECESSION = "recession"
|
|
|
|
|
|
class SectorLeadership(str, Enum):
|
|
"""Sector leadership classification."""
|
|
LEADING = "leading"
|
|
LAGGING = "lagging"
|
|
IMPROVING = "improving"
|
|
WEAKENING = "weakening"
|
|
|
|
|
|
# ============================================================================
|
|
# Helper Functions (matching correlation_analyst.py)
|
|
# ============================================================================
|
|
|
|
def _calculate_correlation(series1: pd.Series, series2: pd.Series) -> float:
|
|
"""Calculate Pearson correlation between two series."""
|
|
if len(series1) < 2 or len(series2) < 2:
|
|
return 0.0
|
|
min_len = min(len(series1), len(series2))
|
|
s1 = series1.iloc[-min_len:].values
|
|
s2 = series2.iloc[-min_len:].values
|
|
|
|
if np.std(s1) == 0 or np.std(s2) == 0:
|
|
return 0.0
|
|
|
|
return float(np.corrcoef(s1, s2)[0, 1])
|
|
|
|
|
|
def _calculate_rolling_correlation(
|
|
series1: pd.Series,
|
|
series2: pd.Series,
|
|
window: int = 20
|
|
) -> pd.Series:
|
|
"""Calculate rolling correlation between two series."""
|
|
if len(series1) < window or len(series2) < window:
|
|
return pd.Series([])
|
|
|
|
min_len = min(len(series1), len(series2))
|
|
s1 = series1.iloc[-min_len:]
|
|
s2 = series2.iloc[-min_len:]
|
|
|
|
rolling_corr = s1.rolling(window=window).corr(s2)
|
|
return rolling_corr.dropna()
|
|
|
|
|
|
def _classify_correlation(corr: float) -> CorrelationStrength:
|
|
"""Classify correlation coefficient into strength categories."""
|
|
if corr >= 0.8:
|
|
return CorrelationStrength.VERY_STRONG_POSITIVE
|
|
elif corr >= 0.6:
|
|
return CorrelationStrength.STRONG_POSITIVE
|
|
elif corr >= 0.4:
|
|
return CorrelationStrength.MODERATE_POSITIVE
|
|
elif corr >= 0.2:
|
|
return CorrelationStrength.WEAK_POSITIVE
|
|
elif corr > -0.2:
|
|
return CorrelationStrength.NEGLIGIBLE
|
|
elif corr > -0.4:
|
|
return CorrelationStrength.WEAK_NEGATIVE
|
|
elif corr > -0.6:
|
|
return CorrelationStrength.MODERATE_NEGATIVE
|
|
elif corr > -0.8:
|
|
return CorrelationStrength.STRONG_NEGATIVE
|
|
else:
|
|
return CorrelationStrength.VERY_STRONG_NEGATIVE
|
|
|
|
|
|
def _detect_correlation_breakdown(
|
|
rolling_corr: pd.Series,
|
|
threshold_change: float = 0.3
|
|
) -> dict:
|
|
"""Detect significant correlation breakdown events."""
|
|
if len(rolling_corr) < 10:
|
|
return {"detected": False, "details": "Insufficient data"}
|
|
|
|
corr_diff = rolling_corr.diff()
|
|
large_changes = corr_diff[abs(corr_diff) > threshold_change]
|
|
|
|
if len(large_changes) == 0:
|
|
return {"detected": False, "details": "No significant correlation changes"}
|
|
|
|
recent_change = corr_diff.iloc[-20:] if len(corr_diff) >= 20 else corr_diff
|
|
max_change_idx = recent_change.abs().idxmax()
|
|
max_change = recent_change.loc[max_change_idx]
|
|
|
|
return {
|
|
"detected": abs(max_change) > threshold_change,
|
|
"change_magnitude": float(max_change),
|
|
"direction": "increasing" if max_change > 0 else "decreasing",
|
|
"current_correlation": float(rolling_corr.iloc[-1]),
|
|
"prior_correlation": float(rolling_corr.iloc[-1] - max_change)
|
|
}
|
|
|
|
|
|
def _calculate_relative_strength(
|
|
returns: pd.Series,
|
|
benchmark_returns: pd.Series,
|
|
window: int = 20
|
|
) -> pd.Series:
|
|
"""Calculate relative strength vs benchmark."""
|
|
if len(returns) < window or len(benchmark_returns) < window:
|
|
return pd.Series([])
|
|
|
|
min_len = min(len(returns), len(benchmark_returns))
|
|
ret = returns.iloc[-min_len:]
|
|
bench = benchmark_returns.iloc[-min_len:]
|
|
|
|
cum_ret = (1 + ret).cumprod()
|
|
cum_bench = (1 + bench).cumprod()
|
|
|
|
relative_strength = cum_ret / cum_bench
|
|
return relative_strength
|
|
|
|
|
|
def _classify_sector_leadership(
|
|
relative_strength: pd.Series,
|
|
window: int = 20
|
|
) -> SectorLeadership:
|
|
"""Classify sector leadership based on relative strength trend."""
|
|
if len(relative_strength) < window:
|
|
return SectorLeadership.LAGGING
|
|
|
|
recent = relative_strength.iloc[-window:]
|
|
|
|
rs_start = recent.iloc[0]
|
|
rs_end = recent.iloc[-1]
|
|
rs_mid = recent.iloc[window//2]
|
|
|
|
current_vs_start = (rs_end - rs_start) / rs_start if rs_start != 0 else 0
|
|
|
|
if rs_end > 1 and current_vs_start > 0.02:
|
|
return SectorLeadership.LEADING
|
|
elif rs_end > 1 and current_vs_start < 0:
|
|
return SectorLeadership.WEAKENING
|
|
elif rs_end < 1 and current_vs_start > 0:
|
|
return SectorLeadership.IMPROVING
|
|
else:
|
|
return SectorLeadership.LAGGING
|
|
|
|
|
|
def _identify_cycle_phase(indicators: dict) -> SectorPhase:
|
|
"""Identify economic cycle phase from market indicators."""
|
|
yield_curve_slope = indicators.get('yield_curve_slope', 0)
|
|
leading_index = indicators.get('leading_index', 0)
|
|
pmi = indicators.get('pmi', 50)
|
|
|
|
if pmi > 50 and leading_index > 0 and yield_curve_slope > 0:
|
|
return SectorPhase.EARLY_CYCLE
|
|
elif pmi > 50 and leading_index > 0:
|
|
return SectorPhase.MID_CYCLE
|
|
elif pmi > 50 and leading_index < 0:
|
|
return SectorPhase.LATE_CYCLE
|
|
else:
|
|
return SectorPhase.RECESSION
|
|
|
|
|
|
def _get_cycle_sector_recommendations(phase: SectorPhase) -> dict:
|
|
"""Get sector recommendations for each cycle phase."""
|
|
recommendations = {
|
|
SectorPhase.EARLY_CYCLE: {
|
|
"overweight": ["XLF", "XLY", "XLI", "XLB"],
|
|
"underweight": ["XLP", "XLU", "XLRE"],
|
|
"rationale": "Economic recovery favors cyclical sectors with high operating leverage"
|
|
},
|
|
SectorPhase.MID_CYCLE: {
|
|
"overweight": ["XLK", "XLI", "XLB"],
|
|
"underweight": ["XLU", "XLP"],
|
|
"rationale": "Sustained growth benefits sectors with secular trends and industrial production"
|
|
},
|
|
SectorPhase.LATE_CYCLE: {
|
|
"overweight": ["XLE", "XLB", "XLI"],
|
|
"underweight": ["XLK", "XLY", "XLF"],
|
|
"rationale": "Inflation hedge and commodity exposure preferred as cycle matures"
|
|
},
|
|
SectorPhase.RECESSION: {
|
|
"overweight": ["XLU", "XLP", "XLV"],
|
|
"underweight": ["XLY", "XLI", "XLB"],
|
|
"rationale": "Defensive sectors with stable cash flows outperform during contractions"
|
|
}
|
|
}
|
|
return recommendations.get(phase, {"overweight": [], "underweight": [], "rationale": "Unknown phase"})
|
|
|
|
|
|
def _interpret_cross_asset_correlation(
|
|
stock_bond_corr: float,
|
|
stock_gold_corr: float,
|
|
stock_oil_corr: float
|
|
) -> str:
|
|
"""Interpret cross-asset correlations for market regime."""
|
|
interpretations = []
|
|
|
|
if stock_bond_corr > 0.3:
|
|
interpretations.append("RISK-OFF REGIME: Positive stock-bond correlation suggests flight to quality")
|
|
elif stock_bond_corr < -0.3:
|
|
interpretations.append("NORMAL REGIME: Negative stock-bond correlation indicates balanced risk appetite")
|
|
else:
|
|
interpretations.append("TRANSITIONAL REGIME: Low stock-bond correlation may signal regime change")
|
|
|
|
if stock_gold_corr < -0.3:
|
|
interpretations.append("HEDGING ACTIVE: Gold acting as portfolio hedge against equity risk")
|
|
elif stock_gold_corr > 0.3:
|
|
interpretations.append("LIQUIDITY DRIVEN: Both assets rising suggests monetary expansion")
|
|
|
|
if stock_oil_corr > 0.5:
|
|
interpretations.append("GROWTH SENSITIVE: Strong stock-oil correlation reflects economic growth expectations")
|
|
elif stock_oil_corr < -0.3:
|
|
interpretations.append("SUPPLY SHOCK: Negative correlation may indicate energy cost pressure on equities")
|
|
|
|
return "\n".join(interpretations) if interpretations else "Normal cross-asset relationships"
|
|
|
|
|
|
def _format_correlation_signal(corr: float) -> str:
|
|
"""Format correlation value with directional signal."""
|
|
strength = _classify_correlation(corr)
|
|
if corr > 0:
|
|
return f"+{corr:.3f} ({strength.value.replace('_', ' ').title()})"
|
|
else:
|
|
return f"{corr:.3f} ({strength.value.replace('_', ' ').title()})"
|
|
|
|
|
|
# ============================================================================
|
|
# Test Fixtures
|
|
# ============================================================================
|
|
|
|
@pytest.fixture
|
|
def sample_returns():
|
|
"""Generate sample return series for testing."""
|
|
np.random.seed(42)
|
|
dates = pd.date_range(start='2024-01-01', periods=100, freq='D')
|
|
returns = pd.Series(np.random.normal(0.001, 0.02, 100), index=dates)
|
|
return returns
|
|
|
|
|
|
@pytest.fixture
|
|
def correlated_returns():
|
|
"""Generate correlated return series."""
|
|
np.random.seed(42)
|
|
dates = pd.date_range(start='2024-01-01', periods=100, freq='D')
|
|
base = np.random.normal(0.001, 0.02, 100)
|
|
# Highly correlated series
|
|
correlated = base * 0.8 + np.random.normal(0, 0.005, 100)
|
|
return pd.Series(base, index=dates), pd.Series(correlated, index=dates)
|
|
|
|
|
|
@pytest.fixture
|
|
def negatively_correlated_returns():
|
|
"""Generate negatively correlated return series."""
|
|
np.random.seed(42)
|
|
dates = pd.date_range(start='2024-01-01', periods=100, freq='D')
|
|
base = np.random.normal(0.001, 0.02, 100)
|
|
# Negatively correlated
|
|
negative = -base * 0.8 + np.random.normal(0, 0.005, 100)
|
|
return pd.Series(base, index=dates), pd.Series(negative, index=dates)
|
|
|
|
|
|
@pytest.fixture
|
|
def uncorrelated_returns():
|
|
"""Generate uncorrelated return series."""
|
|
np.random.seed(42)
|
|
dates = pd.date_range(start='2024-01-01', periods=100, freq='D')
|
|
series1 = pd.Series(np.random.normal(0.001, 0.02, 100), index=dates)
|
|
np.random.seed(99) # Different seed
|
|
series2 = pd.Series(np.random.normal(0.001, 0.02, 100), index=dates)
|
|
return series1, series2
|
|
|
|
|
|
@pytest.fixture
|
|
def benchmark_returns():
|
|
"""Generate benchmark (SPY-like) returns."""
|
|
np.random.seed(42)
|
|
dates = pd.date_range(start='2024-01-01', periods=100, freq='D')
|
|
returns = pd.Series(np.random.normal(0.0005, 0.015, 100), index=dates)
|
|
return returns
|
|
|
|
|
|
@pytest.fixture
|
|
def outperforming_sector_returns(benchmark_returns):
|
|
"""Generate sector returns that outperform benchmark."""
|
|
# Higher mean, similar volatility
|
|
np.random.seed(43)
|
|
dates = benchmark_returns.index
|
|
returns = pd.Series(np.random.normal(0.002, 0.018, len(dates)), index=dates)
|
|
return returns
|
|
|
|
|
|
@pytest.fixture
|
|
def underperforming_sector_returns(benchmark_returns):
|
|
"""Generate sector returns that underperform benchmark."""
|
|
np.random.seed(44)
|
|
dates = benchmark_returns.index
|
|
returns = pd.Series(np.random.normal(-0.001, 0.02, len(dates)), index=dates)
|
|
return returns
|
|
|
|
|
|
# ============================================================================
|
|
# Test Classes
|
|
# ============================================================================
|
|
|
|
class TestCorrelationCalculation:
|
|
"""Tests for correlation calculation."""
|
|
|
|
def test_perfect_positive_correlation(self):
|
|
"""Test perfect positive correlation."""
|
|
series1 = pd.Series([1, 2, 3, 4, 5])
|
|
series2 = pd.Series([2, 4, 6, 8, 10]) # 2x series1
|
|
corr = _calculate_correlation(series1, series2)
|
|
assert abs(corr - 1.0) < 0.001
|
|
|
|
def test_perfect_negative_correlation(self):
|
|
"""Test perfect negative correlation."""
|
|
series1 = pd.Series([1, 2, 3, 4, 5])
|
|
series2 = pd.Series([5, 4, 3, 2, 1]) # Reverse
|
|
corr = _calculate_correlation(series1, series2)
|
|
assert abs(corr - (-1.0)) < 0.001
|
|
|
|
def test_zero_correlation_with_constant(self):
|
|
"""Test zero correlation with constant series."""
|
|
series1 = pd.Series([1, 2, 3, 4, 5])
|
|
series2 = pd.Series([5, 5, 5, 5, 5]) # Constant
|
|
corr = _calculate_correlation(series1, series2)
|
|
assert corr == 0.0
|
|
|
|
def test_insufficient_data(self):
|
|
"""Test handling of insufficient data."""
|
|
series1 = pd.Series([1])
|
|
series2 = pd.Series([2])
|
|
corr = _calculate_correlation(series1, series2)
|
|
assert corr == 0.0
|
|
|
|
def test_empty_series(self):
|
|
"""Test handling of empty series."""
|
|
series1 = pd.Series([])
|
|
series2 = pd.Series([])
|
|
corr = _calculate_correlation(series1, series2)
|
|
assert corr == 0.0
|
|
|
|
def test_different_length_series(self):
|
|
"""Test alignment of different length series."""
|
|
series1 = pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
|
series2 = pd.Series([2, 4, 6, 8, 10]) # Shorter
|
|
corr = _calculate_correlation(series1, series2)
|
|
# Should use last 5 elements of series1
|
|
assert corr > 0.9 # High positive correlation
|
|
|
|
def test_real_world_correlation(self, correlated_returns):
|
|
"""Test correlation with realistic return data."""
|
|
series1, series2 = correlated_returns
|
|
corr = _calculate_correlation(series1, series2)
|
|
assert corr > 0.7 # Strong positive
|
|
|
|
|
|
class TestRollingCorrelation:
|
|
"""Tests for rolling correlation calculation."""
|
|
|
|
def test_rolling_correlation_calculation(self, correlated_returns):
|
|
"""Test rolling correlation produces results."""
|
|
series1, series2 = correlated_returns
|
|
rolling = _calculate_rolling_correlation(series1, series2, window=20)
|
|
assert len(rolling) > 0
|
|
assert all(abs(r) <= 1 for r in rolling)
|
|
|
|
def test_rolling_correlation_insufficient_data(self):
|
|
"""Test handling insufficient data for rolling."""
|
|
series1 = pd.Series([1, 2, 3, 4, 5])
|
|
series2 = pd.Series([2, 4, 6, 8, 10])
|
|
rolling = _calculate_rolling_correlation(series1, series2, window=20)
|
|
assert len(rolling) == 0
|
|
|
|
def test_rolling_window_size(self, correlated_returns):
|
|
"""Test different window sizes."""
|
|
series1, series2 = correlated_returns
|
|
rolling_20 = _calculate_rolling_correlation(series1, series2, window=20)
|
|
rolling_40 = _calculate_rolling_correlation(series1, series2, window=40)
|
|
# Larger window = fewer results
|
|
assert len(rolling_20) > len(rolling_40)
|
|
|
|
|
|
class TestCorrelationClassification:
|
|
"""Tests for correlation strength classification."""
|
|
|
|
def test_very_strong_positive(self):
|
|
"""Test very strong positive classification."""
|
|
assert _classify_correlation(0.85) == CorrelationStrength.VERY_STRONG_POSITIVE
|
|
assert _classify_correlation(0.95) == CorrelationStrength.VERY_STRONG_POSITIVE
|
|
|
|
def test_strong_positive(self):
|
|
"""Test strong positive classification."""
|
|
assert _classify_correlation(0.65) == CorrelationStrength.STRONG_POSITIVE
|
|
assert _classify_correlation(0.75) == CorrelationStrength.STRONG_POSITIVE
|
|
|
|
def test_moderate_positive(self):
|
|
"""Test moderate positive classification."""
|
|
assert _classify_correlation(0.45) == CorrelationStrength.MODERATE_POSITIVE
|
|
assert _classify_correlation(0.55) == CorrelationStrength.MODERATE_POSITIVE
|
|
|
|
def test_weak_positive(self):
|
|
"""Test weak positive classification."""
|
|
assert _classify_correlation(0.25) == CorrelationStrength.WEAK_POSITIVE
|
|
assert _classify_correlation(0.35) == CorrelationStrength.WEAK_POSITIVE
|
|
|
|
def test_negligible(self):
|
|
"""Test negligible classification."""
|
|
assert _classify_correlation(0.0) == CorrelationStrength.NEGLIGIBLE
|
|
assert _classify_correlation(0.15) == CorrelationStrength.NEGLIGIBLE
|
|
assert _classify_correlation(-0.15) == CorrelationStrength.NEGLIGIBLE
|
|
|
|
def test_weak_negative(self):
|
|
"""Test weak negative classification."""
|
|
assert _classify_correlation(-0.25) == CorrelationStrength.WEAK_NEGATIVE
|
|
assert _classify_correlation(-0.35) == CorrelationStrength.WEAK_NEGATIVE
|
|
|
|
def test_moderate_negative(self):
|
|
"""Test moderate negative classification."""
|
|
assert _classify_correlation(-0.45) == CorrelationStrength.MODERATE_NEGATIVE
|
|
assert _classify_correlation(-0.55) == CorrelationStrength.MODERATE_NEGATIVE
|
|
|
|
def test_strong_negative(self):
|
|
"""Test strong negative classification."""
|
|
assert _classify_correlation(-0.65) == CorrelationStrength.STRONG_NEGATIVE
|
|
assert _classify_correlation(-0.75) == CorrelationStrength.STRONG_NEGATIVE
|
|
|
|
def test_very_strong_negative(self):
|
|
"""Test very strong negative classification."""
|
|
assert _classify_correlation(-0.85) == CorrelationStrength.VERY_STRONG_NEGATIVE
|
|
assert _classify_correlation(-0.95) == CorrelationStrength.VERY_STRONG_NEGATIVE
|
|
|
|
def test_boundary_values(self):
|
|
"""Test classification at boundaries."""
|
|
assert _classify_correlation(0.8) == CorrelationStrength.VERY_STRONG_POSITIVE
|
|
assert _classify_correlation(0.6) == CorrelationStrength.STRONG_POSITIVE
|
|
assert _classify_correlation(0.4) == CorrelationStrength.MODERATE_POSITIVE
|
|
assert _classify_correlation(0.2) == CorrelationStrength.WEAK_POSITIVE
|
|
|
|
|
|
class TestCorrelationBreakdownDetection:
|
|
"""Tests for correlation breakdown detection."""
|
|
|
|
def test_breakdown_detected_increasing(self):
|
|
"""Test detection of increasing correlation breakdown."""
|
|
# Create series with a jump
|
|
rolling = pd.Series([0.2, 0.2, 0.2, 0.2, 0.2, 0.6, 0.6, 0.6, 0.6, 0.6])
|
|
result = _detect_correlation_breakdown(rolling, threshold_change=0.3)
|
|
assert result["detected"] == True
|
|
assert result["direction"] == "increasing"
|
|
|
|
def test_breakdown_detected_decreasing(self):
|
|
"""Test detection of decreasing correlation breakdown."""
|
|
rolling = pd.Series([0.8, 0.8, 0.8, 0.8, 0.8, 0.3, 0.3, 0.3, 0.3, 0.3])
|
|
result = _detect_correlation_breakdown(rolling, threshold_change=0.3)
|
|
assert result["detected"] == True
|
|
assert result["direction"] == "decreasing"
|
|
|
|
def test_no_breakdown_stable(self):
|
|
"""Test no breakdown with stable correlation."""
|
|
rolling = pd.Series([0.5, 0.52, 0.48, 0.51, 0.49, 0.50, 0.51, 0.49, 0.50, 0.51])
|
|
result = _detect_correlation_breakdown(rolling, threshold_change=0.3)
|
|
assert result["detected"] == False
|
|
|
|
def test_insufficient_data(self):
|
|
"""Test handling of insufficient data."""
|
|
rolling = pd.Series([0.5, 0.6, 0.7])
|
|
result = _detect_correlation_breakdown(rolling)
|
|
assert result["detected"] == False
|
|
assert "Insufficient" in result["details"]
|
|
|
|
|
|
class TestRelativeStrength:
|
|
"""Tests for relative strength calculation."""
|
|
|
|
def test_outperforming_relative_strength(self, outperforming_sector_returns, benchmark_returns):
|
|
"""Test relative strength for outperforming sector."""
|
|
rs = _calculate_relative_strength(outperforming_sector_returns, benchmark_returns)
|
|
assert len(rs) > 0
|
|
# Should end above 1.0 for outperforming
|
|
assert rs.iloc[-1] > 1.0
|
|
|
|
def test_underperforming_relative_strength(self, underperforming_sector_returns, benchmark_returns):
|
|
"""Test relative strength for underperforming sector."""
|
|
rs = _calculate_relative_strength(underperforming_sector_returns, benchmark_returns)
|
|
assert len(rs) > 0
|
|
# Should end below 1.0 for underperforming
|
|
assert rs.iloc[-1] < 1.0
|
|
|
|
def test_insufficient_data(self):
|
|
"""Test handling of insufficient data."""
|
|
returns = pd.Series([0.01, 0.02])
|
|
benchmark = pd.Series([0.01, 0.02])
|
|
rs = _calculate_relative_strength(returns, benchmark, window=20)
|
|
assert len(rs) == 0
|
|
|
|
|
|
class TestSectorLeadershipClassification:
|
|
"""Tests for sector leadership classification."""
|
|
|
|
def test_leading_sector(self, outperforming_sector_returns, benchmark_returns):
|
|
"""Test classification of leading sector."""
|
|
rs = _calculate_relative_strength(outperforming_sector_returns, benchmark_returns)
|
|
if len(rs) >= 20:
|
|
leadership = _classify_sector_leadership(rs)
|
|
assert leadership in [SectorLeadership.LEADING, SectorLeadership.IMPROVING]
|
|
|
|
def test_lagging_sector(self, underperforming_sector_returns, benchmark_returns):
|
|
"""Test classification of lagging sector."""
|
|
rs = _calculate_relative_strength(underperforming_sector_returns, benchmark_returns)
|
|
if len(rs) >= 20:
|
|
leadership = _classify_sector_leadership(rs)
|
|
assert leadership in [SectorLeadership.LAGGING, SectorLeadership.WEAKENING]
|
|
|
|
def test_insufficient_data_defaults_to_lagging(self):
|
|
"""Test default classification with insufficient data."""
|
|
short_rs = pd.Series([1.0, 1.01, 1.02])
|
|
leadership = _classify_sector_leadership(short_rs, window=20)
|
|
assert leadership == SectorLeadership.LAGGING
|
|
|
|
|
|
class TestCyclePhaseIdentification:
|
|
"""Tests for economic cycle phase identification."""
|
|
|
|
def test_early_cycle(self):
|
|
"""Test early cycle phase identification."""
|
|
indicators = {
|
|
'pmi': 55,
|
|
'leading_index': 0.5,
|
|
'yield_curve_slope': 0.5
|
|
}
|
|
phase = _identify_cycle_phase(indicators)
|
|
assert phase == SectorPhase.EARLY_CYCLE
|
|
|
|
def test_mid_cycle(self):
|
|
"""Test mid cycle phase identification."""
|
|
indicators = {
|
|
'pmi': 55,
|
|
'leading_index': 0.3,
|
|
'yield_curve_slope': 0 # Flat curve
|
|
}
|
|
phase = _identify_cycle_phase(indicators)
|
|
assert phase == SectorPhase.MID_CYCLE
|
|
|
|
def test_late_cycle(self):
|
|
"""Test late cycle phase identification."""
|
|
indicators = {
|
|
'pmi': 52,
|
|
'leading_index': -0.2,
|
|
'yield_curve_slope': -0.3
|
|
}
|
|
phase = _identify_cycle_phase(indicators)
|
|
assert phase == SectorPhase.LATE_CYCLE
|
|
|
|
def test_recession(self):
|
|
"""Test recession phase identification."""
|
|
indicators = {
|
|
'pmi': 45,
|
|
'leading_index': -0.5,
|
|
'yield_curve_slope': -0.5
|
|
}
|
|
phase = _identify_cycle_phase(indicators)
|
|
assert phase == SectorPhase.RECESSION
|
|
|
|
|
|
class TestSectorRecommendations:
|
|
"""Tests for cycle-based sector recommendations."""
|
|
|
|
def test_early_cycle_recommendations(self):
|
|
"""Test early cycle recommendations."""
|
|
recs = _get_cycle_sector_recommendations(SectorPhase.EARLY_CYCLE)
|
|
assert "XLF" in recs["overweight"] # Financials
|
|
assert "XLY" in recs["overweight"] # Consumer Discretionary
|
|
assert "XLU" in recs["underweight"] # Utilities
|
|
|
|
def test_recession_recommendations(self):
|
|
"""Test recession recommendations."""
|
|
recs = _get_cycle_sector_recommendations(SectorPhase.RECESSION)
|
|
assert "XLU" in recs["overweight"] # Utilities
|
|
assert "XLP" in recs["overweight"] # Consumer Staples
|
|
assert "XLY" in recs["underweight"] # Consumer Discretionary
|
|
|
|
def test_late_cycle_recommendations(self):
|
|
"""Test late cycle recommendations."""
|
|
recs = _get_cycle_sector_recommendations(SectorPhase.LATE_CYCLE)
|
|
assert "XLE" in recs["overweight"] # Energy
|
|
assert "XLK" in recs["underweight"] # Tech
|
|
|
|
def test_all_phases_have_rationale(self):
|
|
"""Test all phases have rationale."""
|
|
for phase in SectorPhase:
|
|
recs = _get_cycle_sector_recommendations(phase)
|
|
assert "rationale" in recs
|
|
assert len(recs["rationale"]) > 0
|
|
|
|
|
|
class TestCrossAssetInterpretation:
|
|
"""Tests for cross-asset correlation interpretation."""
|
|
|
|
def test_risk_off_regime(self):
|
|
"""Test risk-off regime interpretation."""
|
|
interpretation = _interpret_cross_asset_correlation(
|
|
stock_bond_corr=0.5,
|
|
stock_gold_corr=0.0,
|
|
stock_oil_corr=0.0
|
|
)
|
|
assert "RISK-OFF" in interpretation
|
|
|
|
def test_normal_regime(self):
|
|
"""Test normal regime interpretation."""
|
|
interpretation = _interpret_cross_asset_correlation(
|
|
stock_bond_corr=-0.5,
|
|
stock_gold_corr=0.0,
|
|
stock_oil_corr=0.0
|
|
)
|
|
assert "NORMAL" in interpretation
|
|
|
|
def test_hedging_active(self):
|
|
"""Test hedging interpretation."""
|
|
interpretation = _interpret_cross_asset_correlation(
|
|
stock_bond_corr=0.0,
|
|
stock_gold_corr=-0.5,
|
|
stock_oil_corr=0.0
|
|
)
|
|
assert "HEDGING" in interpretation
|
|
|
|
def test_liquidity_driven(self):
|
|
"""Test liquidity-driven interpretation."""
|
|
interpretation = _interpret_cross_asset_correlation(
|
|
stock_bond_corr=0.0,
|
|
stock_gold_corr=0.5,
|
|
stock_oil_corr=0.0
|
|
)
|
|
assert "LIQUIDITY" in interpretation
|
|
|
|
def test_growth_sensitive(self):
|
|
"""Test growth sensitivity interpretation."""
|
|
interpretation = _interpret_cross_asset_correlation(
|
|
stock_bond_corr=0.0,
|
|
stock_gold_corr=0.0,
|
|
stock_oil_corr=0.7
|
|
)
|
|
assert "GROWTH" in interpretation
|
|
|
|
def test_supply_shock(self):
|
|
"""Test supply shock interpretation."""
|
|
interpretation = _interpret_cross_asset_correlation(
|
|
stock_bond_corr=0.0,
|
|
stock_gold_corr=0.0,
|
|
stock_oil_corr=-0.5
|
|
)
|
|
assert "SUPPLY SHOCK" in interpretation
|
|
|
|
|
|
class TestCorrelationSignalFormatting:
|
|
"""Tests for correlation signal formatting."""
|
|
|
|
def test_positive_correlation_format(self):
|
|
"""Test positive correlation formatting."""
|
|
result = _format_correlation_signal(0.75)
|
|
assert "+0.750" in result
|
|
assert "Strong Positive" in result
|
|
|
|
def test_negative_correlation_format(self):
|
|
"""Test negative correlation formatting."""
|
|
result = _format_correlation_signal(-0.65)
|
|
assert "-0.650" in result
|
|
assert "Strong Negative" in result
|
|
|
|
def test_negligible_correlation_format(self):
|
|
"""Test negligible correlation formatting."""
|
|
result = _format_correlation_signal(0.05)
|
|
assert "Negligible" in result
|
|
|
|
|
|
class TestEnumValues:
|
|
"""Tests for enum value consistency."""
|
|
|
|
def test_correlation_strength_values(self):
|
|
"""Test correlation strength enum values."""
|
|
assert CorrelationStrength.VERY_STRONG_POSITIVE.value == "very_strong_positive"
|
|
assert CorrelationStrength.NEGLIGIBLE.value == "negligible"
|
|
assert CorrelationStrength.VERY_STRONG_NEGATIVE.value == "very_strong_negative"
|
|
|
|
def test_sector_phase_values(self):
|
|
"""Test sector phase enum values."""
|
|
assert SectorPhase.EARLY_CYCLE.value == "early_cycle"
|
|
assert SectorPhase.RECESSION.value == "recession"
|
|
|
|
def test_sector_leadership_values(self):
|
|
"""Test sector leadership enum values."""
|
|
assert SectorLeadership.LEADING.value == "leading"
|
|
assert SectorLeadership.LAGGING.value == "lagging"
|
|
|
|
|
|
class TestEdgeCases:
|
|
"""Tests for edge cases and error handling."""
|
|
|
|
def test_nan_handling_in_correlation(self):
|
|
"""Test NaN handling in correlation."""
|
|
series1 = pd.Series([1.0, np.nan, 3.0, 4.0, 5.0])
|
|
series2 = pd.Series([2.0, 4.0, np.nan, 8.0, 10.0])
|
|
# Should handle NaN without crashing
|
|
corr = _calculate_correlation(series1.dropna(), series2.dropna())
|
|
assert not np.isnan(corr)
|
|
|
|
def test_inf_handling_in_relative_strength(self):
|
|
"""Test infinite value handling."""
|
|
returns = pd.Series([0.01, 0.02, 0.01, 0.0, -0.01] * 20)
|
|
# Benchmark with zero could cause division issues
|
|
benchmark = pd.Series([0.01, 0.0, 0.01, 0.02, 0.01] * 20)
|
|
rs = _calculate_relative_strength(returns, benchmark)
|
|
# Should complete without error
|
|
assert len(rs) >= 0
|
|
|
|
def test_single_value_series(self):
|
|
"""Test single value series handling."""
|
|
series1 = pd.Series([5])
|
|
series2 = pd.Series([10])
|
|
corr = _calculate_correlation(series1, series2)
|
|
assert corr == 0.0 # Insufficient data
|
|
|
|
def test_all_same_values(self):
|
|
"""Test series with all same values."""
|
|
series1 = pd.Series([5, 5, 5, 5, 5])
|
|
series2 = pd.Series([1, 2, 3, 4, 5])
|
|
corr = _calculate_correlation(series1, series2)
|
|
assert corr == 0.0 # Zero std
|
|
|
|
|
|
class TestIntegration:
|
|
"""Integration tests for combined functionality."""
|
|
|
|
def test_full_correlation_workflow(self, correlated_returns):
|
|
"""Test full correlation analysis workflow."""
|
|
series1, series2 = correlated_returns
|
|
|
|
# Calculate correlation
|
|
corr = _calculate_correlation(series1, series2)
|
|
assert corr > 0.5
|
|
|
|
# Classify strength
|
|
strength = _classify_correlation(corr)
|
|
assert strength in [CorrelationStrength.STRONG_POSITIVE, CorrelationStrength.VERY_STRONG_POSITIVE]
|
|
|
|
# Calculate rolling
|
|
rolling = _calculate_rolling_correlation(series1, series2)
|
|
assert len(rolling) > 0
|
|
|
|
# Check for breakdown
|
|
breakdown = _detect_correlation_breakdown(rolling)
|
|
assert "detected" in breakdown
|
|
|
|
def test_sector_rotation_workflow(self, outperforming_sector_returns, benchmark_returns):
|
|
"""Test sector rotation analysis workflow."""
|
|
# Calculate relative strength
|
|
rs = _calculate_relative_strength(outperforming_sector_returns, benchmark_returns)
|
|
|
|
if len(rs) >= 20:
|
|
# Classify leadership
|
|
leadership = _classify_sector_leadership(rs)
|
|
assert leadership in list(SectorLeadership)
|
|
|
|
# Get cycle phase
|
|
indicators = {'pmi': 55, 'leading_index': 0.3, 'yield_curve_slope': 0.2}
|
|
phase = _identify_cycle_phase(indicators)
|
|
assert phase in list(SectorPhase)
|
|
|
|
# Get recommendations
|
|
recs = _get_cycle_sector_recommendations(phase)
|
|
assert len(recs["overweight"]) > 0
|
|
assert len(recs["underweight"]) > 0
|
|
|
|
def test_cross_asset_regime_workflow(self):
|
|
"""Test cross-asset regime interpretation workflow."""
|
|
# Normal market regime
|
|
interpretation = _interpret_cross_asset_correlation(
|
|
stock_bond_corr=-0.4,
|
|
stock_gold_corr=-0.2,
|
|
stock_oil_corr=0.3
|
|
)
|
|
assert "NORMAL" in interpretation
|
|
|
|
# Risk-off regime
|
|
interpretation = _interpret_cross_asset_correlation(
|
|
stock_bond_corr=0.5,
|
|
stock_gold_corr=-0.4,
|
|
stock_oil_corr=-0.3
|
|
)
|
|
assert "RISK-OFF" in interpretation
|
|
assert "HEDGING" in interpretation
|
|
|
|
|
|
class TestFactoryExpectations:
|
|
"""Tests for factory function expectations."""
|
|
|
|
def test_expected_tools_list(self):
|
|
"""Test expected correlation analyst tools."""
|
|
expected_tools = [
|
|
"get_cross_asset_correlation_analysis",
|
|
"get_sector_rotation_analysis",
|
|
"get_correlation_matrix",
|
|
"get_rolling_correlation_trend"
|
|
]
|
|
# Just verify the expected tool names are defined
|
|
for tool_name in expected_tools:
|
|
assert len(tool_name) > 0
|
|
|
|
def test_factory_expected_signature(self):
|
|
"""Test factory should accept LLM parameter."""
|
|
# Factory should be callable with LLM
|
|
# Just verify the pattern exists
|
|
def mock_factory(llm):
|
|
return lambda state: {"messages": [], "correlation_report": ""}
|
|
|
|
# Should work without error
|
|
mock_llm = Mock()
|
|
node = mock_factory(mock_llm)
|
|
result = node({})
|
|
assert "correlation_report" in result
|