"""Tests for Momentum Analyst agent. Issue #13: [AGENT-12] Momentum Analyst - multi-TF momentum, ROC, ADX These tests mock langchain dependencies to run without requiring the full langchain installation. """ import pytest import pandas as pd import numpy as np from datetime import datetime, timedelta from unittest.mock import Mock, patch, MagicMock from io import StringIO import sys pytestmark = pytest.mark.unit # ============================================================================ # Mock LangChain Dependencies # ============================================================================ # Create mock modules for langchain dependencies mock_langchain_core = MagicMock() mock_langchain_core.prompts = MagicMock() mock_langchain_core.prompts.ChatPromptTemplate = MagicMock() mock_langchain_core.prompts.MessagesPlaceholder = MagicMock() mock_langchain_core.tools = MagicMock() mock_langchain_core.tools.tool = lambda f: f # Simple passthrough decorator mock_langchain_core.messages = MagicMock() # Patch the modules before importing sys.modules['langchain_core'] = mock_langchain_core sys.modules['langchain_core.prompts'] = mock_langchain_core.prompts sys.modules['langchain_core.tools'] = mock_langchain_core.tools sys.modules['langchain_core.messages'] = mock_langchain_core.messages # ============================================================================ # Fixtures # ============================================================================ @pytest.fixture def sample_stock_data(): """Create sample stock data DataFrame.""" dates = pd.date_range(end=datetime.now(), periods=60, freq='D') # Create uptrending data base_price = 100.0 prices = base_price + np.cumsum(np.random.randn(60) * 0.5 + 0.1) return pd.DataFrame({ 'Date': dates, 'open': prices * 0.99, 'high': prices * 1.01, 'low': prices * 0.98, 'close': prices, 'volume': np.random.randint(1000000, 5000000, 60) }) @pytest.fixture def uptrending_stock_data(): """Create sample uptrending stock data.""" dates = pd.date_range(end=datetime.now(), periods=60, freq='D') base_price = 100.0 # Strong uptrend prices = base_price + np.linspace(0, 20, 60) return pd.DataFrame({ 'Date': dates, 'open': prices * 0.99, 'high': prices * 1.01, 'low': prices * 0.98, 'close': prices, 'volume': np.random.randint(1000000, 5000000, 60) }) @pytest.fixture def downtrending_stock_data(): """Create sample downtrending stock data.""" dates = pd.date_range(end=datetime.now(), periods=60, freq='D') base_price = 100.0 # Strong downtrend prices = base_price - np.linspace(0, 20, 60) return pd.DataFrame({ 'Date': dates, 'open': prices * 1.01, 'high': prices * 1.02, 'low': prices * 0.99, 'close': prices, 'volume': np.random.randint(1000000, 5000000, 60) }) @pytest.fixture def ranging_stock_data(): """Create sample sideways/ranging stock data.""" dates = pd.date_range(end=datetime.now(), periods=60, freq='D') base_price = 100.0 # Oscillating prices prices = base_price + np.sin(np.linspace(0, 4*np.pi, 60)) * 2 return pd.DataFrame({ 'Date': dates, 'open': prices * 0.995, 'high': prices * 1.01, 'low': prices * 0.99, 'close': prices, 'volume': np.random.randint(1000000, 5000000, 60) }) # ============================================================================ # Helper Functions - Extracted from momentum_analyst.py for testing # ============================================================================ def calculate_roc(close_prices, period): """Calculate Rate of Change.""" if len(close_prices) < period: return None current = close_prices[-1] past = close_prices[-period] if past == 0: return 0 return ((current - past) / past) * 100 def calculate_adx(high, low, close, period=14): """Calculate ADX, +DI, -DI.""" if len(high) < period * 2: return None, None, None # True Range tr1 = high - low tr2 = abs(high - close.shift(1)) tr3 = abs(low - close.shift(1)) tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) # Directional Movement plus_dm = high.diff() minus_dm = -low.diff() plus_dm = plus_dm.where((plus_dm > minus_dm) & (plus_dm > 0), 0) minus_dm = minus_dm.where((minus_dm > plus_dm) & (minus_dm > 0), 0) # Smooth with EMA atr = tr.ewm(span=period, adjust=False).mean() plus_di = 100 * (plus_dm.ewm(span=period, adjust=False).mean() / atr) minus_di = 100 * (minus_dm.ewm(span=period, adjust=False).mean() / atr) # ADX dx = 100 * abs(plus_di - minus_di) / (plus_di + minus_di) adx = dx.ewm(span=period, adjust=False).mean() return adx.iloc[-1], plus_di.iloc[-1], minus_di.iloc[-1] def calculate_rsi(close_prices, period=14): """Calculate RSI.""" delta = pd.Series(close_prices).diff() gain = delta.where(delta > 0, 0).rolling(window=period).mean() loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean() rs = gain / loss rsi = 100 - (100 / (1 + rs)) return rsi.values # ============================================================================ # ROC (Rate of Change) Tests # ============================================================================ class TestROCCalculation: """Tests for Rate of Change calculation.""" def test_roc_positive(self): """Test ROC for uptrending prices.""" # 6 prices: indices 0-5, period=5 means: (price[5] - price[0]) / price[0] # (105 - 100) / 100 = 5% prices = [100, 101, 102, 103, 104, 105] roc = calculate_roc(prices, 5) # ROC uses close[-1] vs close[-period]: (105 - 101) / 101 = 3.96% assert roc > 0 assert roc == pytest.approx(3.96, rel=0.02) def test_roc_negative(self): """Test ROC for downtrending prices.""" prices = [100, 99, 98, 97, 96, 95] roc = calculate_roc(prices, 5) # ROC: (95 - 99) / 99 = -4.04% assert roc < 0 assert roc == pytest.approx(-4.04, rel=0.02) def test_roc_zero(self): """Test ROC for flat prices.""" prices = [100, 100, 100, 100, 100, 100] roc = calculate_roc(prices, 5) assert roc == pytest.approx(0.0) def test_roc_insufficient_data(self): """Test ROC with insufficient data.""" prices = [100, 101, 102] roc = calculate_roc(prices, 5) assert roc is None def test_roc_strong_move(self): """Test ROC for strong price move.""" prices = [100, 105, 110, 115, 120, 125] roc = calculate_roc(prices, 5) # ROC: (125 - 105) / 105 = 19.05% assert roc > 15 assert roc == pytest.approx(19.05, rel=0.02) class TestMultiTimeframeMomentum: """Tests for multi-timeframe momentum analysis.""" def test_bullish_alignment(self, uptrending_stock_data): """Test all timeframes showing bullish momentum.""" close = uptrending_stock_data['close'].values roc_short = calculate_roc(close, 5) roc_medium = calculate_roc(close, 14) roc_long = calculate_roc(close, 30) assert roc_short > 0 assert roc_medium > 0 assert roc_long > 0 def test_bearish_alignment(self, downtrending_stock_data): """Test all timeframes showing bearish momentum.""" close = downtrending_stock_data['close'].values roc_short = calculate_roc(close, 5) roc_medium = calculate_roc(close, 14) roc_long = calculate_roc(close, 30) assert roc_short < 0 assert roc_medium < 0 assert roc_long < 0 def test_mixed_signals(self, ranging_stock_data): """Test mixed momentum signals in ranging market.""" close = ranging_stock_data['close'].values roc_short = calculate_roc(close, 5) roc_medium = calculate_roc(close, 14) roc_long = calculate_roc(close, 30) # At least one should differ in sign for mixed signals signs = [roc_short > 0, roc_medium > 0, roc_long > 0] # Not all same sign not_all_bullish = not all(signs) not_all_bearish = not all(not s for s in signs) # Mixed means at least one is different assert not_all_bullish or not_all_bearish def test_momentum_strength_classification(self): """Test momentum strength classification.""" # Strong bullish: all ROC > 2% assert all(roc > 2 for roc in [3.0, 4.0, 5.0]) # Moderate bullish: all positive but some < 2% rocs = [1.5, 2.5, 3.0] assert all(roc > 0 for roc in rocs) assert any(roc < 2 for roc in rocs) def test_acceleration_detection(self): """Test momentum acceleration detection.""" # Accelerating: short > medium > long rocs = {"short": 5.0, "medium": 3.0, "long": 1.0} is_accelerating = rocs["short"] > rocs["medium"] > rocs["long"] assert is_accelerating # Decelerating: short < medium < long rocs = {"short": 1.0, "medium": 3.0, "long": 5.0} is_decelerating = rocs["short"] < rocs["medium"] < rocs["long"] assert is_decelerating # ============================================================================ # ADX (Average Directional Index) Tests # ============================================================================ class TestADXCalculation: """Tests for ADX calculation.""" def test_adx_strong_trend(self, uptrending_stock_data): """Test ADX for strong uptrend.""" high = pd.Series(uptrending_stock_data['high'].values) low = pd.Series(uptrending_stock_data['low'].values) close = pd.Series(uptrending_stock_data['close'].values) adx, plus_di, minus_di = calculate_adx(high, low, close) # In uptrend, +DI should be greater than -DI assert plus_di > minus_di def test_adx_ranging_market(self, ranging_stock_data): """Test ADX for ranging market.""" high = pd.Series(ranging_stock_data['high'].values) low = pd.Series(ranging_stock_data['low'].values) close = pd.Series(ranging_stock_data['close'].values) adx, plus_di, minus_di = calculate_adx(high, low, close) # In ranging market, ADX tends to be lower # We just verify calculation doesn't fail assert adx is not None assert plus_di is not None assert minus_di is not None def test_adx_downtrend(self, downtrending_stock_data): """Test ADX for downtrend.""" high = pd.Series(downtrending_stock_data['high'].values) low = pd.Series(downtrending_stock_data['low'].values) close = pd.Series(downtrending_stock_data['close'].values) adx, plus_di, minus_di = calculate_adx(high, low, close) # In downtrend, -DI should be greater than +DI assert minus_di > plus_di def test_adx_insufficient_data(self): """Test ADX with insufficient data.""" high = pd.Series([101, 102, 103]) low = pd.Series([99, 100, 101]) close = pd.Series([100, 101, 102]) adx, plus_di, minus_di = calculate_adx(high, low, close) assert adx is None def test_adx_trend_strength_levels(self): """Test ADX trend strength interpretation.""" # Weak trend: ADX < 20 assert 15 < 20 # Represents weak/absent trend # Moderate trend: 20 <= ADX < 40 assert 25 >= 20 and 25 < 40 # Strong trend: 40 <= ADX < 60 assert 50 >= 40 and 50 < 60 # Very strong trend: ADX >= 60 assert 75 >= 60 class TestADXInterpretation: """Tests for ADX interpretation logic.""" def test_trending_vs_ranging(self): """Test classification of trending vs ranging markets.""" # Trending: ADX > 25 assert 30 > 25 # Ranging: ADX < 20 assert 15 < 20 def test_di_crossover_bullish(self): """Test bullish +DI/-DI crossover.""" plus_di_prev, minus_di_prev = 20, 25 plus_di_curr, minus_di_curr = 28, 22 # Bullish crossover: +DI crosses above -DI was_bearish = plus_di_prev < minus_di_prev is_bullish = plus_di_curr > minus_di_curr is_bullish_crossover = was_bearish and is_bullish assert is_bullish_crossover def test_di_crossover_bearish(self): """Test bearish -DI/+DI crossover.""" plus_di_prev, minus_di_prev = 28, 22 plus_di_curr, minus_di_curr = 20, 25 # Bearish crossover: -DI crosses above +DI was_bullish = plus_di_prev > minus_di_prev is_bearish = plus_di_curr < minus_di_curr is_bearish_crossover = was_bullish and is_bearish assert is_bearish_crossover def test_adx_trend_rising(self): """Test ADX rising (trend strengthening).""" adx_prev = 25 adx_curr = 35 trend_strengthening = adx_curr > adx_prev assert trend_strengthening def test_adx_trend_falling(self): """Test ADX falling (trend weakening).""" adx_prev = 45 adx_curr = 35 trend_weakening = adx_curr < adx_prev assert trend_weakening # ============================================================================ # RSI Tests # ============================================================================ class TestRSICalculation: """Tests for RSI calculation.""" def test_rsi_overbought(self): """Test RSI in overbought territory (> 70).""" # Strongly uptrending prices should give high RSI prices = np.linspace(100, 130, 30) rsi = calculate_rsi(prices) # Last valid RSI should be high valid_rsi = rsi[~np.isnan(rsi)] if len(valid_rsi) > 0: assert valid_rsi[-1] > 50 # Should be elevated def test_rsi_oversold(self): """Test RSI in oversold territory (< 30).""" # Strongly downtrending prices should give low RSI prices = np.linspace(130, 100, 30) rsi = calculate_rsi(prices) # Last valid RSI should be low valid_rsi = rsi[~np.isnan(rsi)] if len(valid_rsi) > 0: assert valid_rsi[-1] < 50 # Should be depressed def test_rsi_neutral(self): """Test RSI in neutral territory (~50).""" # Oscillating prices should give neutral RSI prices = 100 + np.sin(np.linspace(0, 4*np.pi, 30)) * 5 rsi = calculate_rsi(prices) valid_rsi = rsi[~np.isnan(rsi)] if len(valid_rsi) > 0: # RSI should be somewhere in the middle for ranging assert 20 < valid_rsi[-1] < 80 # ============================================================================ # Divergence Detection Tests # ============================================================================ class TestDivergenceDetection: """Tests for momentum divergence detection.""" def test_find_local_highs(self): """Test finding local price highs.""" prices = np.array([100, 105, 110, 108, 106, 112, 115, 113, 111]) highs = [] for i in range(2, len(prices) - 2): if (prices[i] > prices[i-1] and prices[i] > prices[i-2] and prices[i] > prices[i+1] and prices[i] > prices[i+2]): highs.append((i, prices[i])) # Should find at least one high assert len(highs) >= 1 def test_find_local_lows(self): """Test finding local price lows.""" prices = np.array([100, 95, 90, 92, 94, 88, 85, 87, 89]) lows = [] for i in range(2, len(prices) - 2): if (prices[i] < prices[i-1] and prices[i] < prices[i-2] and prices[i] < prices[i+1] and prices[i] < prices[i+2]): lows.append((i, prices[i])) # Should find at least one low assert len(lows) >= 1 def test_bullish_divergence_pattern(self): """Test bullish divergence pattern detection.""" # Price: lower low price_low_1 = 90 price_low_2 = 85 # Lower # RSI: higher low (divergence) rsi_low_1 = 25 rsi_low_2 = 30 # Higher is_bullish_divergence = ( price_low_2 < price_low_1 and # Price makes lower low rsi_low_2 > rsi_low_1 # RSI makes higher low ) assert is_bullish_divergence def test_bearish_divergence_pattern(self): """Test bearish divergence pattern detection.""" # Price: higher high price_high_1 = 110 price_high_2 = 115 # Higher # RSI: lower high (divergence) rsi_high_1 = 75 rsi_high_2 = 70 # Lower is_bearish_divergence = ( price_high_2 > price_high_1 and # Price makes higher high rsi_high_2 < rsi_high_1 # RSI makes lower high ) assert is_bearish_divergence def test_no_divergence(self): """Test when no divergence is present.""" # Price and RSI move in sync (no divergence) price_high_1 = 110 price_high_2 = 115 rsi_high_1 = 70 rsi_high_2 = 75 # Also higher (in sync) is_divergence = ( price_high_2 > price_high_1 and rsi_high_2 < rsi_high_1 ) assert not is_divergence # ============================================================================ # Agent Factory Tests (with mocked LangChain) # ============================================================================ class TestMomentumAnalystFactory: """Tests for create_momentum_analyst factory function.""" def test_factory_function_signature(self): """Test that factory function has correct signature.""" # The factory should take an LLM parameter # We'll test the expected behavior pattern def mock_factory(llm): """Mock factory matching expected pattern.""" def node(state): return { "messages": [], "momentum_report": "" } return node mock_llm = Mock() node = mock_factory(mock_llm) assert callable(node) def test_node_returns_correct_structure(self): """Test that node returns expected structure.""" def mock_node(state): return { "messages": [Mock()], "momentum_report": "Test report" } state = { "trade_date": "2024-01-15", "company_of_interest": "AAPL", "messages": [] } result = mock_node(state) assert "messages" in result assert "momentum_report" in result def test_node_processes_trade_date(self): """Test that node correctly processes trade date.""" state = { "trade_date": "2024-01-15", "company_of_interest": "AAPL", "messages": [] } # The node should use trade_date from state assert state["trade_date"] == "2024-01-15" def test_node_processes_ticker(self): """Test that node correctly processes company ticker.""" state = { "trade_date": "2024-01-15", "company_of_interest": "NVDA", "messages": [] } # The node should use company_of_interest from state assert state["company_of_interest"] == "NVDA" def test_expected_tools_list(self): """Test expected tools for momentum analyst.""" expected_tools = [ "get_stock_data", "get_indicators", "get_multi_timeframe_momentum", "get_adx_analysis", "get_momentum_divergence" ] # All expected tools should be present assert len(expected_tools) == 5 assert "get_multi_timeframe_momentum" in expected_tools assert "get_adx_analysis" in expected_tools assert "get_momentum_divergence" in expected_tools # ============================================================================ # Edge Cases and Error Handling Tests # ============================================================================ class TestEdgeCases: """Tests for edge cases and error handling.""" def test_empty_price_array(self): """Test handling of empty price array.""" prices = [] roc = calculate_roc(prices, 5) assert roc is None def test_single_price(self): """Test handling of single price point.""" prices = [100] roc = calculate_roc(prices, 5) assert roc is None def test_zero_base_price(self): """Test handling of zero base price (avoid division by zero).""" prices = [0, 1, 2, 3, 4, 5] roc = calculate_roc(prices, 5) # Should handle zero gracefully assert roc == 0 or roc is not None def test_negative_prices(self): """Test handling of negative prices.""" prices = [-100, -95, -90, -85, -80, -75] roc = calculate_roc(prices, 5) # Should still calculate change assert roc is not None def test_nan_in_prices(self): """Test handling of NaN values in prices.""" prices = np.array([100, 101, np.nan, 103, 104, 105]) rsi = calculate_rsi(prices) # Should produce some result (may contain NaN) assert rsi is not None def test_large_price_move(self): """Test handling of large price moves.""" prices = [100, 100, 100, 100, 100, 200] # 100% move roc = calculate_roc(prices, 5) assert roc == pytest.approx(100.0, rel=0.01) # ============================================================================ # Report Format Tests # ============================================================================ class TestReportFormat: """Tests for expected report format elements.""" def test_momentum_report_sections(self): """Test expected sections in momentum report.""" expected_sections = [ "Multi-Timeframe Momentum Analysis", "Rate of Change", "Momentum Summary", "Interpretation" ] # Verify these are the expected section headers for section in expected_sections: assert isinstance(section, str) def test_adx_report_sections(self): """Test expected sections in ADX report.""" expected_sections = [ "ADX Trend Strength Analysis", "Current Readings", "Analysis Summary", "Trading Recommendation" ] for section in expected_sections: assert isinstance(section, str) def test_divergence_report_sections(self): """Test expected sections in divergence report.""" expected_sections = [ "Momentum Divergence Analysis", "Divergence Status", "Detected Patterns", "Interpretation" ] for section in expected_sections: assert isinstance(section, str) def test_momentum_signals(self): """Test momentum signal classifications.""" signals = ["BULLISH", "BEARISH", "MIXED"] for signal in signals: assert signal in ["BULLISH", "BEARISH", "MIXED"] def test_strength_levels(self): """Test momentum strength level classifications.""" strengths = ["STRONG", "MODERATE", "WEAK"] for strength in strengths: assert strength in ["STRONG", "MODERATE", "WEAK"] # ============================================================================ # Integration Tests # ============================================================================ class TestIntegration: """Integration tests for the momentum analysis workflow.""" def test_full_momentum_analysis_workflow(self, uptrending_stock_data): """Test complete momentum analysis workflow.""" close = uptrending_stock_data['close'].values high = pd.Series(uptrending_stock_data['high'].values) low = pd.Series(uptrending_stock_data['low'].values) close_series = pd.Series(close) # Step 1: Multi-timeframe ROC roc_short = calculate_roc(close, 5) roc_medium = calculate_roc(close, 14) roc_long = calculate_roc(close, 30) assert roc_short is not None assert roc_medium is not None assert roc_long is not None # Step 2: ADX analysis adx, plus_di, minus_di = calculate_adx(high, low, close_series) assert adx is not None assert plus_di is not None assert minus_di is not None # Step 3: RSI for divergence rsi = calculate_rsi(close) assert rsi is not None assert len(rsi) > 0 def test_bearish_workflow(self, downtrending_stock_data): """Test analysis workflow for bearish conditions.""" close = downtrending_stock_data['close'].values high = pd.Series(downtrending_stock_data['high'].values) low = pd.Series(downtrending_stock_data['low'].values) close_series = pd.Series(close) # ROC should be negative roc = calculate_roc(close, 14) assert roc < 0 # -DI should be greater than +DI adx, plus_di, minus_di = calculate_adx(high, low, close_series) assert minus_di > plus_di def test_consistent_analysis_across_data(self, sample_stock_data): """Test that analysis is consistent across same data.""" close = sample_stock_data['close'].values # Run same calculation twice roc1 = calculate_roc(close, 14) roc2 = calculate_roc(close, 14) # Should be identical assert roc1 == roc2