diff --git a/tests/unit/simulation/test_strategy_comparator.py b/tests/unit/simulation/test_strategy_comparator.py new file mode 100644 index 00000000..469b07fb --- /dev/null +++ b/tests/unit/simulation/test_strategy_comparator.py @@ -0,0 +1,778 @@ +"""Tests for Strategy Comparator. + +Issue #34: [SIM-33] Strategy comparator - performance comparison, stats + +Tests cover: +- RankingCriteria and ComparisonStatus enums +- StrategyMetrics dataclass +- PairwiseComparison and ComparisonResult dataclasses +- StrategyComparator comparison logic +- Statistical significance testing +- Ranking and recommendations +- Edge cases +""" + +import pytest +from datetime import date +from decimal import Decimal + +from tradingagents.simulation.strategy_comparator import ( + RankingCriteria, + ComparisonStatus, + StrategyMetrics, + PairwiseComparison, + ComparisonResult, + StrategyComparator, +) + + +# ============================================================================== +# RankingCriteria Enum Tests +# ============================================================================== + + +class TestRankingCriteria: + """Tests for RankingCriteria enum.""" + + def test_total_return_value(self): + """Test TOTAL_RETURN criterion value.""" + assert RankingCriteria.TOTAL_RETURN.value == "total_return" + + def test_sharpe_ratio_value(self): + """Test SHARPE_RATIO criterion value.""" + assert RankingCriteria.SHARPE_RATIO.value == "sharpe_ratio" + + def test_sortino_ratio_value(self): + """Test SORTINO_RATIO criterion value.""" + assert RankingCriteria.SORTINO_RATIO.value == "sortino_ratio" + + def test_max_drawdown_value(self): + """Test MAX_DRAWDOWN criterion value.""" + assert RankingCriteria.MAX_DRAWDOWN.value == "max_drawdown" + + def test_win_rate_value(self): + """Test WIN_RATE criterion value.""" + assert RankingCriteria.WIN_RATE.value == "win_rate" + + def test_profit_factor_value(self): + """Test PROFIT_FACTOR criterion value.""" + assert RankingCriteria.PROFIT_FACTOR.value == "profit_factor" + + def test_all_criteria_exist(self): + """Test all expected criteria exist.""" + criteria = [c for c in RankingCriteria] + assert len(criteria) == 8 + + +# ============================================================================== +# ComparisonStatus Enum Tests +# ============================================================================== + + +class TestComparisonStatus: + """Tests for ComparisonStatus enum.""" + + def test_valid_value(self): + """Test VALID status value.""" + assert ComparisonStatus.VALID.value == "valid" + + def test_insufficient_data_value(self): + """Test INSUFFICIENT_DATA status value.""" + assert ComparisonStatus.INSUFFICIENT_DATA.value == "insufficient_data" + + def test_incomparable_value(self): + """Test INCOMPARABLE status value.""" + assert ComparisonStatus.INCOMPARABLE.value == "incomparable" + + +# ============================================================================== +# StrategyMetrics Tests +# ============================================================================== + + +class TestStrategyMetrics: + """Tests for StrategyMetrics dataclass.""" + + def test_create_basic_metrics(self): + """Test creating basic strategy metrics.""" + metrics = StrategyMetrics( + strategy_id="strat1", + strategy_name="Momentum", + total_return=Decimal("0.25"), + sharpe_ratio=Decimal("1.5"), + ) + assert metrics.strategy_id == "strat1" + assert metrics.strategy_name == "Momentum" + assert metrics.total_return == Decimal("0.25") + assert metrics.sharpe_ratio == Decimal("1.5") + + def test_risk_adjusted_return(self): + """Test risk-adjusted return calculation.""" + metrics = StrategyMetrics( + strategy_id="strat1", + strategy_name="Test", + total_return=Decimal("0.20"), + volatility=Decimal("0.10"), + ) + assert metrics.risk_adjusted_return == Decimal("2.0000") + + def test_risk_adjusted_return_zero_volatility(self): + """Test risk-adjusted return with zero volatility.""" + metrics = StrategyMetrics( + strategy_id="strat1", + strategy_name="Test", + total_return=Decimal("0.20"), + volatility=Decimal("0"), + ) + assert metrics.risk_adjusted_return == Decimal("0") + + def test_has_sufficient_data_trades(self): + """Test sufficient data check with trades.""" + metrics = StrategyMetrics( + strategy_id="strat1", + strategy_name="Test", + total_trades=15, + ) + assert metrics.has_sufficient_data is True + + def test_has_sufficient_data_returns(self): + """Test sufficient data check with return series.""" + metrics = StrategyMetrics( + strategy_id="strat1", + strategy_name="Test", + total_trades=5, + returns_series=[Decimal("0.01")] * 35, + ) + assert metrics.has_sufficient_data is True + + def test_insufficient_data(self): + """Test insufficient data detection.""" + metrics = StrategyMetrics( + strategy_id="strat1", + strategy_name="Test", + total_trades=5, + returns_series=[Decimal("0.01")] * 10, + ) + assert metrics.has_sufficient_data is False + + +# ============================================================================== +# PairwiseComparison Tests +# ============================================================================== + + +class TestPairwiseComparison: + """Tests for PairwiseComparison dataclass.""" + + def test_create_comparison(self): + """Test creating a pairwise comparison.""" + comparison = PairwiseComparison( + strategy_a_id="strat1", + strategy_b_id="strat2", + winner="strat1", + return_difference=Decimal("0.05"), + sharpe_difference=Decimal("0.3"), + ) + assert comparison.strategy_a_id == "strat1" + assert comparison.strategy_b_id == "strat2" + assert comparison.winner == "strat1" + assert comparison.return_difference == Decimal("0.05") + + def test_no_winner(self): + """Test comparison with no clear winner.""" + comparison = PairwiseComparison( + strategy_a_id="strat1", + strategy_b_id="strat2", + winner=None, + return_difference=Decimal("0.01"), + ) + assert comparison.winner is None + + +# ============================================================================== +# ComparisonResult Tests +# ============================================================================== + + +class TestComparisonResult: + """Tests for ComparisonResult dataclass.""" + + def test_strategy_count(self): + """Test strategy count property.""" + result = ComparisonResult( + status=ComparisonStatus.VALID, + strategies=[ + StrategyMetrics(strategy_id="s1", strategy_name="S1"), + StrategyMetrics(strategy_id="s2", strategy_name="S2"), + ], + ) + assert result.strategy_count == 2 + + def test_empty_result(self): + """Test empty comparison result.""" + result = ComparisonResult(status=ComparisonStatus.INSUFFICIENT_DATA) + assert result.strategy_count == 0 + assert result.best_overall is None + + +# ============================================================================== +# StrategyComparator Tests - Basic Operations +# ============================================================================== + + +class TestStrategyComparatorBasic: + """Tests for StrategyComparator basic operations.""" + + def test_add_strategy(self): + """Test adding a strategy.""" + comparator = StrategyComparator() + strategy = StrategyMetrics( + strategy_id="strat1", + strategy_name="Test", + total_return=Decimal("0.20"), + ) + comparator.add_strategy(strategy) + assert comparator.get_strategy("strat1") is not None + + def test_remove_strategy(self): + """Test removing a strategy.""" + comparator = StrategyComparator() + strategy = StrategyMetrics(strategy_id="strat1", strategy_name="Test") + comparator.add_strategy(strategy) + assert comparator.remove_strategy("strat1") is True + assert comparator.get_strategy("strat1") is None + + def test_remove_nonexistent_strategy(self): + """Test removing a nonexistent strategy.""" + comparator = StrategyComparator() + assert comparator.remove_strategy("nonexistent") is False + + def test_get_all_strategies(self): + """Test getting all strategies.""" + comparator = StrategyComparator() + comparator.add_strategy(StrategyMetrics(strategy_id="s1", strategy_name="S1")) + comparator.add_strategy(StrategyMetrics(strategy_id="s2", strategy_name="S2")) + strategies = comparator.get_all_strategies() + assert len(strategies) == 2 + + def test_clear(self): + """Test clearing all strategies.""" + comparator = StrategyComparator() + comparator.add_strategy(StrategyMetrics(strategy_id="s1", strategy_name="S1")) + comparator.clear() + assert len(comparator.get_all_strategies()) == 0 + + +# ============================================================================== +# StrategyComparator Tests - Comparison +# ============================================================================== + + +class TestStrategyComparatorComparison: + """Tests for StrategyComparator comparison logic.""" + + def test_compare_empty(self): + """Test comparing with no strategies.""" + comparator = StrategyComparator() + result = comparator.compare() + assert result.status == ComparisonStatus.INSUFFICIENT_DATA + + def test_compare_single_strategy(self): + """Test comparing with single strategy.""" + comparator = StrategyComparator() + comparator.add_strategy(StrategyMetrics( + strategy_id="strat1", + strategy_name="Test", + total_return=Decimal("0.20"), + )) + result = comparator.compare() + assert result.status == ComparisonStatus.INSUFFICIENT_DATA + assert result.best_overall == "strat1" + assert result.worst_overall == "strat1" + + def test_compare_two_strategies(self): + """Test comparing two strategies.""" + comparator = StrategyComparator() + comparator.add_strategy(StrategyMetrics( + strategy_id="strat1", + strategy_name="Momentum", + total_return=Decimal("0.25"), + sharpe_ratio=Decimal("1.5"), + volatility=Decimal("0.15"), + max_drawdown=Decimal("-0.12"), + )) + comparator.add_strategy(StrategyMetrics( + strategy_id="strat2", + strategy_name="Value", + total_return=Decimal("0.18"), + sharpe_ratio=Decimal("1.2"), + volatility=Decimal("0.12"), + max_drawdown=Decimal("-0.08"), + )) + result = comparator.compare() + assert result.status == ComparisonStatus.VALID + assert result.strategy_count == 2 + assert len(result.pairwise_comparisons) == 1 + + def test_compare_by_sharpe(self): + """Test comparison by Sharpe ratio.""" + comparator = StrategyComparator() + comparator.add_strategy(StrategyMetrics( + strategy_id="high_sharpe", + strategy_name="High Sharpe", + total_return=Decimal("0.15"), + sharpe_ratio=Decimal("2.0"), + )) + comparator.add_strategy(StrategyMetrics( + strategy_id="low_sharpe", + strategy_name="Low Sharpe", + total_return=Decimal("0.25"), + sharpe_ratio=Decimal("1.0"), + )) + result = comparator.compare(primary_criteria=RankingCriteria.SHARPE_RATIO) + assert result.best_overall == "high_sharpe" + + def test_compare_by_return(self): + """Test comparison by total return.""" + comparator = StrategyComparator() + comparator.add_strategy(StrategyMetrics( + strategy_id="high_return", + strategy_name="High Return", + total_return=Decimal("0.30"), + sharpe_ratio=Decimal("1.0"), + )) + comparator.add_strategy(StrategyMetrics( + strategy_id="low_return", + strategy_name="Low Return", + total_return=Decimal("0.10"), + sharpe_ratio=Decimal("2.0"), + )) + result = comparator.compare(primary_criteria=RankingCriteria.TOTAL_RETURN) + assert result.best_overall == "high_return" + + def test_rankings_all_criteria(self): + """Test rankings for all criteria.""" + comparator = StrategyComparator() + comparator.add_strategy(StrategyMetrics( + strategy_id="s1", + strategy_name="Strategy 1", + total_return=Decimal("0.20"), + sharpe_ratio=Decimal("1.5"), + sortino_ratio=Decimal("2.0"), + max_drawdown=Decimal("-0.10"), + win_rate=Decimal("0.55"), + profit_factor=Decimal("1.5"), + calmar_ratio=Decimal("2.0"), + volatility=Decimal("0.10"), + )) + comparator.add_strategy(StrategyMetrics( + strategy_id="s2", + strategy_name="Strategy 2", + total_return=Decimal("0.15"), + sharpe_ratio=Decimal("1.8"), + sortino_ratio=Decimal("2.5"), + max_drawdown=Decimal("-0.08"), + win_rate=Decimal("0.60"), + profit_factor=Decimal("1.8"), + calmar_ratio=Decimal("1.8"), + volatility=Decimal("0.08"), + )) + result = comparator.compare() + assert result.status == ComparisonStatus.VALID + assert len(result.rankings) == 8 # All criteria + + +# ============================================================================== +# StrategyComparator Tests - Statistical Testing +# ============================================================================== + + +class TestStrategyComparatorStatistics: + """Tests for statistical significance testing.""" + + def test_pairwise_with_return_series(self): + """Test pairwise comparison with return series.""" + comparator = StrategyComparator() + # Strategy with higher mean returns + high_returns = [Decimal("0.02")] * 50 + low_returns = [Decimal("0.01")] * 50 + + comparator.add_strategy(StrategyMetrics( + strategy_id="high", + strategy_name="High Returns", + total_return=Decimal("0.50"), + returns_series=high_returns, + )) + comparator.add_strategy(StrategyMetrics( + strategy_id="low", + strategy_name="Low Returns", + total_return=Decimal("0.25"), + returns_series=low_returns, + )) + result = comparator.compare() + assert len(result.pairwise_comparisons) == 1 + comparison = result.pairwise_comparisons[0] + # With identical values, should be highly significant + assert comparison.p_value is not None + + def test_pairwise_insufficient_data(self): + """Test pairwise comparison with insufficient data.""" + comparator = StrategyComparator() + comparator.add_strategy(StrategyMetrics( + strategy_id="s1", + strategy_name="S1", + total_return=Decimal("0.20"), + returns_series=[Decimal("0.01")] * 10, # Not enough + )) + comparator.add_strategy(StrategyMetrics( + strategy_id="s2", + strategy_name="S2", + total_return=Decimal("0.15"), + returns_series=[Decimal("0.01")] * 10, + )) + result = comparator.compare() + comparison = result.pairwise_comparisons[0] + # With insufficient data, no statistical test performed + assert comparison.p_value is None + + +# ============================================================================== +# StrategyComparator Tests - Summary Statistics +# ============================================================================== + + +class TestStrategyComparatorSummary: + """Tests for summary statistics.""" + + def test_summary_statistics(self): + """Test summary statistics calculation.""" + comparator = StrategyComparator() + comparator.add_strategy(StrategyMetrics( + strategy_id="s1", + strategy_name="S1", + total_return=Decimal("0.20"), + volatility=Decimal("0.15"), + max_drawdown=Decimal("-0.12"), + sharpe_ratio=Decimal("1.5"), + )) + comparator.add_strategy(StrategyMetrics( + strategy_id="s2", + strategy_name="S2", + total_return=Decimal("0.10"), + volatility=Decimal("0.10"), + max_drawdown=Decimal("-0.08"), + sharpe_ratio=Decimal("1.2"), + )) + comparator.add_strategy(StrategyMetrics( + strategy_id="s3", + strategy_name="S3", + total_return=Decimal("0.30"), + volatility=Decimal("0.20"), + max_drawdown=Decimal("-0.15"), + sharpe_ratio=Decimal("1.8"), + )) + + result = comparator.compare() + summary = result.summary_statistics + + assert summary["strategy_count"] == 3 + assert "return" in summary + assert "volatility" in summary + assert "max_drawdown" in summary + assert "sharpe_ratio" in summary + + # Return stats + assert summary["return"]["min"] == pytest.approx(0.10) + assert summary["return"]["max"] == pytest.approx(0.30) + + +# ============================================================================== +# StrategyComparator Tests - Recommendations +# ============================================================================== + + +class TestStrategyComparatorRecommendations: + """Tests for recommendation generation.""" + + def test_high_volatility_warning(self): + """Test high volatility warning.""" + comparator = StrategyComparator() + comparator.add_strategy(StrategyMetrics( + strategy_id="high_vol", + strategy_name="High Vol Strategy", + total_return=Decimal("0.40"), + volatility=Decimal("0.35"), # Very high + sharpe_ratio=Decimal("1.2"), + )) + comparator.add_strategy(StrategyMetrics( + strategy_id="low_vol", + strategy_name="Low Vol Strategy", + total_return=Decimal("0.15"), + volatility=Decimal("0.08"), + sharpe_ratio=Decimal("1.8"), + )) + result = comparator.compare() + vol_warnings = [r for r in result.recommendations if "volatility" in r.lower()] + assert len(vol_warnings) >= 1 + + def test_drawdown_warning(self): + """Test significant drawdown warning.""" + comparator = StrategyComparator() + comparator.add_strategy(StrategyMetrics( + strategy_id="big_dd", + strategy_name="Big Drawdown", + total_return=Decimal("0.30"), + max_drawdown=Decimal("-0.35"), # Very large + sharpe_ratio=Decimal("1.0"), + )) + comparator.add_strategy(StrategyMetrics( + strategy_id="small_dd", + strategy_name="Small Drawdown", + total_return=Decimal("0.15"), + max_drawdown=Decimal("-0.08"), + sharpe_ratio=Decimal("1.5"), + )) + result = comparator.compare() + dd_warnings = [r for r in result.recommendations if "drawdown" in r.lower()] + assert len(dd_warnings) >= 1 + + def test_low_trades_warning(self): + """Test limited trade history warning.""" + comparator = StrategyComparator() + comparator.add_strategy(StrategyMetrics( + strategy_id="few_trades", + strategy_name="Few Trades", + total_return=Decimal("0.50"), + total_trades=15, # Low + sharpe_ratio=Decimal("2.0"), + )) + comparator.add_strategy(StrategyMetrics( + strategy_id="many_trades", + strategy_name="Many Trades", + total_return=Decimal("0.20"), + total_trades=500, + sharpe_ratio=Decimal("1.5"), + )) + result = comparator.compare() + trade_warnings = [r for r in result.recommendations if "trade" in r.lower()] + assert len(trade_warnings) >= 1 + + +# ============================================================================== +# StrategyComparator Tests - Compare Returns +# ============================================================================== + + +class TestStrategyComparatorReturns: + """Tests for return comparison.""" + + def test_compare_returns_basic(self): + """Test basic return comparison.""" + comparator = StrategyComparator() + comparator.add_strategy(StrategyMetrics( + strategy_id="s1", + strategy_name="Strategy 1", + total_return=Decimal("0.25"), + annualized_return=Decimal("0.20"), + best_trade=Decimal("0.05"), + worst_trade=Decimal("-0.03"), + avg_trade_return=Decimal("0.01"), + )) + comparator.add_strategy(StrategyMetrics( + strategy_id="s2", + strategy_name="Strategy 2", + total_return=Decimal("0.18"), + annualized_return=Decimal("0.15"), + best_trade=Decimal("0.03"), + worst_trade=Decimal("-0.02"), + avg_trade_return=Decimal("0.008"), + )) + + comparison = comparator.compare_returns() + assert "strategies" in comparison + assert len(comparison["strategies"]) == 2 + assert "returns" in comparison + + def test_compare_returns_with_series(self): + """Test return comparison with return series.""" + comparator = StrategyComparator() + returns = [Decimal("0.01"), Decimal("0.02"), Decimal("-0.01"), Decimal("0.015")] + returns_extended = returns * 10 # 40 values + + comparator.add_strategy(StrategyMetrics( + strategy_id="s1", + strategy_name="Strategy 1", + total_return=Decimal("0.25"), + returns_series=returns_extended, + )) + + comparison = comparator.compare_returns() + assert "series_stats" in comparison["returns"]["Strategy 1"] + stats = comparison["returns"]["Strategy 1"]["series_stats"] + assert stats["count"] == 40 + + +# ============================================================================== +# StrategyComparator Tests - Ranking Table +# ============================================================================== + + +class TestStrategyComparatorRankingTable: + """Tests for ranking table generation.""" + + def test_get_ranking_table_empty(self): + """Test ranking table with no strategies.""" + comparator = StrategyComparator() + table = comparator.get_ranking_table() + assert table == [] + + def test_get_ranking_table(self): + """Test ranking table generation.""" + comparator = StrategyComparator() + comparator.add_strategy(StrategyMetrics( + strategy_id="s1", + strategy_name="Best", + total_return=Decimal("0.30"), + sharpe_ratio=Decimal("2.0"), + max_drawdown=Decimal("-0.05"), + win_rate=Decimal("0.65"), + )) + comparator.add_strategy(StrategyMetrics( + strategy_id="s2", + strategy_name="Worst", + total_return=Decimal("0.10"), + sharpe_ratio=Decimal("0.8"), + max_drawdown=Decimal("-0.20"), + win_rate=Decimal("0.45"), + )) + comparator.add_strategy(StrategyMetrics( + strategy_id="s3", + strategy_name="Middle", + total_return=Decimal("0.20"), + sharpe_ratio=Decimal("1.5"), + max_drawdown=Decimal("-0.10"), + win_rate=Decimal("0.55"), + )) + + table = comparator.get_ranking_table() + assert len(table) == 3 + + # Table should be sorted by average rank + assert "avg_rank" in table[0] + assert "rankings" in table[0] + + # First should be the best overall + assert table[0]["strategy_name"] == "Best" + + +# ============================================================================== +# Module Import Tests +# ============================================================================== + + +class TestModuleImports: + """Tests for module imports.""" + + def test_import_from_simulation_module(self): + """Test importing from simulation module.""" + from tradingagents.simulation import ( + RankingCriteria, + ComparisonStatus, + StrategyMetrics, + PairwiseComparison, + ComparisonResult, + StrategyComparator, + ) + assert RankingCriteria is not None + assert ComparisonStatus is not None + assert StrategyMetrics is not None + assert PairwiseComparison is not None + assert ComparisonResult is not None + assert StrategyComparator is not None + + +# ============================================================================== +# Integration Tests +# ============================================================================== + + +class TestStrategyComparatorIntegration: + """Integration tests for StrategyComparator.""" + + def test_full_comparison_workflow(self): + """Test complete comparison workflow.""" + comparator = StrategyComparator() + + # Add multiple strategies with various characteristics + strategies = [ + StrategyMetrics( + strategy_id="momentum", + strategy_name="Momentum", + total_return=Decimal("0.28"), + annualized_return=Decimal("0.22"), + volatility=Decimal("0.18"), + sharpe_ratio=Decimal("1.55"), + sortino_ratio=Decimal("2.10"), + max_drawdown=Decimal("-0.14"), + calmar_ratio=Decimal("1.57"), + win_rate=Decimal("0.52"), + profit_factor=Decimal("1.45"), + total_trades=150, + avg_trade_return=Decimal("0.0018"), + ), + StrategyMetrics( + strategy_id="value", + strategy_name="Value", + total_return=Decimal("0.18"), + annualized_return=Decimal("0.15"), + volatility=Decimal("0.12"), + sharpe_ratio=Decimal("1.25"), + sortino_ratio=Decimal("1.60"), + max_drawdown=Decimal("-0.10"), + calmar_ratio=Decimal("1.50"), + win_rate=Decimal("0.58"), + profit_factor=Decimal("1.65"), + total_trades=80, + avg_trade_return=Decimal("0.0022"), + ), + StrategyMetrics( + strategy_id="growth", + strategy_name="Growth", + total_return=Decimal("0.35"), + annualized_return=Decimal("0.28"), + volatility=Decimal("0.25"), + sharpe_ratio=Decimal("1.40"), + sortino_ratio=Decimal("1.80"), + max_drawdown=Decimal("-0.22"), + calmar_ratio=Decimal("1.27"), + win_rate=Decimal("0.48"), + profit_factor=Decimal("1.35"), + total_trades=200, + avg_trade_return=Decimal("0.0017"), + ), + ] + + for s in strategies: + comparator.add_strategy(s) + + # Compare by Sharpe ratio + result = comparator.compare(primary_criteria=RankingCriteria.SHARPE_RATIO) + + # Verify result + assert result.status == ComparisonStatus.VALID + assert result.strategy_count == 3 + assert result.best_overall == "momentum" # Highest Sharpe + + # Verify rankings + assert "momentum" in result.rankings[RankingCriteria.SHARPE_RATIO][:1] + assert "growth" in result.rankings[RankingCriteria.TOTAL_RETURN][:1] + + # Verify pairwise comparisons + assert len(result.pairwise_comparisons) == 3 # 3 choose 2 + + # Verify summary statistics + assert result.summary_statistics["strategy_count"] == 3 + + # Get ranking table + table = comparator.get_ranking_table() + assert len(table) == 3 diff --git a/tradingagents/simulation/__init__.py b/tradingagents/simulation/__init__.py index 0981dc3f..2223f978 100644 --- a/tradingagents/simulation/__init__.py +++ b/tradingagents/simulation/__init__.py @@ -7,23 +7,31 @@ This module provides simulation capabilities including: - Economic regime simulation Issue #33: [SIM-32] Scenario runner - parallel portfolio simulations +Issue #34: [SIM-33] Strategy comparator - performance comparison, stats Submodules: scenario_runner: Core scenario execution framework + strategy_comparator: Strategy comparison and statistical analysis Classes: Enums: - ExecutionMode: Parallel execution mode (sequential, threaded, process) - ScenarioStatus: Status of a scenario run + - RankingCriteria: Criteria for ranking strategies + - ComparisonStatus: Status of strategy comparison Data Classes: - ScenarioConfig: Configuration for a simulation scenario - ScenarioResult: Result from a scenario simulation - RunnerProgress: Progress information for batch runs + - StrategyMetrics: Performance metrics for a strategy + - PairwiseComparison: Comparison between two strategies + - ComparisonResult: Complete result of strategy comparison Main Classes: - ScenarioRunner: Runner for parallel portfolio simulations - ScenarioBatchBuilder: Builder for creating scenario batches + - StrategyComparator: Compares multiple trading strategies Protocols: - ScenarioExecutor: Protocol for scenario execution functions @@ -76,21 +84,42 @@ from .scenario_runner import ( aggregate_results, ) -__all__ = [ +from .strategy_comparator import ( # Enums + RankingCriteria, + ComparisonStatus, + # Data Classes + StrategyMetrics, + PairwiseComparison, + ComparisonResult, + # Main Class + StrategyComparator, +) + +__all__ = [ + # Scenario Runner Enums "ExecutionMode", "ScenarioStatus", - # Data Classes + # Scenario Runner Data Classes "ScenarioConfig", "ScenarioResult", "RunnerProgress", - # Main Classes + # Scenario Runner Main Classes "ScenarioRunner", "ScenarioBatchBuilder", - # Protocols + # Scenario Runner Protocols "ScenarioExecutor", - # Types + # Scenario Runner Types "ProgressCallback", - # Utility Functions + # Scenario Runner Utility Functions "aggregate_results", + # Strategy Comparator Enums + "RankingCriteria", + "ComparisonStatus", + # Strategy Comparator Data Classes + "StrategyMetrics", + "PairwiseComparison", + "ComparisonResult", + # Strategy Comparator Main Class + "StrategyComparator", ] diff --git a/tradingagents/simulation/strategy_comparator.py b/tradingagents/simulation/strategy_comparator.py new file mode 100644 index 00000000..3c2035c7 --- /dev/null +++ b/tradingagents/simulation/strategy_comparator.py @@ -0,0 +1,737 @@ +"""Strategy Comparator for performance comparison and statistical analysis. + +This module provides strategy comparison capabilities including: +- Performance metrics comparison across strategies +- Statistical significance testing +- Ranking and scoring +- Visualization data preparation + +Issue #34: [SIM-33] Strategy comparator - performance comparison, stats + +Design Principles: + - Comprehensive performance metrics + - Statistical rigor (hypothesis testing) + - Flexible ranking criteria + - Clear comparison outputs +""" + +from dataclasses import dataclass, field +from datetime import date +from decimal import Decimal, ROUND_HALF_UP +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple +import math +import statistics + + +class RankingCriteria(Enum): + """Criteria for ranking strategies.""" + TOTAL_RETURN = "total_return" + SHARPE_RATIO = "sharpe_ratio" + SORTINO_RATIO = "sortino_ratio" + MAX_DRAWDOWN = "max_drawdown" + WIN_RATE = "win_rate" + PROFIT_FACTOR = "profit_factor" + CALMAR_RATIO = "calmar_ratio" + RISK_ADJUSTED_RETURN = "risk_adjusted_return" + + +class ComparisonStatus(Enum): + """Status of strategy comparison.""" + VALID = "valid" + INSUFFICIENT_DATA = "insufficient_data" + INCOMPARABLE = "incomparable" + + +@dataclass +class StrategyMetrics: + """Performance metrics for a strategy. + + Attributes: + strategy_id: Unique identifier + strategy_name: Human-readable name + start_date: First date of performance data + end_date: Last date of performance data + total_return: Total return as decimal (0.10 = 10%) + annualized_return: Annualized return + volatility: Annualized standard deviation of returns + sharpe_ratio: Sharpe ratio (risk-free rate assumed 0) + sortino_ratio: Sortino ratio (downside deviation) + max_drawdown: Maximum drawdown (negative) + calmar_ratio: Annualized return / max drawdown + win_rate: Percentage of winning trades + profit_factor: Gross profit / gross loss + total_trades: Number of trades executed + avg_trade_return: Average return per trade + avg_win: Average winning trade + avg_loss: Average losing trade + best_trade: Best single trade return + worst_trade: Worst single trade return + returns_series: Time series of periodic returns + equity_curve: Time series of equity values + metadata: Additional strategy data + """ + strategy_id: str + strategy_name: str + start_date: Optional[date] = None + end_date: Optional[date] = None + total_return: Decimal = Decimal("0") + annualized_return: Decimal = Decimal("0") + volatility: Decimal = Decimal("0") + sharpe_ratio: Optional[Decimal] = None + sortino_ratio: Optional[Decimal] = None + max_drawdown: Decimal = Decimal("0") + calmar_ratio: Optional[Decimal] = None + win_rate: Decimal = Decimal("0") + profit_factor: Optional[Decimal] = None + total_trades: int = 0 + avg_trade_return: Decimal = Decimal("0") + avg_win: Decimal = Decimal("0") + avg_loss: Decimal = Decimal("0") + best_trade: Decimal = Decimal("0") + worst_trade: Decimal = Decimal("0") + returns_series: List[Decimal] = field(default_factory=list) + equity_curve: List[Tuple[date, Decimal]] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def risk_adjusted_return(self) -> Decimal: + """Calculate risk-adjusted return (return / volatility).""" + if self.volatility == 0: + return Decimal("0") + return (self.total_return / self.volatility).quantize( + Decimal("0.0001"), rounding=ROUND_HALF_UP + ) + + @property + def has_sufficient_data(self) -> bool: + """Check if strategy has sufficient data for comparison.""" + return self.total_trades >= 10 or len(self.returns_series) >= 30 + + +@dataclass +class PairwiseComparison: + """Comparison between two strategies. + + Attributes: + strategy_a_id: First strategy ID + strategy_b_id: Second strategy ID + winner: ID of the winning strategy (or None if tie) + return_difference: Difference in total returns (A - B) + sharpe_difference: Difference in Sharpe ratios + volatility_difference: Difference in volatility + statistically_significant: Whether difference is significant + p_value: P-value from statistical test + confidence_interval: 95% CI for return difference + notes: Additional comparison notes + """ + strategy_a_id: str + strategy_b_id: str + winner: Optional[str] = None + return_difference: Decimal = Decimal("0") + sharpe_difference: Optional[Decimal] = None + volatility_difference: Decimal = Decimal("0") + statistically_significant: bool = False + p_value: Optional[float] = None + confidence_interval: Optional[Tuple[float, float]] = None + notes: str = "" + + +@dataclass +class ComparisonResult: + """Complete result of strategy comparison. + + Attributes: + status: Comparison status + strategies: List of strategies compared + rankings: Strategies ranked by criteria + best_overall: Best overall strategy ID + worst_overall: Worst overall strategy ID + pairwise_comparisons: Pairwise comparison results + summary_statistics: Summary statistics across all strategies + recommendations: Analysis recommendations + metadata: Additional result data + """ + status: ComparisonStatus + strategies: List[StrategyMetrics] = field(default_factory=list) + rankings: Dict[RankingCriteria, List[str]] = field(default_factory=dict) + best_overall: Optional[str] = None + worst_overall: Optional[str] = None + pairwise_comparisons: List[PairwiseComparison] = field(default_factory=list) + summary_statistics: Dict[str, Any] = field(default_factory=dict) + recommendations: List[str] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def strategy_count(self) -> int: + """Number of strategies compared.""" + return len(self.strategies) + + +class StrategyComparator: + """Compares multiple trading strategies. + + Provides comprehensive comparison of strategy performance including + statistical tests, rankings, and recommendations. + + Example: + >>> comparator = StrategyComparator() + >>> comparator.add_strategy(StrategyMetrics( + ... strategy_id="strat1", + ... strategy_name="Momentum", + ... total_return=Decimal("0.25"), + ... sharpe_ratio=Decimal("1.5"), + ... )) + >>> comparator.add_strategy(StrategyMetrics( + ... strategy_id="strat2", + ... strategy_name="Value", + ... total_return=Decimal("0.18"), + ... sharpe_ratio=Decimal("1.2"), + ... )) + >>> result = comparator.compare() + >>> print(f"Best: {result.best_overall}") + """ + + def __init__( + self, + risk_free_rate: Decimal = Decimal("0"), + min_data_points: int = 30, + significance_level: float = 0.05, + ): + """Initialize the comparator. + + Args: + risk_free_rate: Risk-free rate for Sharpe calculations + min_data_points: Minimum data points for valid comparison + significance_level: Significance level for statistical tests + """ + self.risk_free_rate = risk_free_rate + self.min_data_points = min_data_points + self.significance_level = significance_level + self._strategies: Dict[str, StrategyMetrics] = {} + + def add_strategy(self, strategy: StrategyMetrics) -> None: + """Add a strategy for comparison. + + Args: + strategy: Strategy metrics to add + """ + self._strategies[strategy.strategy_id] = strategy + + def remove_strategy(self, strategy_id: str) -> bool: + """Remove a strategy from comparison. + + Args: + strategy_id: ID of strategy to remove + + Returns: + True if removed, False if not found + """ + if strategy_id in self._strategies: + del self._strategies[strategy_id] + return True + return False + + def get_strategy(self, strategy_id: str) -> Optional[StrategyMetrics]: + """Get a strategy by ID. + + Args: + strategy_id: Strategy ID + + Returns: + Strategy metrics or None if not found + """ + return self._strategies.get(strategy_id) + + def get_all_strategies(self) -> List[StrategyMetrics]: + """Get all strategies. + + Returns: + List of all strategy metrics + """ + return list(self._strategies.values()) + + def clear(self) -> None: + """Remove all strategies.""" + self._strategies.clear() + + def _rank_by_criteria( + self, criteria: RankingCriteria + ) -> List[str]: + """Rank strategies by a specific criterion. + + Args: + criteria: Ranking criterion + + Returns: + List of strategy IDs in ranked order (best first) + """ + strategies = list(self._strategies.values()) + + if criteria == RankingCriteria.TOTAL_RETURN: + key = lambda s: float(s.total_return) + reverse = True + elif criteria == RankingCriteria.SHARPE_RATIO: + key = lambda s: float(s.sharpe_ratio or Decimal("-999")) + reverse = True + elif criteria == RankingCriteria.SORTINO_RATIO: + key = lambda s: float(s.sortino_ratio or Decimal("-999")) + reverse = True + elif criteria == RankingCriteria.MAX_DRAWDOWN: + # Less negative is better + key = lambda s: float(s.max_drawdown) + reverse = True + elif criteria == RankingCriteria.WIN_RATE: + key = lambda s: float(s.win_rate) + reverse = True + elif criteria == RankingCriteria.PROFIT_FACTOR: + key = lambda s: float(s.profit_factor or Decimal("0")) + reverse = True + elif criteria == RankingCriteria.CALMAR_RATIO: + key = lambda s: float(s.calmar_ratio or Decimal("-999")) + reverse = True + elif criteria == RankingCriteria.RISK_ADJUSTED_RETURN: + key = lambda s: float(s.risk_adjusted_return) + reverse = True + else: + key = lambda s: float(s.total_return) + reverse = True + + sorted_strategies = sorted(strategies, key=key, reverse=reverse) + return [s.strategy_id for s in sorted_strategies] + + def _calculate_pairwise_comparison( + self, + strategy_a: StrategyMetrics, + strategy_b: StrategyMetrics, + ) -> PairwiseComparison: + """Calculate pairwise comparison between two strategies. + + Args: + strategy_a: First strategy + strategy_b: Second strategy + + Returns: + Pairwise comparison result + """ + return_diff = strategy_a.total_return - strategy_b.total_return + + sharpe_diff = None + if strategy_a.sharpe_ratio is not None and strategy_b.sharpe_ratio is not None: + sharpe_diff = strategy_a.sharpe_ratio - strategy_b.sharpe_ratio + + vol_diff = strategy_a.volatility - strategy_b.volatility + + # Determine winner based on Sharpe ratio (or return if no Sharpe) + winner = None + if sharpe_diff is not None: + if sharpe_diff > Decimal("0.1"): # Meaningful difference + winner = strategy_a.strategy_id + elif sharpe_diff < Decimal("-0.1"): + winner = strategy_b.strategy_id + elif return_diff > Decimal("0.05"): + winner = strategy_a.strategy_id + elif return_diff < Decimal("-0.05"): + winner = strategy_b.strategy_id + + # Statistical significance test + significant = False + p_value = None + ci = None + + # Perform t-test if we have return series + if (len(strategy_a.returns_series) >= self.min_data_points and + len(strategy_b.returns_series) >= self.min_data_points): + try: + t_stat, p_value, ci = self._welch_t_test( + [float(r) for r in strategy_a.returns_series], + [float(r) for r in strategy_b.returns_series], + ) + significant = p_value < self.significance_level + except Exception: + pass + + notes = "" + if significant: + notes = f"Statistically significant difference (p={p_value:.4f})" + elif p_value is not None: + notes = f"Not statistically significant (p={p_value:.4f})" + + return PairwiseComparison( + strategy_a_id=strategy_a.strategy_id, + strategy_b_id=strategy_b.strategy_id, + winner=winner, + return_difference=return_diff.quantize(Decimal("0.0001")), + sharpe_difference=sharpe_diff.quantize(Decimal("0.0001")) if sharpe_diff else None, + volatility_difference=vol_diff.quantize(Decimal("0.0001")), + statistically_significant=significant, + p_value=p_value, + confidence_interval=ci, + notes=notes, + ) + + def _welch_t_test( + self, + sample_a: List[float], + sample_b: List[float], + ) -> Tuple[float, float, Tuple[float, float]]: + """Perform Welch's t-test for unequal variances. + + Args: + sample_a: First sample + sample_b: Second sample + + Returns: + Tuple of (t-statistic, p-value, 95% CI) + """ + n_a = len(sample_a) + n_b = len(sample_b) + + mean_a = statistics.mean(sample_a) + mean_b = statistics.mean(sample_b) + var_a = statistics.variance(sample_a) if n_a > 1 else 0 + var_b = statistics.variance(sample_b) if n_b > 1 else 0 + + # Pooled standard error + se = math.sqrt(var_a / n_a + var_b / n_b) if (var_a + var_b) > 0 else 0.0001 + + # t-statistic + t_stat = (mean_a - mean_b) / se + + # Degrees of freedom (Welch-Satterthwaite) + if var_a > 0 or var_b > 0: + num = (var_a / n_a + var_b / n_b) ** 2 + denom = ( + (var_a / n_a) ** 2 / (n_a - 1) + + (var_b / n_b) ** 2 / (n_b - 1) + ) + df = num / denom if denom > 0 else n_a + n_b - 2 + else: + df = n_a + n_b - 2 + + # Approximate p-value using normal distribution (good for large samples) + # For more accuracy, would use t-distribution + p_value = 2 * (1 - self._normal_cdf(abs(t_stat))) + + # 95% confidence interval + z = 1.96 # For 95% CI + ci_lower = (mean_a - mean_b) - z * se + ci_upper = (mean_a - mean_b) + z * se + + return t_stat, p_value, (ci_lower, ci_upper) + + @staticmethod + def _normal_cdf(x: float) -> float: + """Approximate standard normal CDF. + + Args: + x: Value + + Returns: + Cumulative probability + """ + return (1 + math.erf(x / math.sqrt(2))) / 2 + + def _calculate_summary_statistics( + self, strategies: List[StrategyMetrics] + ) -> Dict[str, Any]: + """Calculate summary statistics across all strategies. + + Args: + strategies: List of strategies + + Returns: + Dictionary of summary statistics + """ + if not strategies: + return {} + + returns = [float(s.total_return) for s in strategies] + vols = [float(s.volatility) for s in strategies] + sharpes = [ + float(s.sharpe_ratio) for s in strategies + if s.sharpe_ratio is not None + ] + drawdowns = [float(s.max_drawdown) for s in strategies] + + summary = { + "strategy_count": len(strategies), + "return": { + "mean": statistics.mean(returns), + "median": statistics.median(returns), + "min": min(returns), + "max": max(returns), + "stdev": statistics.stdev(returns) if len(returns) > 1 else 0, + }, + "volatility": { + "mean": statistics.mean(vols), + "median": statistics.median(vols), + "min": min(vols), + "max": max(vols), + }, + "max_drawdown": { + "mean": statistics.mean(drawdowns), + "worst": min(drawdowns), + "best": max(drawdowns), + }, + } + + if sharpes: + summary["sharpe_ratio"] = { + "mean": statistics.mean(sharpes), + "median": statistics.median(sharpes), + "min": min(sharpes), + "max": max(sharpes), + } + + return summary + + def _generate_recommendations( + self, + strategies: List[StrategyMetrics], + rankings: Dict[RankingCriteria, List[str]], + ) -> List[str]: + """Generate analysis recommendations. + + Args: + strategies: List of strategies + rankings: Rankings by criteria + + Returns: + List of recommendation strings + """ + recommendations = [] + + if len(strategies) < 2: + recommendations.append( + "Add more strategies for meaningful comparison." + ) + return recommendations + + # Consistency check + best_return = rankings[RankingCriteria.TOTAL_RETURN][0] + best_sharpe = rankings.get(RankingCriteria.SHARPE_RATIO, [None])[0] + + if best_return != best_sharpe and best_sharpe: + recommendations.append( + f"'{self._strategies[best_return].strategy_name}' has highest " + f"return but '{self._strategies[best_sharpe].strategy_name}' " + "has best risk-adjusted performance." + ) + + # Volatility warning + for s in strategies: + if float(s.volatility) > 0.30: # 30% annual volatility + recommendations.append( + f"'{s.strategy_name}' has high volatility " + f"({float(s.volatility)*100:.1f}%). Consider risk reduction." + ) + + # Drawdown warning + for s in strategies: + if float(s.max_drawdown) < -0.25: # 25% drawdown + recommendations.append( + f"'{s.strategy_name}' experienced significant drawdown " + f"({float(s.max_drawdown)*100:.1f}%). Review risk management." + ) + + # Insufficient trades warning + for s in strategies: + if s.total_trades < 30: + recommendations.append( + f"'{s.strategy_name}' has limited trade history " + f"({s.total_trades} trades). Results may not be reliable." + ) + + return recommendations + + def compare( + self, + primary_criteria: RankingCriteria = RankingCriteria.SHARPE_RATIO, + ) -> ComparisonResult: + """Compare all added strategies. + + Args: + primary_criteria: Primary criterion for determining best/worst + + Returns: + Complete comparison result + """ + strategies = list(self._strategies.values()) + + if len(strategies) == 0: + return ComparisonResult( + status=ComparisonStatus.INSUFFICIENT_DATA, + recommendations=["No strategies to compare."], + ) + + if len(strategies) == 1: + return ComparisonResult( + status=ComparisonStatus.INSUFFICIENT_DATA, + strategies=strategies, + best_overall=strategies[0].strategy_id, + worst_overall=strategies[0].strategy_id, + recommendations=["Only one strategy provided. Add more for comparison."], + ) + + # Calculate rankings for each criterion + rankings = {} + for criteria in RankingCriteria: + rankings[criteria] = self._rank_by_criteria(criteria) + + # Determine best and worst overall + primary_ranking = rankings[primary_criteria] + best_overall = primary_ranking[0] if primary_ranking else None + worst_overall = primary_ranking[-1] if primary_ranking else None + + # Calculate pairwise comparisons + pairwise = [] + for i, s_a in enumerate(strategies): + for s_b in strategies[i+1:]: + comparison = self._calculate_pairwise_comparison(s_a, s_b) + pairwise.append(comparison) + + # Calculate summary statistics + summary = self._calculate_summary_statistics(strategies) + + # Generate recommendations + recommendations = self._generate_recommendations(strategies, rankings) + + return ComparisonResult( + status=ComparisonStatus.VALID, + strategies=strategies, + rankings=rankings, + best_overall=best_overall, + worst_overall=worst_overall, + pairwise_comparisons=pairwise, + summary_statistics=summary, + recommendations=recommendations, + ) + + def compare_returns( + self, + strategy_ids: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """Compare return distributions. + + Args: + strategy_ids: Specific strategies to compare (None = all) + + Returns: + Return comparison data + """ + if strategy_ids: + strategies = [ + self._strategies[sid] for sid in strategy_ids + if sid in self._strategies + ] + else: + strategies = list(self._strategies.values()) + + comparison = { + "strategies": [s.strategy_name for s in strategies], + "returns": { + s.strategy_name: { + "total": str(s.total_return), + "annualized": str(s.annualized_return), + "best_trade": str(s.best_trade), + "worst_trade": str(s.worst_trade), + "avg_trade": str(s.avg_trade_return), + } + for s in strategies + }, + } + + # Return series statistics if available + for s in strategies: + if s.returns_series: + returns = [float(r) for r in s.returns_series] + comparison["returns"][s.strategy_name]["series_stats"] = { + "count": len(returns), + "mean": statistics.mean(returns), + "median": statistics.median(returns), + "stdev": statistics.stdev(returns) if len(returns) > 1 else 0, + "skew": self._calculate_skew(returns), + "kurtosis": self._calculate_kurtosis(returns), + } + + return comparison + + @staticmethod + def _calculate_skew(data: List[float]) -> float: + """Calculate skewness of a distribution. + + Args: + data: List of values + + Returns: + Skewness value + """ + if len(data) < 3: + return 0.0 + n = len(data) + mean = statistics.mean(data) + std = statistics.stdev(data) + if std == 0: + return 0.0 + return sum((x - mean) ** 3 for x in data) / ((n - 1) * std ** 3) + + @staticmethod + def _calculate_kurtosis(data: List[float]) -> float: + """Calculate excess kurtosis of a distribution. + + Args: + data: List of values + + Returns: + Excess kurtosis value + """ + if len(data) < 4: + return 0.0 + n = len(data) + mean = statistics.mean(data) + std = statistics.stdev(data) + if std == 0: + return 0.0 + return sum((x - mean) ** 4 for x in data) / ((n - 1) * std ** 4) - 3 + + def get_ranking_table(self) -> List[Dict[str, Any]]: + """Generate a ranking table for all strategies. + + Returns: + List of strategy data with rankings + """ + if not self._strategies: + return [] + + # Get all rankings + rankings = {} + for criteria in RankingCriteria: + for rank, sid in enumerate(self._rank_by_criteria(criteria), 1): + if sid not in rankings: + rankings[sid] = {} + rankings[sid][criteria.value] = rank + + # Build table + table = [] + for sid, strategy in self._strategies.items(): + row = { + "strategy_id": sid, + "strategy_name": strategy.strategy_name, + "total_return": str(strategy.total_return), + "sharpe_ratio": str(strategy.sharpe_ratio) if strategy.sharpe_ratio else "N/A", + "max_drawdown": str(strategy.max_drawdown), + "win_rate": str(strategy.win_rate), + "rankings": rankings.get(sid, {}), + } + + # Calculate average rank + ranks = list(rankings.get(sid, {}).values()) + row["avg_rank"] = sum(ranks) / len(ranks) if ranks else 0 + + table.append(row) + + # Sort by average rank + table.sort(key=lambda x: x["avg_rank"]) + + return table