198 lines
6.4 KiB
Python
198 lines
6.4 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
class TestCalculateRsiScore:
|
|
def test_rsi_oversold_bullish_score(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_rsi_score,
|
|
)
|
|
|
|
score = calculate_rsi_score(25.0)
|
|
assert 0.8 <= score <= 1.0
|
|
|
|
score = calculate_rsi_score(30.0)
|
|
assert 0.7 <= score <= 0.9
|
|
|
|
def test_rsi_neutral_score(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_rsi_score,
|
|
)
|
|
|
|
score = calculate_rsi_score(50.0)
|
|
assert 0.5 <= score <= 0.7
|
|
|
|
score = calculate_rsi_score(40.0)
|
|
assert 0.5 <= score <= 0.7
|
|
|
|
def test_rsi_overbought_warning_score(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_rsi_score,
|
|
)
|
|
|
|
score = calculate_rsi_score(70.0)
|
|
assert 0.3 <= score <= 0.5
|
|
|
|
score = calculate_rsi_score(75.0)
|
|
assert 0.2 <= score <= 0.5
|
|
|
|
def test_rsi_extreme_overbought_score(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_rsi_score,
|
|
)
|
|
|
|
score = calculate_rsi_score(85.0)
|
|
assert 0.0 <= score <= 0.3
|
|
|
|
score = calculate_rsi_score(95.0)
|
|
assert 0.0 <= score <= 0.2
|
|
|
|
|
|
class TestCalculateMacdScore:
|
|
def test_macd_bullish_crossover_high_score(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_macd_score,
|
|
)
|
|
|
|
score = calculate_macd_score(macd=1.5, signal=1.0, histogram=0.5)
|
|
assert 0.7 <= score <= 1.0
|
|
|
|
def test_macd_bearish_crossover_low_score(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_macd_score,
|
|
)
|
|
|
|
score = calculate_macd_score(macd=-1.5, signal=-1.0, histogram=-0.5)
|
|
assert 0.0 <= score <= 0.4
|
|
|
|
def test_macd_neutral_score(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_macd_score,
|
|
)
|
|
|
|
score = calculate_macd_score(macd=0.1, signal=0.1, histogram=0.0)
|
|
assert 0.4 <= score <= 0.6
|
|
|
|
def test_macd_expanding_histogram_bullish(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_macd_score,
|
|
)
|
|
|
|
score = calculate_macd_score(macd=2.0, signal=1.5, histogram=0.8)
|
|
assert 0.6 <= score <= 1.0
|
|
|
|
|
|
class TestCalculateSmaScore:
|
|
def test_price_above_both_smas_high_score(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_sma_score,
|
|
)
|
|
|
|
score = calculate_sma_score(price=150.0, sma50=140.0, sma200=130.0)
|
|
assert 0.7 <= score <= 1.0
|
|
|
|
def test_price_below_both_smas_low_score(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_sma_score,
|
|
)
|
|
|
|
score = calculate_sma_score(price=120.0, sma50=140.0, sma200=150.0)
|
|
assert 0.0 <= score <= 0.4
|
|
|
|
def test_price_between_smas_moderate_score(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_sma_score,
|
|
)
|
|
|
|
score = calculate_sma_score(price=145.0, sma50=150.0, sma200=130.0)
|
|
assert 0.4 <= score <= 0.7
|
|
|
|
def test_golden_alignment_bonus(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_sma_score,
|
|
)
|
|
|
|
score = calculate_sma_score(price=160.0, sma50=150.0, sma200=140.0)
|
|
assert score >= 0.8
|
|
|
|
|
|
class TestCalculateEmaDirection:
|
|
def test_ema_upward_trend(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_ema_direction,
|
|
)
|
|
|
|
ema_values = [100.0, 102.0, 104.0, 106.0, 108.0]
|
|
direction = calculate_ema_direction(ema_values)
|
|
assert direction == "up"
|
|
|
|
def test_ema_downward_trend(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_ema_direction,
|
|
)
|
|
|
|
ema_values = [108.0, 106.0, 104.0, 102.0, 100.0]
|
|
direction = calculate_ema_direction(ema_values)
|
|
assert direction == "down"
|
|
|
|
def test_ema_flat_trend(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_ema_direction,
|
|
)
|
|
|
|
ema_values = [100.0, 100.1, 99.9, 100.0, 100.1]
|
|
direction = calculate_ema_direction(ema_values)
|
|
assert direction == "flat"
|
|
|
|
def test_ema_empty_list_returns_flat(self):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_ema_direction,
|
|
)
|
|
|
|
ema_values = []
|
|
direction = calculate_ema_direction(ema_values)
|
|
assert direction == "flat"
|
|
|
|
|
|
class TestCalculateMomentumScore:
|
|
@patch("tradingagents.agents.discovery.indicators.momentum._get_stock_stats_bulk")
|
|
def test_calculate_momentum_score_returns_dict(self, mock_get_stats):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_momentum_score,
|
|
)
|
|
|
|
mock_get_stats.side_effect = lambda symbol, indicator, date: {
|
|
"2024-01-15": "50.0" if indicator == "rsi" else "1.0",
|
|
"2024-01-14": "49.0" if indicator == "rsi" else "0.9",
|
|
"2024-01-13": "48.0" if indicator == "rsi" else "0.8",
|
|
"2024-01-12": "47.0" if indicator == "rsi" else "0.7",
|
|
"2024-01-11": "46.0" if indicator == "rsi" else "0.6",
|
|
}
|
|
|
|
result = calculate_momentum_score("AAPL", "2024-01-15")
|
|
|
|
assert isinstance(result, dict)
|
|
assert "rsi" in result
|
|
assert "macd" in result
|
|
assert "macd_signal" in result
|
|
assert "macd_histogram" in result
|
|
assert "price_vs_sma50" in result
|
|
assert "price_vs_sma200" in result
|
|
assert "ema10_direction" in result
|
|
assert "momentum_score" in result
|
|
assert 0.0 <= result["momentum_score"] <= 1.0
|
|
|
|
@patch("tradingagents.agents.discovery.indicators.momentum._get_stock_stats_bulk")
|
|
def test_calculate_momentum_score_handles_missing_data(self, mock_get_stats):
|
|
from tradingagents.agents.discovery.indicators.momentum import (
|
|
calculate_momentum_score,
|
|
)
|
|
|
|
mock_get_stats.return_value = {}
|
|
|
|
result = calculate_momentum_score("INVALID", "2024-01-15")
|
|
|
|
assert isinstance(result, dict)
|
|
assert result["momentum_score"] == 0.5
|