TradingAgents/tests/discovery/test_support_resistance.py

225 lines
6.7 KiB
Python

from unittest.mock import MagicMock, patch
import pytest
class TestFindSupportLevels:
def test_find_support_from_lows(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
find_support_levels,
)
lows = (
[100.0, 98.0, 97.0, 99.0, 96.0, 95.0, 97.0, 98.0, 94.0, 93.0]
+ [95.0, 96.0, 94.0, 93.0, 92.0, 91.0, 90.0, 89.0, 88.0, 87.0]
+ [88.0, 89.0, 87.0, 86.0, 85.0] * 6
)
support_20d, support_50d = find_support_levels(
lows, lookback_20=20, lookback_50=50
)
assert support_20d <= 100.0
assert support_50d <= 100.0
def test_find_support_empty_list(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
find_support_levels,
)
lows = []
support_20d, support_50d = find_support_levels(
lows, lookback_20=20, lookback_50=50
)
assert support_20d == 0.0
assert support_50d == 0.0
class TestFindResistanceLevels:
def test_find_resistance_from_highs(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
find_resistance_levels,
)
highs = (
[100.0, 102.0, 105.0, 103.0, 108.0, 106.0, 110.0, 107.0, 112.0, 109.0]
+ [111.0, 113.0, 115.0, 114.0, 117.0, 116.0, 120.0, 118.0, 122.0, 119.0]
+ [120.0, 121.0, 123.0, 124.0, 125.0] * 6
)
resistance_20d, resistance_50d = find_resistance_levels(
highs, lookback_20=20, lookback_50=50
)
assert resistance_20d >= 100.0
assert resistance_50d >= 100.0
def test_find_resistance_empty_list(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
find_resistance_levels,
)
highs = []
resistance_20d, resistance_50d = find_resistance_levels(
highs, lookback_20=20, lookback_50=50
)
assert resistance_20d == 0.0
assert resistance_50d == 0.0
class TestDetectSwingPoints:
def test_detect_swing_highs_and_lows(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
detect_swing_points,
)
prices = [
100,
102,
105,
103,
98,
95,
97,
100,
103,
107,
105,
102,
99,
96,
94,
97,
101,
]
swing_lows, swing_highs = detect_swing_points(prices, n_bars=3)
assert len(swing_lows) >= 0
assert len(swing_highs) >= 0
def test_detect_swing_points_short_list(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
detect_swing_points,
)
prices = [100, 105, 102]
swing_lows, swing_highs = detect_swing_points(prices, n_bars=3)
assert isinstance(swing_lows, list)
assert isinstance(swing_highs, list)
def test_detect_swing_points_empty_list(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
detect_swing_points,
)
prices = []
swing_lows, swing_highs = detect_swing_points(prices, n_bars=5)
assert swing_lows == []
assert swing_highs == []
class TestGetNearestLevels:
def test_get_nearest_support_and_resistance(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
get_nearest_levels,
)
price = 150.0
supports = [140.0, 145.0, 130.0, 120.0]
resistances = [155.0, 160.0, 170.0, 180.0]
nearest_support, nearest_resistance = get_nearest_levels(
price, supports, resistances
)
assert nearest_support == 145.0
assert nearest_resistance == 155.0
def test_get_nearest_levels_no_support_below(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
get_nearest_levels,
)
price = 100.0
supports = [110.0, 120.0, 130.0]
resistances = [150.0, 160.0]
nearest_support, nearest_resistance = get_nearest_levels(
price, supports, resistances
)
assert nearest_support == 0.0
assert nearest_resistance == 150.0
def test_get_nearest_levels_empty_lists(self):
from tradingagents.agents.discovery.indicators.support_resistance import (
get_nearest_levels,
)
price = 150.0
supports = []
resistances = []
nearest_support, nearest_resistance = get_nearest_levels(
price, supports, resistances
)
assert nearest_support == 0.0
assert nearest_resistance == 0.0
class TestCalculateSupportResistanceMetrics:
@patch(
"tradingagents.agents.discovery.indicators.support_resistance._get_ohlc_data"
)
@patch("tradingagents.agents.discovery.indicators.support_resistance._get_atr")
def test_calculate_sr_metrics_returns_dict(self, mock_atr, mock_ohlc):
from tradingagents.agents.discovery.indicators.support_resistance import (
calculate_support_resistance_metrics,
)
mock_ohlc.return_value = {
"highs": [105.0 + i * 0.5 for i in range(60)],
"lows": [95.0 + i * 0.3 for i in range(60)],
"closes": [100.0 + i * 0.4 for i in range(60)],
"current_price": 125.0,
}
mock_atr.return_value = 2.5
result = calculate_support_resistance_metrics("AAPL", "2024-01-15")
assert isinstance(result, dict)
assert "support_level" in result
assert "resistance_level" in result
assert "atr" in result
assert "suggested_stop" in result
assert "reward_target" in result
assert "risk_reward_ratio" in result
assert "risk_reward_score" in result
@patch(
"tradingagents.agents.discovery.indicators.support_resistance._get_ohlc_data"
)
@patch("tradingagents.agents.discovery.indicators.support_resistance._get_atr")
def test_calculate_sr_metrics_handles_missing_data(self, mock_atr, mock_ohlc):
from tradingagents.agents.discovery.indicators.support_resistance import (
calculate_support_resistance_metrics,
)
mock_ohlc.return_value = {
"highs": [],
"lows": [],
"closes": [],
"current_price": None,
}
mock_atr.return_value = None
result = calculate_support_resistance_metrics("INVALID", "2024-01-15")
assert isinstance(result, dict)
assert result["risk_reward_score"] == 0.5