779 lines
28 KiB
Python
779 lines
28 KiB
Python
"""Tests for Strategy Comparator.
|
|
|
|
Issue #34: [SIM-33] Strategy comparator - performance comparison, stats
|
|
|
|
Tests cover:
|
|
- RankingCriteria and ComparisonStatus enums
|
|
- StrategyMetrics dataclass
|
|
- PairwiseComparison and ComparisonResult dataclasses
|
|
- StrategyComparator comparison logic
|
|
- Statistical significance testing
|
|
- Ranking and recommendations
|
|
- Edge cases
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import date
|
|
from decimal import Decimal
|
|
|
|
from tradingagents.simulation.strategy_comparator import (
|
|
RankingCriteria,
|
|
ComparisonStatus,
|
|
StrategyMetrics,
|
|
PairwiseComparison,
|
|
ComparisonResult,
|
|
StrategyComparator,
|
|
)
|
|
|
|
|
|
# ==============================================================================
|
|
# RankingCriteria Enum Tests
|
|
# ==============================================================================
|
|
|
|
|
|
class TestRankingCriteria:
|
|
"""Tests for RankingCriteria enum."""
|
|
|
|
def test_total_return_value(self):
|
|
"""Test TOTAL_RETURN criterion value."""
|
|
assert RankingCriteria.TOTAL_RETURN.value == "total_return"
|
|
|
|
def test_sharpe_ratio_value(self):
|
|
"""Test SHARPE_RATIO criterion value."""
|
|
assert RankingCriteria.SHARPE_RATIO.value == "sharpe_ratio"
|
|
|
|
def test_sortino_ratio_value(self):
|
|
"""Test SORTINO_RATIO criterion value."""
|
|
assert RankingCriteria.SORTINO_RATIO.value == "sortino_ratio"
|
|
|
|
def test_max_drawdown_value(self):
|
|
"""Test MAX_DRAWDOWN criterion value."""
|
|
assert RankingCriteria.MAX_DRAWDOWN.value == "max_drawdown"
|
|
|
|
def test_win_rate_value(self):
|
|
"""Test WIN_RATE criterion value."""
|
|
assert RankingCriteria.WIN_RATE.value == "win_rate"
|
|
|
|
def test_profit_factor_value(self):
|
|
"""Test PROFIT_FACTOR criterion value."""
|
|
assert RankingCriteria.PROFIT_FACTOR.value == "profit_factor"
|
|
|
|
def test_all_criteria_exist(self):
|
|
"""Test all expected criteria exist."""
|
|
criteria = [c for c in RankingCriteria]
|
|
assert len(criteria) == 8
|
|
|
|
|
|
# ==============================================================================
|
|
# ComparisonStatus Enum Tests
|
|
# ==============================================================================
|
|
|
|
|
|
class TestComparisonStatus:
|
|
"""Tests for ComparisonStatus enum."""
|
|
|
|
def test_valid_value(self):
|
|
"""Test VALID status value."""
|
|
assert ComparisonStatus.VALID.value == "valid"
|
|
|
|
def test_insufficient_data_value(self):
|
|
"""Test INSUFFICIENT_DATA status value."""
|
|
assert ComparisonStatus.INSUFFICIENT_DATA.value == "insufficient_data"
|
|
|
|
def test_incomparable_value(self):
|
|
"""Test INCOMPARABLE status value."""
|
|
assert ComparisonStatus.INCOMPARABLE.value == "incomparable"
|
|
|
|
|
|
# ==============================================================================
|
|
# StrategyMetrics Tests
|
|
# ==============================================================================
|
|
|
|
|
|
class TestStrategyMetrics:
|
|
"""Tests for StrategyMetrics dataclass."""
|
|
|
|
def test_create_basic_metrics(self):
|
|
"""Test creating basic strategy metrics."""
|
|
metrics = StrategyMetrics(
|
|
strategy_id="strat1",
|
|
strategy_name="Momentum",
|
|
total_return=Decimal("0.25"),
|
|
sharpe_ratio=Decimal("1.5"),
|
|
)
|
|
assert metrics.strategy_id == "strat1"
|
|
assert metrics.strategy_name == "Momentum"
|
|
assert metrics.total_return == Decimal("0.25")
|
|
assert metrics.sharpe_ratio == Decimal("1.5")
|
|
|
|
def test_risk_adjusted_return(self):
|
|
"""Test risk-adjusted return calculation."""
|
|
metrics = StrategyMetrics(
|
|
strategy_id="strat1",
|
|
strategy_name="Test",
|
|
total_return=Decimal("0.20"),
|
|
volatility=Decimal("0.10"),
|
|
)
|
|
assert metrics.risk_adjusted_return == Decimal("2.0000")
|
|
|
|
def test_risk_adjusted_return_zero_volatility(self):
|
|
"""Test risk-adjusted return with zero volatility."""
|
|
metrics = StrategyMetrics(
|
|
strategy_id="strat1",
|
|
strategy_name="Test",
|
|
total_return=Decimal("0.20"),
|
|
volatility=Decimal("0"),
|
|
)
|
|
assert metrics.risk_adjusted_return == Decimal("0")
|
|
|
|
def test_has_sufficient_data_trades(self):
|
|
"""Test sufficient data check with trades."""
|
|
metrics = StrategyMetrics(
|
|
strategy_id="strat1",
|
|
strategy_name="Test",
|
|
total_trades=15,
|
|
)
|
|
assert metrics.has_sufficient_data is True
|
|
|
|
def test_has_sufficient_data_returns(self):
|
|
"""Test sufficient data check with return series."""
|
|
metrics = StrategyMetrics(
|
|
strategy_id="strat1",
|
|
strategy_name="Test",
|
|
total_trades=5,
|
|
returns_series=[Decimal("0.01")] * 35,
|
|
)
|
|
assert metrics.has_sufficient_data is True
|
|
|
|
def test_insufficient_data(self):
|
|
"""Test insufficient data detection."""
|
|
metrics = StrategyMetrics(
|
|
strategy_id="strat1",
|
|
strategy_name="Test",
|
|
total_trades=5,
|
|
returns_series=[Decimal("0.01")] * 10,
|
|
)
|
|
assert metrics.has_sufficient_data is False
|
|
|
|
|
|
# ==============================================================================
|
|
# PairwiseComparison Tests
|
|
# ==============================================================================
|
|
|
|
|
|
class TestPairwiseComparison:
|
|
"""Tests for PairwiseComparison dataclass."""
|
|
|
|
def test_create_comparison(self):
|
|
"""Test creating a pairwise comparison."""
|
|
comparison = PairwiseComparison(
|
|
strategy_a_id="strat1",
|
|
strategy_b_id="strat2",
|
|
winner="strat1",
|
|
return_difference=Decimal("0.05"),
|
|
sharpe_difference=Decimal("0.3"),
|
|
)
|
|
assert comparison.strategy_a_id == "strat1"
|
|
assert comparison.strategy_b_id == "strat2"
|
|
assert comparison.winner == "strat1"
|
|
assert comparison.return_difference == Decimal("0.05")
|
|
|
|
def test_no_winner(self):
|
|
"""Test comparison with no clear winner."""
|
|
comparison = PairwiseComparison(
|
|
strategy_a_id="strat1",
|
|
strategy_b_id="strat2",
|
|
winner=None,
|
|
return_difference=Decimal("0.01"),
|
|
)
|
|
assert comparison.winner is None
|
|
|
|
|
|
# ==============================================================================
|
|
# ComparisonResult Tests
|
|
# ==============================================================================
|
|
|
|
|
|
class TestComparisonResult:
|
|
"""Tests for ComparisonResult dataclass."""
|
|
|
|
def test_strategy_count(self):
|
|
"""Test strategy count property."""
|
|
result = ComparisonResult(
|
|
status=ComparisonStatus.VALID,
|
|
strategies=[
|
|
StrategyMetrics(strategy_id="s1", strategy_name="S1"),
|
|
StrategyMetrics(strategy_id="s2", strategy_name="S2"),
|
|
],
|
|
)
|
|
assert result.strategy_count == 2
|
|
|
|
def test_empty_result(self):
|
|
"""Test empty comparison result."""
|
|
result = ComparisonResult(status=ComparisonStatus.INSUFFICIENT_DATA)
|
|
assert result.strategy_count == 0
|
|
assert result.best_overall is None
|
|
|
|
|
|
# ==============================================================================
|
|
# StrategyComparator Tests - Basic Operations
|
|
# ==============================================================================
|
|
|
|
|
|
class TestStrategyComparatorBasic:
|
|
"""Tests for StrategyComparator basic operations."""
|
|
|
|
def test_add_strategy(self):
|
|
"""Test adding a strategy."""
|
|
comparator = StrategyComparator()
|
|
strategy = StrategyMetrics(
|
|
strategy_id="strat1",
|
|
strategy_name="Test",
|
|
total_return=Decimal("0.20"),
|
|
)
|
|
comparator.add_strategy(strategy)
|
|
assert comparator.get_strategy("strat1") is not None
|
|
|
|
def test_remove_strategy(self):
|
|
"""Test removing a strategy."""
|
|
comparator = StrategyComparator()
|
|
strategy = StrategyMetrics(strategy_id="strat1", strategy_name="Test")
|
|
comparator.add_strategy(strategy)
|
|
assert comparator.remove_strategy("strat1") is True
|
|
assert comparator.get_strategy("strat1") is None
|
|
|
|
def test_remove_nonexistent_strategy(self):
|
|
"""Test removing a nonexistent strategy."""
|
|
comparator = StrategyComparator()
|
|
assert comparator.remove_strategy("nonexistent") is False
|
|
|
|
def test_get_all_strategies(self):
|
|
"""Test getting all strategies."""
|
|
comparator = StrategyComparator()
|
|
comparator.add_strategy(StrategyMetrics(strategy_id="s1", strategy_name="S1"))
|
|
comparator.add_strategy(StrategyMetrics(strategy_id="s2", strategy_name="S2"))
|
|
strategies = comparator.get_all_strategies()
|
|
assert len(strategies) == 2
|
|
|
|
def test_clear(self):
|
|
"""Test clearing all strategies."""
|
|
comparator = StrategyComparator()
|
|
comparator.add_strategy(StrategyMetrics(strategy_id="s1", strategy_name="S1"))
|
|
comparator.clear()
|
|
assert len(comparator.get_all_strategies()) == 0
|
|
|
|
|
|
# ==============================================================================
|
|
# StrategyComparator Tests - Comparison
|
|
# ==============================================================================
|
|
|
|
|
|
class TestStrategyComparatorComparison:
|
|
"""Tests for StrategyComparator comparison logic."""
|
|
|
|
def test_compare_empty(self):
|
|
"""Test comparing with no strategies."""
|
|
comparator = StrategyComparator()
|
|
result = comparator.compare()
|
|
assert result.status == ComparisonStatus.INSUFFICIENT_DATA
|
|
|
|
def test_compare_single_strategy(self):
|
|
"""Test comparing with single strategy."""
|
|
comparator = StrategyComparator()
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="strat1",
|
|
strategy_name="Test",
|
|
total_return=Decimal("0.20"),
|
|
))
|
|
result = comparator.compare()
|
|
assert result.status == ComparisonStatus.INSUFFICIENT_DATA
|
|
assert result.best_overall == "strat1"
|
|
assert result.worst_overall == "strat1"
|
|
|
|
def test_compare_two_strategies(self):
|
|
"""Test comparing two strategies."""
|
|
comparator = StrategyComparator()
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="strat1",
|
|
strategy_name="Momentum",
|
|
total_return=Decimal("0.25"),
|
|
sharpe_ratio=Decimal("1.5"),
|
|
volatility=Decimal("0.15"),
|
|
max_drawdown=Decimal("-0.12"),
|
|
))
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="strat2",
|
|
strategy_name="Value",
|
|
total_return=Decimal("0.18"),
|
|
sharpe_ratio=Decimal("1.2"),
|
|
volatility=Decimal("0.12"),
|
|
max_drawdown=Decimal("-0.08"),
|
|
))
|
|
result = comparator.compare()
|
|
assert result.status == ComparisonStatus.VALID
|
|
assert result.strategy_count == 2
|
|
assert len(result.pairwise_comparisons) == 1
|
|
|
|
def test_compare_by_sharpe(self):
|
|
"""Test comparison by Sharpe ratio."""
|
|
comparator = StrategyComparator()
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="high_sharpe",
|
|
strategy_name="High Sharpe",
|
|
total_return=Decimal("0.15"),
|
|
sharpe_ratio=Decimal("2.0"),
|
|
))
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="low_sharpe",
|
|
strategy_name="Low Sharpe",
|
|
total_return=Decimal("0.25"),
|
|
sharpe_ratio=Decimal("1.0"),
|
|
))
|
|
result = comparator.compare(primary_criteria=RankingCriteria.SHARPE_RATIO)
|
|
assert result.best_overall == "high_sharpe"
|
|
|
|
def test_compare_by_return(self):
|
|
"""Test comparison by total return."""
|
|
comparator = StrategyComparator()
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="high_return",
|
|
strategy_name="High Return",
|
|
total_return=Decimal("0.30"),
|
|
sharpe_ratio=Decimal("1.0"),
|
|
))
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="low_return",
|
|
strategy_name="Low Return",
|
|
total_return=Decimal("0.10"),
|
|
sharpe_ratio=Decimal("2.0"),
|
|
))
|
|
result = comparator.compare(primary_criteria=RankingCriteria.TOTAL_RETURN)
|
|
assert result.best_overall == "high_return"
|
|
|
|
def test_rankings_all_criteria(self):
|
|
"""Test rankings for all criteria."""
|
|
comparator = StrategyComparator()
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="s1",
|
|
strategy_name="Strategy 1",
|
|
total_return=Decimal("0.20"),
|
|
sharpe_ratio=Decimal("1.5"),
|
|
sortino_ratio=Decimal("2.0"),
|
|
max_drawdown=Decimal("-0.10"),
|
|
win_rate=Decimal("0.55"),
|
|
profit_factor=Decimal("1.5"),
|
|
calmar_ratio=Decimal("2.0"),
|
|
volatility=Decimal("0.10"),
|
|
))
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="s2",
|
|
strategy_name="Strategy 2",
|
|
total_return=Decimal("0.15"),
|
|
sharpe_ratio=Decimal("1.8"),
|
|
sortino_ratio=Decimal("2.5"),
|
|
max_drawdown=Decimal("-0.08"),
|
|
win_rate=Decimal("0.60"),
|
|
profit_factor=Decimal("1.8"),
|
|
calmar_ratio=Decimal("1.8"),
|
|
volatility=Decimal("0.08"),
|
|
))
|
|
result = comparator.compare()
|
|
assert result.status == ComparisonStatus.VALID
|
|
assert len(result.rankings) == 8 # All criteria
|
|
|
|
|
|
# ==============================================================================
|
|
# StrategyComparator Tests - Statistical Testing
|
|
# ==============================================================================
|
|
|
|
|
|
class TestStrategyComparatorStatistics:
|
|
"""Tests for statistical significance testing."""
|
|
|
|
def test_pairwise_with_return_series(self):
|
|
"""Test pairwise comparison with return series."""
|
|
comparator = StrategyComparator()
|
|
# Strategy with higher mean returns
|
|
high_returns = [Decimal("0.02")] * 50
|
|
low_returns = [Decimal("0.01")] * 50
|
|
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="high",
|
|
strategy_name="High Returns",
|
|
total_return=Decimal("0.50"),
|
|
returns_series=high_returns,
|
|
))
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="low",
|
|
strategy_name="Low Returns",
|
|
total_return=Decimal("0.25"),
|
|
returns_series=low_returns,
|
|
))
|
|
result = comparator.compare()
|
|
assert len(result.pairwise_comparisons) == 1
|
|
comparison = result.pairwise_comparisons[0]
|
|
# With identical values, should be highly significant
|
|
assert comparison.p_value is not None
|
|
|
|
def test_pairwise_insufficient_data(self):
|
|
"""Test pairwise comparison with insufficient data."""
|
|
comparator = StrategyComparator()
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="s1",
|
|
strategy_name="S1",
|
|
total_return=Decimal("0.20"),
|
|
returns_series=[Decimal("0.01")] * 10, # Not enough
|
|
))
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="s2",
|
|
strategy_name="S2",
|
|
total_return=Decimal("0.15"),
|
|
returns_series=[Decimal("0.01")] * 10,
|
|
))
|
|
result = comparator.compare()
|
|
comparison = result.pairwise_comparisons[0]
|
|
# With insufficient data, no statistical test performed
|
|
assert comparison.p_value is None
|
|
|
|
|
|
# ==============================================================================
|
|
# StrategyComparator Tests - Summary Statistics
|
|
# ==============================================================================
|
|
|
|
|
|
class TestStrategyComparatorSummary:
|
|
"""Tests for summary statistics."""
|
|
|
|
def test_summary_statistics(self):
|
|
"""Test summary statistics calculation."""
|
|
comparator = StrategyComparator()
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="s1",
|
|
strategy_name="S1",
|
|
total_return=Decimal("0.20"),
|
|
volatility=Decimal("0.15"),
|
|
max_drawdown=Decimal("-0.12"),
|
|
sharpe_ratio=Decimal("1.5"),
|
|
))
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="s2",
|
|
strategy_name="S2",
|
|
total_return=Decimal("0.10"),
|
|
volatility=Decimal("0.10"),
|
|
max_drawdown=Decimal("-0.08"),
|
|
sharpe_ratio=Decimal("1.2"),
|
|
))
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="s3",
|
|
strategy_name="S3",
|
|
total_return=Decimal("0.30"),
|
|
volatility=Decimal("0.20"),
|
|
max_drawdown=Decimal("-0.15"),
|
|
sharpe_ratio=Decimal("1.8"),
|
|
))
|
|
|
|
result = comparator.compare()
|
|
summary = result.summary_statistics
|
|
|
|
assert summary["strategy_count"] == 3
|
|
assert "return" in summary
|
|
assert "volatility" in summary
|
|
assert "max_drawdown" in summary
|
|
assert "sharpe_ratio" in summary
|
|
|
|
# Return stats
|
|
assert summary["return"]["min"] == pytest.approx(0.10)
|
|
assert summary["return"]["max"] == pytest.approx(0.30)
|
|
|
|
|
|
# ==============================================================================
|
|
# StrategyComparator Tests - Recommendations
|
|
# ==============================================================================
|
|
|
|
|
|
class TestStrategyComparatorRecommendations:
|
|
"""Tests for recommendation generation."""
|
|
|
|
def test_high_volatility_warning(self):
|
|
"""Test high volatility warning."""
|
|
comparator = StrategyComparator()
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="high_vol",
|
|
strategy_name="High Vol Strategy",
|
|
total_return=Decimal("0.40"),
|
|
volatility=Decimal("0.35"), # Very high
|
|
sharpe_ratio=Decimal("1.2"),
|
|
))
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="low_vol",
|
|
strategy_name="Low Vol Strategy",
|
|
total_return=Decimal("0.15"),
|
|
volatility=Decimal("0.08"),
|
|
sharpe_ratio=Decimal("1.8"),
|
|
))
|
|
result = comparator.compare()
|
|
vol_warnings = [r for r in result.recommendations if "volatility" in r.lower()]
|
|
assert len(vol_warnings) >= 1
|
|
|
|
def test_drawdown_warning(self):
|
|
"""Test significant drawdown warning."""
|
|
comparator = StrategyComparator()
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="big_dd",
|
|
strategy_name="Big Drawdown",
|
|
total_return=Decimal("0.30"),
|
|
max_drawdown=Decimal("-0.35"), # Very large
|
|
sharpe_ratio=Decimal("1.0"),
|
|
))
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="small_dd",
|
|
strategy_name="Small Drawdown",
|
|
total_return=Decimal("0.15"),
|
|
max_drawdown=Decimal("-0.08"),
|
|
sharpe_ratio=Decimal("1.5"),
|
|
))
|
|
result = comparator.compare()
|
|
dd_warnings = [r for r in result.recommendations if "drawdown" in r.lower()]
|
|
assert len(dd_warnings) >= 1
|
|
|
|
def test_low_trades_warning(self):
|
|
"""Test limited trade history warning."""
|
|
comparator = StrategyComparator()
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="few_trades",
|
|
strategy_name="Few Trades",
|
|
total_return=Decimal("0.50"),
|
|
total_trades=15, # Low
|
|
sharpe_ratio=Decimal("2.0"),
|
|
))
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="many_trades",
|
|
strategy_name="Many Trades",
|
|
total_return=Decimal("0.20"),
|
|
total_trades=500,
|
|
sharpe_ratio=Decimal("1.5"),
|
|
))
|
|
result = comparator.compare()
|
|
trade_warnings = [r for r in result.recommendations if "trade" in r.lower()]
|
|
assert len(trade_warnings) >= 1
|
|
|
|
|
|
# ==============================================================================
|
|
# StrategyComparator Tests - Compare Returns
|
|
# ==============================================================================
|
|
|
|
|
|
class TestStrategyComparatorReturns:
|
|
"""Tests for return comparison."""
|
|
|
|
def test_compare_returns_basic(self):
|
|
"""Test basic return comparison."""
|
|
comparator = StrategyComparator()
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="s1",
|
|
strategy_name="Strategy 1",
|
|
total_return=Decimal("0.25"),
|
|
annualized_return=Decimal("0.20"),
|
|
best_trade=Decimal("0.05"),
|
|
worst_trade=Decimal("-0.03"),
|
|
avg_trade_return=Decimal("0.01"),
|
|
))
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="s2",
|
|
strategy_name="Strategy 2",
|
|
total_return=Decimal("0.18"),
|
|
annualized_return=Decimal("0.15"),
|
|
best_trade=Decimal("0.03"),
|
|
worst_trade=Decimal("-0.02"),
|
|
avg_trade_return=Decimal("0.008"),
|
|
))
|
|
|
|
comparison = comparator.compare_returns()
|
|
assert "strategies" in comparison
|
|
assert len(comparison["strategies"]) == 2
|
|
assert "returns" in comparison
|
|
|
|
def test_compare_returns_with_series(self):
|
|
"""Test return comparison with return series."""
|
|
comparator = StrategyComparator()
|
|
returns = [Decimal("0.01"), Decimal("0.02"), Decimal("-0.01"), Decimal("0.015")]
|
|
returns_extended = returns * 10 # 40 values
|
|
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="s1",
|
|
strategy_name="Strategy 1",
|
|
total_return=Decimal("0.25"),
|
|
returns_series=returns_extended,
|
|
))
|
|
|
|
comparison = comparator.compare_returns()
|
|
assert "series_stats" in comparison["returns"]["Strategy 1"]
|
|
stats = comparison["returns"]["Strategy 1"]["series_stats"]
|
|
assert stats["count"] == 40
|
|
|
|
|
|
# ==============================================================================
|
|
# StrategyComparator Tests - Ranking Table
|
|
# ==============================================================================
|
|
|
|
|
|
class TestStrategyComparatorRankingTable:
|
|
"""Tests for ranking table generation."""
|
|
|
|
def test_get_ranking_table_empty(self):
|
|
"""Test ranking table with no strategies."""
|
|
comparator = StrategyComparator()
|
|
table = comparator.get_ranking_table()
|
|
assert table == []
|
|
|
|
def test_get_ranking_table(self):
|
|
"""Test ranking table generation."""
|
|
comparator = StrategyComparator()
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="s1",
|
|
strategy_name="Best",
|
|
total_return=Decimal("0.30"),
|
|
sharpe_ratio=Decimal("2.0"),
|
|
max_drawdown=Decimal("-0.05"),
|
|
win_rate=Decimal("0.65"),
|
|
))
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="s2",
|
|
strategy_name="Worst",
|
|
total_return=Decimal("0.10"),
|
|
sharpe_ratio=Decimal("0.8"),
|
|
max_drawdown=Decimal("-0.20"),
|
|
win_rate=Decimal("0.45"),
|
|
))
|
|
comparator.add_strategy(StrategyMetrics(
|
|
strategy_id="s3",
|
|
strategy_name="Middle",
|
|
total_return=Decimal("0.20"),
|
|
sharpe_ratio=Decimal("1.5"),
|
|
max_drawdown=Decimal("-0.10"),
|
|
win_rate=Decimal("0.55"),
|
|
))
|
|
|
|
table = comparator.get_ranking_table()
|
|
assert len(table) == 3
|
|
|
|
# Table should be sorted by average rank
|
|
assert "avg_rank" in table[0]
|
|
assert "rankings" in table[0]
|
|
|
|
# First should be the best overall
|
|
assert table[0]["strategy_name"] == "Best"
|
|
|
|
|
|
# ==============================================================================
|
|
# Module Import Tests
|
|
# ==============================================================================
|
|
|
|
|
|
class TestModuleImports:
|
|
"""Tests for module imports."""
|
|
|
|
def test_import_from_simulation_module(self):
|
|
"""Test importing from simulation module."""
|
|
from tradingagents.simulation import (
|
|
RankingCriteria,
|
|
ComparisonStatus,
|
|
StrategyMetrics,
|
|
PairwiseComparison,
|
|
ComparisonResult,
|
|
StrategyComparator,
|
|
)
|
|
assert RankingCriteria is not None
|
|
assert ComparisonStatus is not None
|
|
assert StrategyMetrics is not None
|
|
assert PairwiseComparison is not None
|
|
assert ComparisonResult is not None
|
|
assert StrategyComparator is not None
|
|
|
|
|
|
# ==============================================================================
|
|
# Integration Tests
|
|
# ==============================================================================
|
|
|
|
|
|
class TestStrategyComparatorIntegration:
|
|
"""Integration tests for StrategyComparator."""
|
|
|
|
def test_full_comparison_workflow(self):
|
|
"""Test complete comparison workflow."""
|
|
comparator = StrategyComparator()
|
|
|
|
# Add multiple strategies with various characteristics
|
|
strategies = [
|
|
StrategyMetrics(
|
|
strategy_id="momentum",
|
|
strategy_name="Momentum",
|
|
total_return=Decimal("0.28"),
|
|
annualized_return=Decimal("0.22"),
|
|
volatility=Decimal("0.18"),
|
|
sharpe_ratio=Decimal("1.55"),
|
|
sortino_ratio=Decimal("2.10"),
|
|
max_drawdown=Decimal("-0.14"),
|
|
calmar_ratio=Decimal("1.57"),
|
|
win_rate=Decimal("0.52"),
|
|
profit_factor=Decimal("1.45"),
|
|
total_trades=150,
|
|
avg_trade_return=Decimal("0.0018"),
|
|
),
|
|
StrategyMetrics(
|
|
strategy_id="value",
|
|
strategy_name="Value",
|
|
total_return=Decimal("0.18"),
|
|
annualized_return=Decimal("0.15"),
|
|
volatility=Decimal("0.12"),
|
|
sharpe_ratio=Decimal("1.25"),
|
|
sortino_ratio=Decimal("1.60"),
|
|
max_drawdown=Decimal("-0.10"),
|
|
calmar_ratio=Decimal("1.50"),
|
|
win_rate=Decimal("0.58"),
|
|
profit_factor=Decimal("1.65"),
|
|
total_trades=80,
|
|
avg_trade_return=Decimal("0.0022"),
|
|
),
|
|
StrategyMetrics(
|
|
strategy_id="growth",
|
|
strategy_name="Growth",
|
|
total_return=Decimal("0.35"),
|
|
annualized_return=Decimal("0.28"),
|
|
volatility=Decimal("0.25"),
|
|
sharpe_ratio=Decimal("1.40"),
|
|
sortino_ratio=Decimal("1.80"),
|
|
max_drawdown=Decimal("-0.22"),
|
|
calmar_ratio=Decimal("1.27"),
|
|
win_rate=Decimal("0.48"),
|
|
profit_factor=Decimal("1.35"),
|
|
total_trades=200,
|
|
avg_trade_return=Decimal("0.0017"),
|
|
),
|
|
]
|
|
|
|
for s in strategies:
|
|
comparator.add_strategy(s)
|
|
|
|
# Compare by Sharpe ratio
|
|
result = comparator.compare(primary_criteria=RankingCriteria.SHARPE_RATIO)
|
|
|
|
# Verify result
|
|
assert result.status == ComparisonStatus.VALID
|
|
assert result.strategy_count == 3
|
|
assert result.best_overall == "momentum" # Highest Sharpe
|
|
|
|
# Verify rankings
|
|
assert "momentum" in result.rankings[RankingCriteria.SHARPE_RATIO][:1]
|
|
assert "growth" in result.rankings[RankingCriteria.TOTAL_RETURN][:1]
|
|
|
|
# Verify pairwise comparisons
|
|
assert len(result.pairwise_comparisons) == 3 # 3 choose 2
|
|
|
|
# Verify summary statistics
|
|
assert result.summary_statistics["strategy_count"] == 3
|
|
|
|
# Get ranking table
|
|
table = comparator.get_ranking_table()
|
|
assert len(table) == 3
|