feat(simulation): add Strategy Comparator for performance comparison - Issue #34 (43 tests)

Implements comprehensive strategy comparison framework:
- StrategyMetrics dataclass for performance data
- PairwiseComparison for head-to-head analysis
- ComparisonResult with rankings and recommendations
- StrategyComparator main class

Features:
- Multi-criteria ranking (Sharpe, Sortino, returns, drawdown, etc.)
- Welch's t-test for statistical significance
- Summary statistics across all strategies
- Automated recommendations (volatility, drawdown, trade count warnings)
- Return distribution analysis with skew/kurtosis
- Ranking table generation with average rank calculation

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Andrew Kaszubski 2025-12-26 22:05:35 +11:00
parent e7bff2c4cf
commit 76eac65eb3
3 changed files with 1550 additions and 6 deletions

View File

@ -0,0 +1,778 @@
"""Tests for Strategy Comparator.
Issue #34: [SIM-33] Strategy comparator - performance comparison, stats
Tests cover:
- RankingCriteria and ComparisonStatus enums
- StrategyMetrics dataclass
- PairwiseComparison and ComparisonResult dataclasses
- StrategyComparator comparison logic
- Statistical significance testing
- Ranking and recommendations
- Edge cases
"""
import pytest
from datetime import date
from decimal import Decimal
from tradingagents.simulation.strategy_comparator import (
RankingCriteria,
ComparisonStatus,
StrategyMetrics,
PairwiseComparison,
ComparisonResult,
StrategyComparator,
)
# ==============================================================================
# RankingCriteria Enum Tests
# ==============================================================================
class TestRankingCriteria:
"""Tests for RankingCriteria enum."""
def test_total_return_value(self):
"""Test TOTAL_RETURN criterion value."""
assert RankingCriteria.TOTAL_RETURN.value == "total_return"
def test_sharpe_ratio_value(self):
"""Test SHARPE_RATIO criterion value."""
assert RankingCriteria.SHARPE_RATIO.value == "sharpe_ratio"
def test_sortino_ratio_value(self):
"""Test SORTINO_RATIO criterion value."""
assert RankingCriteria.SORTINO_RATIO.value == "sortino_ratio"
def test_max_drawdown_value(self):
"""Test MAX_DRAWDOWN criterion value."""
assert RankingCriteria.MAX_DRAWDOWN.value == "max_drawdown"
def test_win_rate_value(self):
"""Test WIN_RATE criterion value."""
assert RankingCriteria.WIN_RATE.value == "win_rate"
def test_profit_factor_value(self):
"""Test PROFIT_FACTOR criterion value."""
assert RankingCriteria.PROFIT_FACTOR.value == "profit_factor"
def test_all_criteria_exist(self):
"""Test all expected criteria exist."""
criteria = [c for c in RankingCriteria]
assert len(criteria) == 8
# ==============================================================================
# ComparisonStatus Enum Tests
# ==============================================================================
class TestComparisonStatus:
"""Tests for ComparisonStatus enum."""
def test_valid_value(self):
"""Test VALID status value."""
assert ComparisonStatus.VALID.value == "valid"
def test_insufficient_data_value(self):
"""Test INSUFFICIENT_DATA status value."""
assert ComparisonStatus.INSUFFICIENT_DATA.value == "insufficient_data"
def test_incomparable_value(self):
"""Test INCOMPARABLE status value."""
assert ComparisonStatus.INCOMPARABLE.value == "incomparable"
# ==============================================================================
# StrategyMetrics Tests
# ==============================================================================
class TestStrategyMetrics:
"""Tests for StrategyMetrics dataclass."""
def test_create_basic_metrics(self):
"""Test creating basic strategy metrics."""
metrics = StrategyMetrics(
strategy_id="strat1",
strategy_name="Momentum",
total_return=Decimal("0.25"),
sharpe_ratio=Decimal("1.5"),
)
assert metrics.strategy_id == "strat1"
assert metrics.strategy_name == "Momentum"
assert metrics.total_return == Decimal("0.25")
assert metrics.sharpe_ratio == Decimal("1.5")
def test_risk_adjusted_return(self):
"""Test risk-adjusted return calculation."""
metrics = StrategyMetrics(
strategy_id="strat1",
strategy_name="Test",
total_return=Decimal("0.20"),
volatility=Decimal("0.10"),
)
assert metrics.risk_adjusted_return == Decimal("2.0000")
def test_risk_adjusted_return_zero_volatility(self):
"""Test risk-adjusted return with zero volatility."""
metrics = StrategyMetrics(
strategy_id="strat1",
strategy_name="Test",
total_return=Decimal("0.20"),
volatility=Decimal("0"),
)
assert metrics.risk_adjusted_return == Decimal("0")
def test_has_sufficient_data_trades(self):
"""Test sufficient data check with trades."""
metrics = StrategyMetrics(
strategy_id="strat1",
strategy_name="Test",
total_trades=15,
)
assert metrics.has_sufficient_data is True
def test_has_sufficient_data_returns(self):
"""Test sufficient data check with return series."""
metrics = StrategyMetrics(
strategy_id="strat1",
strategy_name="Test",
total_trades=5,
returns_series=[Decimal("0.01")] * 35,
)
assert metrics.has_sufficient_data is True
def test_insufficient_data(self):
"""Test insufficient data detection."""
metrics = StrategyMetrics(
strategy_id="strat1",
strategy_name="Test",
total_trades=5,
returns_series=[Decimal("0.01")] * 10,
)
assert metrics.has_sufficient_data is False
# ==============================================================================
# PairwiseComparison Tests
# ==============================================================================
class TestPairwiseComparison:
"""Tests for PairwiseComparison dataclass."""
def test_create_comparison(self):
"""Test creating a pairwise comparison."""
comparison = PairwiseComparison(
strategy_a_id="strat1",
strategy_b_id="strat2",
winner="strat1",
return_difference=Decimal("0.05"),
sharpe_difference=Decimal("0.3"),
)
assert comparison.strategy_a_id == "strat1"
assert comparison.strategy_b_id == "strat2"
assert comparison.winner == "strat1"
assert comparison.return_difference == Decimal("0.05")
def test_no_winner(self):
"""Test comparison with no clear winner."""
comparison = PairwiseComparison(
strategy_a_id="strat1",
strategy_b_id="strat2",
winner=None,
return_difference=Decimal("0.01"),
)
assert comparison.winner is None
# ==============================================================================
# ComparisonResult Tests
# ==============================================================================
class TestComparisonResult:
"""Tests for ComparisonResult dataclass."""
def test_strategy_count(self):
"""Test strategy count property."""
result = ComparisonResult(
status=ComparisonStatus.VALID,
strategies=[
StrategyMetrics(strategy_id="s1", strategy_name="S1"),
StrategyMetrics(strategy_id="s2", strategy_name="S2"),
],
)
assert result.strategy_count == 2
def test_empty_result(self):
"""Test empty comparison result."""
result = ComparisonResult(status=ComparisonStatus.INSUFFICIENT_DATA)
assert result.strategy_count == 0
assert result.best_overall is None
# ==============================================================================
# StrategyComparator Tests - Basic Operations
# ==============================================================================
class TestStrategyComparatorBasic:
"""Tests for StrategyComparator basic operations."""
def test_add_strategy(self):
"""Test adding a strategy."""
comparator = StrategyComparator()
strategy = StrategyMetrics(
strategy_id="strat1",
strategy_name="Test",
total_return=Decimal("0.20"),
)
comparator.add_strategy(strategy)
assert comparator.get_strategy("strat1") is not None
def test_remove_strategy(self):
"""Test removing a strategy."""
comparator = StrategyComparator()
strategy = StrategyMetrics(strategy_id="strat1", strategy_name="Test")
comparator.add_strategy(strategy)
assert comparator.remove_strategy("strat1") is True
assert comparator.get_strategy("strat1") is None
def test_remove_nonexistent_strategy(self):
"""Test removing a nonexistent strategy."""
comparator = StrategyComparator()
assert comparator.remove_strategy("nonexistent") is False
def test_get_all_strategies(self):
"""Test getting all strategies."""
comparator = StrategyComparator()
comparator.add_strategy(StrategyMetrics(strategy_id="s1", strategy_name="S1"))
comparator.add_strategy(StrategyMetrics(strategy_id="s2", strategy_name="S2"))
strategies = comparator.get_all_strategies()
assert len(strategies) == 2
def test_clear(self):
"""Test clearing all strategies."""
comparator = StrategyComparator()
comparator.add_strategy(StrategyMetrics(strategy_id="s1", strategy_name="S1"))
comparator.clear()
assert len(comparator.get_all_strategies()) == 0
# ==============================================================================
# StrategyComparator Tests - Comparison
# ==============================================================================
class TestStrategyComparatorComparison:
"""Tests for StrategyComparator comparison logic."""
def test_compare_empty(self):
"""Test comparing with no strategies."""
comparator = StrategyComparator()
result = comparator.compare()
assert result.status == ComparisonStatus.INSUFFICIENT_DATA
def test_compare_single_strategy(self):
"""Test comparing with single strategy."""
comparator = StrategyComparator()
comparator.add_strategy(StrategyMetrics(
strategy_id="strat1",
strategy_name="Test",
total_return=Decimal("0.20"),
))
result = comparator.compare()
assert result.status == ComparisonStatus.INSUFFICIENT_DATA
assert result.best_overall == "strat1"
assert result.worst_overall == "strat1"
def test_compare_two_strategies(self):
"""Test comparing two strategies."""
comparator = StrategyComparator()
comparator.add_strategy(StrategyMetrics(
strategy_id="strat1",
strategy_name="Momentum",
total_return=Decimal("0.25"),
sharpe_ratio=Decimal("1.5"),
volatility=Decimal("0.15"),
max_drawdown=Decimal("-0.12"),
))
comparator.add_strategy(StrategyMetrics(
strategy_id="strat2",
strategy_name="Value",
total_return=Decimal("0.18"),
sharpe_ratio=Decimal("1.2"),
volatility=Decimal("0.12"),
max_drawdown=Decimal("-0.08"),
))
result = comparator.compare()
assert result.status == ComparisonStatus.VALID
assert result.strategy_count == 2
assert len(result.pairwise_comparisons) == 1
def test_compare_by_sharpe(self):
"""Test comparison by Sharpe ratio."""
comparator = StrategyComparator()
comparator.add_strategy(StrategyMetrics(
strategy_id="high_sharpe",
strategy_name="High Sharpe",
total_return=Decimal("0.15"),
sharpe_ratio=Decimal("2.0"),
))
comparator.add_strategy(StrategyMetrics(
strategy_id="low_sharpe",
strategy_name="Low Sharpe",
total_return=Decimal("0.25"),
sharpe_ratio=Decimal("1.0"),
))
result = comparator.compare(primary_criteria=RankingCriteria.SHARPE_RATIO)
assert result.best_overall == "high_sharpe"
def test_compare_by_return(self):
"""Test comparison by total return."""
comparator = StrategyComparator()
comparator.add_strategy(StrategyMetrics(
strategy_id="high_return",
strategy_name="High Return",
total_return=Decimal("0.30"),
sharpe_ratio=Decimal("1.0"),
))
comparator.add_strategy(StrategyMetrics(
strategy_id="low_return",
strategy_name="Low Return",
total_return=Decimal("0.10"),
sharpe_ratio=Decimal("2.0"),
))
result = comparator.compare(primary_criteria=RankingCriteria.TOTAL_RETURN)
assert result.best_overall == "high_return"
def test_rankings_all_criteria(self):
"""Test rankings for all criteria."""
comparator = StrategyComparator()
comparator.add_strategy(StrategyMetrics(
strategy_id="s1",
strategy_name="Strategy 1",
total_return=Decimal("0.20"),
sharpe_ratio=Decimal("1.5"),
sortino_ratio=Decimal("2.0"),
max_drawdown=Decimal("-0.10"),
win_rate=Decimal("0.55"),
profit_factor=Decimal("1.5"),
calmar_ratio=Decimal("2.0"),
volatility=Decimal("0.10"),
))
comparator.add_strategy(StrategyMetrics(
strategy_id="s2",
strategy_name="Strategy 2",
total_return=Decimal("0.15"),
sharpe_ratio=Decimal("1.8"),
sortino_ratio=Decimal("2.5"),
max_drawdown=Decimal("-0.08"),
win_rate=Decimal("0.60"),
profit_factor=Decimal("1.8"),
calmar_ratio=Decimal("1.8"),
volatility=Decimal("0.08"),
))
result = comparator.compare()
assert result.status == ComparisonStatus.VALID
assert len(result.rankings) == 8 # All criteria
# ==============================================================================
# StrategyComparator Tests - Statistical Testing
# ==============================================================================
class TestStrategyComparatorStatistics:
"""Tests for statistical significance testing."""
def test_pairwise_with_return_series(self):
"""Test pairwise comparison with return series."""
comparator = StrategyComparator()
# Strategy with higher mean returns
high_returns = [Decimal("0.02")] * 50
low_returns = [Decimal("0.01")] * 50
comparator.add_strategy(StrategyMetrics(
strategy_id="high",
strategy_name="High Returns",
total_return=Decimal("0.50"),
returns_series=high_returns,
))
comparator.add_strategy(StrategyMetrics(
strategy_id="low",
strategy_name="Low Returns",
total_return=Decimal("0.25"),
returns_series=low_returns,
))
result = comparator.compare()
assert len(result.pairwise_comparisons) == 1
comparison = result.pairwise_comparisons[0]
# With identical values, should be highly significant
assert comparison.p_value is not None
def test_pairwise_insufficient_data(self):
"""Test pairwise comparison with insufficient data."""
comparator = StrategyComparator()
comparator.add_strategy(StrategyMetrics(
strategy_id="s1",
strategy_name="S1",
total_return=Decimal("0.20"),
returns_series=[Decimal("0.01")] * 10, # Not enough
))
comparator.add_strategy(StrategyMetrics(
strategy_id="s2",
strategy_name="S2",
total_return=Decimal("0.15"),
returns_series=[Decimal("0.01")] * 10,
))
result = comparator.compare()
comparison = result.pairwise_comparisons[0]
# With insufficient data, no statistical test performed
assert comparison.p_value is None
# ==============================================================================
# StrategyComparator Tests - Summary Statistics
# ==============================================================================
class TestStrategyComparatorSummary:
"""Tests for summary statistics."""
def test_summary_statistics(self):
"""Test summary statistics calculation."""
comparator = StrategyComparator()
comparator.add_strategy(StrategyMetrics(
strategy_id="s1",
strategy_name="S1",
total_return=Decimal("0.20"),
volatility=Decimal("0.15"),
max_drawdown=Decimal("-0.12"),
sharpe_ratio=Decimal("1.5"),
))
comparator.add_strategy(StrategyMetrics(
strategy_id="s2",
strategy_name="S2",
total_return=Decimal("0.10"),
volatility=Decimal("0.10"),
max_drawdown=Decimal("-0.08"),
sharpe_ratio=Decimal("1.2"),
))
comparator.add_strategy(StrategyMetrics(
strategy_id="s3",
strategy_name="S3",
total_return=Decimal("0.30"),
volatility=Decimal("0.20"),
max_drawdown=Decimal("-0.15"),
sharpe_ratio=Decimal("1.8"),
))
result = comparator.compare()
summary = result.summary_statistics
assert summary["strategy_count"] == 3
assert "return" in summary
assert "volatility" in summary
assert "max_drawdown" in summary
assert "sharpe_ratio" in summary
# Return stats
assert summary["return"]["min"] == pytest.approx(0.10)
assert summary["return"]["max"] == pytest.approx(0.30)
# ==============================================================================
# StrategyComparator Tests - Recommendations
# ==============================================================================
class TestStrategyComparatorRecommendations:
"""Tests for recommendation generation."""
def test_high_volatility_warning(self):
"""Test high volatility warning."""
comparator = StrategyComparator()
comparator.add_strategy(StrategyMetrics(
strategy_id="high_vol",
strategy_name="High Vol Strategy",
total_return=Decimal("0.40"),
volatility=Decimal("0.35"), # Very high
sharpe_ratio=Decimal("1.2"),
))
comparator.add_strategy(StrategyMetrics(
strategy_id="low_vol",
strategy_name="Low Vol Strategy",
total_return=Decimal("0.15"),
volatility=Decimal("0.08"),
sharpe_ratio=Decimal("1.8"),
))
result = comparator.compare()
vol_warnings = [r for r in result.recommendations if "volatility" in r.lower()]
assert len(vol_warnings) >= 1
def test_drawdown_warning(self):
"""Test significant drawdown warning."""
comparator = StrategyComparator()
comparator.add_strategy(StrategyMetrics(
strategy_id="big_dd",
strategy_name="Big Drawdown",
total_return=Decimal("0.30"),
max_drawdown=Decimal("-0.35"), # Very large
sharpe_ratio=Decimal("1.0"),
))
comparator.add_strategy(StrategyMetrics(
strategy_id="small_dd",
strategy_name="Small Drawdown",
total_return=Decimal("0.15"),
max_drawdown=Decimal("-0.08"),
sharpe_ratio=Decimal("1.5"),
))
result = comparator.compare()
dd_warnings = [r for r in result.recommendations if "drawdown" in r.lower()]
assert len(dd_warnings) >= 1
def test_low_trades_warning(self):
"""Test limited trade history warning."""
comparator = StrategyComparator()
comparator.add_strategy(StrategyMetrics(
strategy_id="few_trades",
strategy_name="Few Trades",
total_return=Decimal("0.50"),
total_trades=15, # Low
sharpe_ratio=Decimal("2.0"),
))
comparator.add_strategy(StrategyMetrics(
strategy_id="many_trades",
strategy_name="Many Trades",
total_return=Decimal("0.20"),
total_trades=500,
sharpe_ratio=Decimal("1.5"),
))
result = comparator.compare()
trade_warnings = [r for r in result.recommendations if "trade" in r.lower()]
assert len(trade_warnings) >= 1
# ==============================================================================
# StrategyComparator Tests - Compare Returns
# ==============================================================================
class TestStrategyComparatorReturns:
"""Tests for return comparison."""
def test_compare_returns_basic(self):
"""Test basic return comparison."""
comparator = StrategyComparator()
comparator.add_strategy(StrategyMetrics(
strategy_id="s1",
strategy_name="Strategy 1",
total_return=Decimal("0.25"),
annualized_return=Decimal("0.20"),
best_trade=Decimal("0.05"),
worst_trade=Decimal("-0.03"),
avg_trade_return=Decimal("0.01"),
))
comparator.add_strategy(StrategyMetrics(
strategy_id="s2",
strategy_name="Strategy 2",
total_return=Decimal("0.18"),
annualized_return=Decimal("0.15"),
best_trade=Decimal("0.03"),
worst_trade=Decimal("-0.02"),
avg_trade_return=Decimal("0.008"),
))
comparison = comparator.compare_returns()
assert "strategies" in comparison
assert len(comparison["strategies"]) == 2
assert "returns" in comparison
def test_compare_returns_with_series(self):
"""Test return comparison with return series."""
comparator = StrategyComparator()
returns = [Decimal("0.01"), Decimal("0.02"), Decimal("-0.01"), Decimal("0.015")]
returns_extended = returns * 10 # 40 values
comparator.add_strategy(StrategyMetrics(
strategy_id="s1",
strategy_name="Strategy 1",
total_return=Decimal("0.25"),
returns_series=returns_extended,
))
comparison = comparator.compare_returns()
assert "series_stats" in comparison["returns"]["Strategy 1"]
stats = comparison["returns"]["Strategy 1"]["series_stats"]
assert stats["count"] == 40
# ==============================================================================
# StrategyComparator Tests - Ranking Table
# ==============================================================================
class TestStrategyComparatorRankingTable:
"""Tests for ranking table generation."""
def test_get_ranking_table_empty(self):
"""Test ranking table with no strategies."""
comparator = StrategyComparator()
table = comparator.get_ranking_table()
assert table == []
def test_get_ranking_table(self):
"""Test ranking table generation."""
comparator = StrategyComparator()
comparator.add_strategy(StrategyMetrics(
strategy_id="s1",
strategy_name="Best",
total_return=Decimal("0.30"),
sharpe_ratio=Decimal("2.0"),
max_drawdown=Decimal("-0.05"),
win_rate=Decimal("0.65"),
))
comparator.add_strategy(StrategyMetrics(
strategy_id="s2",
strategy_name="Worst",
total_return=Decimal("0.10"),
sharpe_ratio=Decimal("0.8"),
max_drawdown=Decimal("-0.20"),
win_rate=Decimal("0.45"),
))
comparator.add_strategy(StrategyMetrics(
strategy_id="s3",
strategy_name="Middle",
total_return=Decimal("0.20"),
sharpe_ratio=Decimal("1.5"),
max_drawdown=Decimal("-0.10"),
win_rate=Decimal("0.55"),
))
table = comparator.get_ranking_table()
assert len(table) == 3
# Table should be sorted by average rank
assert "avg_rank" in table[0]
assert "rankings" in table[0]
# First should be the best overall
assert table[0]["strategy_name"] == "Best"
# ==============================================================================
# Module Import Tests
# ==============================================================================
class TestModuleImports:
"""Tests for module imports."""
def test_import_from_simulation_module(self):
"""Test importing from simulation module."""
from tradingagents.simulation import (
RankingCriteria,
ComparisonStatus,
StrategyMetrics,
PairwiseComparison,
ComparisonResult,
StrategyComparator,
)
assert RankingCriteria is not None
assert ComparisonStatus is not None
assert StrategyMetrics is not None
assert PairwiseComparison is not None
assert ComparisonResult is not None
assert StrategyComparator is not None
# ==============================================================================
# Integration Tests
# ==============================================================================
class TestStrategyComparatorIntegration:
"""Integration tests for StrategyComparator."""
def test_full_comparison_workflow(self):
"""Test complete comparison workflow."""
comparator = StrategyComparator()
# Add multiple strategies with various characteristics
strategies = [
StrategyMetrics(
strategy_id="momentum",
strategy_name="Momentum",
total_return=Decimal("0.28"),
annualized_return=Decimal("0.22"),
volatility=Decimal("0.18"),
sharpe_ratio=Decimal("1.55"),
sortino_ratio=Decimal("2.10"),
max_drawdown=Decimal("-0.14"),
calmar_ratio=Decimal("1.57"),
win_rate=Decimal("0.52"),
profit_factor=Decimal("1.45"),
total_trades=150,
avg_trade_return=Decimal("0.0018"),
),
StrategyMetrics(
strategy_id="value",
strategy_name="Value",
total_return=Decimal("0.18"),
annualized_return=Decimal("0.15"),
volatility=Decimal("0.12"),
sharpe_ratio=Decimal("1.25"),
sortino_ratio=Decimal("1.60"),
max_drawdown=Decimal("-0.10"),
calmar_ratio=Decimal("1.50"),
win_rate=Decimal("0.58"),
profit_factor=Decimal("1.65"),
total_trades=80,
avg_trade_return=Decimal("0.0022"),
),
StrategyMetrics(
strategy_id="growth",
strategy_name="Growth",
total_return=Decimal("0.35"),
annualized_return=Decimal("0.28"),
volatility=Decimal("0.25"),
sharpe_ratio=Decimal("1.40"),
sortino_ratio=Decimal("1.80"),
max_drawdown=Decimal("-0.22"),
calmar_ratio=Decimal("1.27"),
win_rate=Decimal("0.48"),
profit_factor=Decimal("1.35"),
total_trades=200,
avg_trade_return=Decimal("0.0017"),
),
]
for s in strategies:
comparator.add_strategy(s)
# Compare by Sharpe ratio
result = comparator.compare(primary_criteria=RankingCriteria.SHARPE_RATIO)
# Verify result
assert result.status == ComparisonStatus.VALID
assert result.strategy_count == 3
assert result.best_overall == "momentum" # Highest Sharpe
# Verify rankings
assert "momentum" in result.rankings[RankingCriteria.SHARPE_RATIO][:1]
assert "growth" in result.rankings[RankingCriteria.TOTAL_RETURN][:1]
# Verify pairwise comparisons
assert len(result.pairwise_comparisons) == 3 # 3 choose 2
# Verify summary statistics
assert result.summary_statistics["strategy_count"] == 3
# Get ranking table
table = comparator.get_ranking_table()
assert len(table) == 3

View File

@ -7,23 +7,31 @@ This module provides simulation capabilities including:
- Economic regime simulation
Issue #33: [SIM-32] Scenario runner - parallel portfolio simulations
Issue #34: [SIM-33] Strategy comparator - performance comparison, stats
Submodules:
scenario_runner: Core scenario execution framework
strategy_comparator: Strategy comparison and statistical analysis
Classes:
Enums:
- ExecutionMode: Parallel execution mode (sequential, threaded, process)
- ScenarioStatus: Status of a scenario run
- RankingCriteria: Criteria for ranking strategies
- ComparisonStatus: Status of strategy comparison
Data Classes:
- ScenarioConfig: Configuration for a simulation scenario
- ScenarioResult: Result from a scenario simulation
- RunnerProgress: Progress information for batch runs
- StrategyMetrics: Performance metrics for a strategy
- PairwiseComparison: Comparison between two strategies
- ComparisonResult: Complete result of strategy comparison
Main Classes:
- ScenarioRunner: Runner for parallel portfolio simulations
- ScenarioBatchBuilder: Builder for creating scenario batches
- StrategyComparator: Compares multiple trading strategies
Protocols:
- ScenarioExecutor: Protocol for scenario execution functions
@ -76,21 +84,42 @@ from .scenario_runner import (
aggregate_results,
)
__all__ = [
from .strategy_comparator import (
# Enums
RankingCriteria,
ComparisonStatus,
# Data Classes
StrategyMetrics,
PairwiseComparison,
ComparisonResult,
# Main Class
StrategyComparator,
)
__all__ = [
# Scenario Runner Enums
"ExecutionMode",
"ScenarioStatus",
# Data Classes
# Scenario Runner Data Classes
"ScenarioConfig",
"ScenarioResult",
"RunnerProgress",
# Main Classes
# Scenario Runner Main Classes
"ScenarioRunner",
"ScenarioBatchBuilder",
# Protocols
# Scenario Runner Protocols
"ScenarioExecutor",
# Types
# Scenario Runner Types
"ProgressCallback",
# Utility Functions
# Scenario Runner Utility Functions
"aggregate_results",
# Strategy Comparator Enums
"RankingCriteria",
"ComparisonStatus",
# Strategy Comparator Data Classes
"StrategyMetrics",
"PairwiseComparison",
"ComparisonResult",
# Strategy Comparator Main Class
"StrategyComparator",
]

View File

@ -0,0 +1,737 @@
"""Strategy Comparator for performance comparison and statistical analysis.
This module provides strategy comparison capabilities including:
- Performance metrics comparison across strategies
- Statistical significance testing
- Ranking and scoring
- Visualization data preparation
Issue #34: [SIM-33] Strategy comparator - performance comparison, stats
Design Principles:
- Comprehensive performance metrics
- Statistical rigor (hypothesis testing)
- Flexible ranking criteria
- Clear comparison outputs
"""
from dataclasses import dataclass, field
from datetime import date
from decimal import Decimal, ROUND_HALF_UP
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
import math
import statistics
class RankingCriteria(Enum):
"""Criteria for ranking strategies."""
TOTAL_RETURN = "total_return"
SHARPE_RATIO = "sharpe_ratio"
SORTINO_RATIO = "sortino_ratio"
MAX_DRAWDOWN = "max_drawdown"
WIN_RATE = "win_rate"
PROFIT_FACTOR = "profit_factor"
CALMAR_RATIO = "calmar_ratio"
RISK_ADJUSTED_RETURN = "risk_adjusted_return"
class ComparisonStatus(Enum):
"""Status of strategy comparison."""
VALID = "valid"
INSUFFICIENT_DATA = "insufficient_data"
INCOMPARABLE = "incomparable"
@dataclass
class StrategyMetrics:
"""Performance metrics for a strategy.
Attributes:
strategy_id: Unique identifier
strategy_name: Human-readable name
start_date: First date of performance data
end_date: Last date of performance data
total_return: Total return as decimal (0.10 = 10%)
annualized_return: Annualized return
volatility: Annualized standard deviation of returns
sharpe_ratio: Sharpe ratio (risk-free rate assumed 0)
sortino_ratio: Sortino ratio (downside deviation)
max_drawdown: Maximum drawdown (negative)
calmar_ratio: Annualized return / max drawdown
win_rate: Percentage of winning trades
profit_factor: Gross profit / gross loss
total_trades: Number of trades executed
avg_trade_return: Average return per trade
avg_win: Average winning trade
avg_loss: Average losing trade
best_trade: Best single trade return
worst_trade: Worst single trade return
returns_series: Time series of periodic returns
equity_curve: Time series of equity values
metadata: Additional strategy data
"""
strategy_id: str
strategy_name: str
start_date: Optional[date] = None
end_date: Optional[date] = None
total_return: Decimal = Decimal("0")
annualized_return: Decimal = Decimal("0")
volatility: Decimal = Decimal("0")
sharpe_ratio: Optional[Decimal] = None
sortino_ratio: Optional[Decimal] = None
max_drawdown: Decimal = Decimal("0")
calmar_ratio: Optional[Decimal] = None
win_rate: Decimal = Decimal("0")
profit_factor: Optional[Decimal] = None
total_trades: int = 0
avg_trade_return: Decimal = Decimal("0")
avg_win: Decimal = Decimal("0")
avg_loss: Decimal = Decimal("0")
best_trade: Decimal = Decimal("0")
worst_trade: Decimal = Decimal("0")
returns_series: List[Decimal] = field(default_factory=list)
equity_curve: List[Tuple[date, Decimal]] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def risk_adjusted_return(self) -> Decimal:
"""Calculate risk-adjusted return (return / volatility)."""
if self.volatility == 0:
return Decimal("0")
return (self.total_return / self.volatility).quantize(
Decimal("0.0001"), rounding=ROUND_HALF_UP
)
@property
def has_sufficient_data(self) -> bool:
"""Check if strategy has sufficient data for comparison."""
return self.total_trades >= 10 or len(self.returns_series) >= 30
@dataclass
class PairwiseComparison:
"""Comparison between two strategies.
Attributes:
strategy_a_id: First strategy ID
strategy_b_id: Second strategy ID
winner: ID of the winning strategy (or None if tie)
return_difference: Difference in total returns (A - B)
sharpe_difference: Difference in Sharpe ratios
volatility_difference: Difference in volatility
statistically_significant: Whether difference is significant
p_value: P-value from statistical test
confidence_interval: 95% CI for return difference
notes: Additional comparison notes
"""
strategy_a_id: str
strategy_b_id: str
winner: Optional[str] = None
return_difference: Decimal = Decimal("0")
sharpe_difference: Optional[Decimal] = None
volatility_difference: Decimal = Decimal("0")
statistically_significant: bool = False
p_value: Optional[float] = None
confidence_interval: Optional[Tuple[float, float]] = None
notes: str = ""
@dataclass
class ComparisonResult:
"""Complete result of strategy comparison.
Attributes:
status: Comparison status
strategies: List of strategies compared
rankings: Strategies ranked by criteria
best_overall: Best overall strategy ID
worst_overall: Worst overall strategy ID
pairwise_comparisons: Pairwise comparison results
summary_statistics: Summary statistics across all strategies
recommendations: Analysis recommendations
metadata: Additional result data
"""
status: ComparisonStatus
strategies: List[StrategyMetrics] = field(default_factory=list)
rankings: Dict[RankingCriteria, List[str]] = field(default_factory=dict)
best_overall: Optional[str] = None
worst_overall: Optional[str] = None
pairwise_comparisons: List[PairwiseComparison] = field(default_factory=list)
summary_statistics: Dict[str, Any] = field(default_factory=dict)
recommendations: List[str] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def strategy_count(self) -> int:
"""Number of strategies compared."""
return len(self.strategies)
class StrategyComparator:
"""Compares multiple trading strategies.
Provides comprehensive comparison of strategy performance including
statistical tests, rankings, and recommendations.
Example:
>>> comparator = StrategyComparator()
>>> comparator.add_strategy(StrategyMetrics(
... strategy_id="strat1",
... strategy_name="Momentum",
... total_return=Decimal("0.25"),
... sharpe_ratio=Decimal("1.5"),
... ))
>>> comparator.add_strategy(StrategyMetrics(
... strategy_id="strat2",
... strategy_name="Value",
... total_return=Decimal("0.18"),
... sharpe_ratio=Decimal("1.2"),
... ))
>>> result = comparator.compare()
>>> print(f"Best: {result.best_overall}")
"""
def __init__(
self,
risk_free_rate: Decimal = Decimal("0"),
min_data_points: int = 30,
significance_level: float = 0.05,
):
"""Initialize the comparator.
Args:
risk_free_rate: Risk-free rate for Sharpe calculations
min_data_points: Minimum data points for valid comparison
significance_level: Significance level for statistical tests
"""
self.risk_free_rate = risk_free_rate
self.min_data_points = min_data_points
self.significance_level = significance_level
self._strategies: Dict[str, StrategyMetrics] = {}
def add_strategy(self, strategy: StrategyMetrics) -> None:
"""Add a strategy for comparison.
Args:
strategy: Strategy metrics to add
"""
self._strategies[strategy.strategy_id] = strategy
def remove_strategy(self, strategy_id: str) -> bool:
"""Remove a strategy from comparison.
Args:
strategy_id: ID of strategy to remove
Returns:
True if removed, False if not found
"""
if strategy_id in self._strategies:
del self._strategies[strategy_id]
return True
return False
def get_strategy(self, strategy_id: str) -> Optional[StrategyMetrics]:
"""Get a strategy by ID.
Args:
strategy_id: Strategy ID
Returns:
Strategy metrics or None if not found
"""
return self._strategies.get(strategy_id)
def get_all_strategies(self) -> List[StrategyMetrics]:
"""Get all strategies.
Returns:
List of all strategy metrics
"""
return list(self._strategies.values())
def clear(self) -> None:
"""Remove all strategies."""
self._strategies.clear()
def _rank_by_criteria(
self, criteria: RankingCriteria
) -> List[str]:
"""Rank strategies by a specific criterion.
Args:
criteria: Ranking criterion
Returns:
List of strategy IDs in ranked order (best first)
"""
strategies = list(self._strategies.values())
if criteria == RankingCriteria.TOTAL_RETURN:
key = lambda s: float(s.total_return)
reverse = True
elif criteria == RankingCriteria.SHARPE_RATIO:
key = lambda s: float(s.sharpe_ratio or Decimal("-999"))
reverse = True
elif criteria == RankingCriteria.SORTINO_RATIO:
key = lambda s: float(s.sortino_ratio or Decimal("-999"))
reverse = True
elif criteria == RankingCriteria.MAX_DRAWDOWN:
# Less negative is better
key = lambda s: float(s.max_drawdown)
reverse = True
elif criteria == RankingCriteria.WIN_RATE:
key = lambda s: float(s.win_rate)
reverse = True
elif criteria == RankingCriteria.PROFIT_FACTOR:
key = lambda s: float(s.profit_factor or Decimal("0"))
reverse = True
elif criteria == RankingCriteria.CALMAR_RATIO:
key = lambda s: float(s.calmar_ratio or Decimal("-999"))
reverse = True
elif criteria == RankingCriteria.RISK_ADJUSTED_RETURN:
key = lambda s: float(s.risk_adjusted_return)
reverse = True
else:
key = lambda s: float(s.total_return)
reverse = True
sorted_strategies = sorted(strategies, key=key, reverse=reverse)
return [s.strategy_id for s in sorted_strategies]
def _calculate_pairwise_comparison(
self,
strategy_a: StrategyMetrics,
strategy_b: StrategyMetrics,
) -> PairwiseComparison:
"""Calculate pairwise comparison between two strategies.
Args:
strategy_a: First strategy
strategy_b: Second strategy
Returns:
Pairwise comparison result
"""
return_diff = strategy_a.total_return - strategy_b.total_return
sharpe_diff = None
if strategy_a.sharpe_ratio is not None and strategy_b.sharpe_ratio is not None:
sharpe_diff = strategy_a.sharpe_ratio - strategy_b.sharpe_ratio
vol_diff = strategy_a.volatility - strategy_b.volatility
# Determine winner based on Sharpe ratio (or return if no Sharpe)
winner = None
if sharpe_diff is not None:
if sharpe_diff > Decimal("0.1"): # Meaningful difference
winner = strategy_a.strategy_id
elif sharpe_diff < Decimal("-0.1"):
winner = strategy_b.strategy_id
elif return_diff > Decimal("0.05"):
winner = strategy_a.strategy_id
elif return_diff < Decimal("-0.05"):
winner = strategy_b.strategy_id
# Statistical significance test
significant = False
p_value = None
ci = None
# Perform t-test if we have return series
if (len(strategy_a.returns_series) >= self.min_data_points and
len(strategy_b.returns_series) >= self.min_data_points):
try:
t_stat, p_value, ci = self._welch_t_test(
[float(r) for r in strategy_a.returns_series],
[float(r) for r in strategy_b.returns_series],
)
significant = p_value < self.significance_level
except Exception:
pass
notes = ""
if significant:
notes = f"Statistically significant difference (p={p_value:.4f})"
elif p_value is not None:
notes = f"Not statistically significant (p={p_value:.4f})"
return PairwiseComparison(
strategy_a_id=strategy_a.strategy_id,
strategy_b_id=strategy_b.strategy_id,
winner=winner,
return_difference=return_diff.quantize(Decimal("0.0001")),
sharpe_difference=sharpe_diff.quantize(Decimal("0.0001")) if sharpe_diff else None,
volatility_difference=vol_diff.quantize(Decimal("0.0001")),
statistically_significant=significant,
p_value=p_value,
confidence_interval=ci,
notes=notes,
)
def _welch_t_test(
self,
sample_a: List[float],
sample_b: List[float],
) -> Tuple[float, float, Tuple[float, float]]:
"""Perform Welch's t-test for unequal variances.
Args:
sample_a: First sample
sample_b: Second sample
Returns:
Tuple of (t-statistic, p-value, 95% CI)
"""
n_a = len(sample_a)
n_b = len(sample_b)
mean_a = statistics.mean(sample_a)
mean_b = statistics.mean(sample_b)
var_a = statistics.variance(sample_a) if n_a > 1 else 0
var_b = statistics.variance(sample_b) if n_b > 1 else 0
# Pooled standard error
se = math.sqrt(var_a / n_a + var_b / n_b) if (var_a + var_b) > 0 else 0.0001
# t-statistic
t_stat = (mean_a - mean_b) / se
# Degrees of freedom (Welch-Satterthwaite)
if var_a > 0 or var_b > 0:
num = (var_a / n_a + var_b / n_b) ** 2
denom = (
(var_a / n_a) ** 2 / (n_a - 1) +
(var_b / n_b) ** 2 / (n_b - 1)
)
df = num / denom if denom > 0 else n_a + n_b - 2
else:
df = n_a + n_b - 2
# Approximate p-value using normal distribution (good for large samples)
# For more accuracy, would use t-distribution
p_value = 2 * (1 - self._normal_cdf(abs(t_stat)))
# 95% confidence interval
z = 1.96 # For 95% CI
ci_lower = (mean_a - mean_b) - z * se
ci_upper = (mean_a - mean_b) + z * se
return t_stat, p_value, (ci_lower, ci_upper)
@staticmethod
def _normal_cdf(x: float) -> float:
"""Approximate standard normal CDF.
Args:
x: Value
Returns:
Cumulative probability
"""
return (1 + math.erf(x / math.sqrt(2))) / 2
def _calculate_summary_statistics(
self, strategies: List[StrategyMetrics]
) -> Dict[str, Any]:
"""Calculate summary statistics across all strategies.
Args:
strategies: List of strategies
Returns:
Dictionary of summary statistics
"""
if not strategies:
return {}
returns = [float(s.total_return) for s in strategies]
vols = [float(s.volatility) for s in strategies]
sharpes = [
float(s.sharpe_ratio) for s in strategies
if s.sharpe_ratio is not None
]
drawdowns = [float(s.max_drawdown) for s in strategies]
summary = {
"strategy_count": len(strategies),
"return": {
"mean": statistics.mean(returns),
"median": statistics.median(returns),
"min": min(returns),
"max": max(returns),
"stdev": statistics.stdev(returns) if len(returns) > 1 else 0,
},
"volatility": {
"mean": statistics.mean(vols),
"median": statistics.median(vols),
"min": min(vols),
"max": max(vols),
},
"max_drawdown": {
"mean": statistics.mean(drawdowns),
"worst": min(drawdowns),
"best": max(drawdowns),
},
}
if sharpes:
summary["sharpe_ratio"] = {
"mean": statistics.mean(sharpes),
"median": statistics.median(sharpes),
"min": min(sharpes),
"max": max(sharpes),
}
return summary
def _generate_recommendations(
self,
strategies: List[StrategyMetrics],
rankings: Dict[RankingCriteria, List[str]],
) -> List[str]:
"""Generate analysis recommendations.
Args:
strategies: List of strategies
rankings: Rankings by criteria
Returns:
List of recommendation strings
"""
recommendations = []
if len(strategies) < 2:
recommendations.append(
"Add more strategies for meaningful comparison."
)
return recommendations
# Consistency check
best_return = rankings[RankingCriteria.TOTAL_RETURN][0]
best_sharpe = rankings.get(RankingCriteria.SHARPE_RATIO, [None])[0]
if best_return != best_sharpe and best_sharpe:
recommendations.append(
f"'{self._strategies[best_return].strategy_name}' has highest "
f"return but '{self._strategies[best_sharpe].strategy_name}' "
"has best risk-adjusted performance."
)
# Volatility warning
for s in strategies:
if float(s.volatility) > 0.30: # 30% annual volatility
recommendations.append(
f"'{s.strategy_name}' has high volatility "
f"({float(s.volatility)*100:.1f}%). Consider risk reduction."
)
# Drawdown warning
for s in strategies:
if float(s.max_drawdown) < -0.25: # 25% drawdown
recommendations.append(
f"'{s.strategy_name}' experienced significant drawdown "
f"({float(s.max_drawdown)*100:.1f}%). Review risk management."
)
# Insufficient trades warning
for s in strategies:
if s.total_trades < 30:
recommendations.append(
f"'{s.strategy_name}' has limited trade history "
f"({s.total_trades} trades). Results may not be reliable."
)
return recommendations
def compare(
self,
primary_criteria: RankingCriteria = RankingCriteria.SHARPE_RATIO,
) -> ComparisonResult:
"""Compare all added strategies.
Args:
primary_criteria: Primary criterion for determining best/worst
Returns:
Complete comparison result
"""
strategies = list(self._strategies.values())
if len(strategies) == 0:
return ComparisonResult(
status=ComparisonStatus.INSUFFICIENT_DATA,
recommendations=["No strategies to compare."],
)
if len(strategies) == 1:
return ComparisonResult(
status=ComparisonStatus.INSUFFICIENT_DATA,
strategies=strategies,
best_overall=strategies[0].strategy_id,
worst_overall=strategies[0].strategy_id,
recommendations=["Only one strategy provided. Add more for comparison."],
)
# Calculate rankings for each criterion
rankings = {}
for criteria in RankingCriteria:
rankings[criteria] = self._rank_by_criteria(criteria)
# Determine best and worst overall
primary_ranking = rankings[primary_criteria]
best_overall = primary_ranking[0] if primary_ranking else None
worst_overall = primary_ranking[-1] if primary_ranking else None
# Calculate pairwise comparisons
pairwise = []
for i, s_a in enumerate(strategies):
for s_b in strategies[i+1:]:
comparison = self._calculate_pairwise_comparison(s_a, s_b)
pairwise.append(comparison)
# Calculate summary statistics
summary = self._calculate_summary_statistics(strategies)
# Generate recommendations
recommendations = self._generate_recommendations(strategies, rankings)
return ComparisonResult(
status=ComparisonStatus.VALID,
strategies=strategies,
rankings=rankings,
best_overall=best_overall,
worst_overall=worst_overall,
pairwise_comparisons=pairwise,
summary_statistics=summary,
recommendations=recommendations,
)
def compare_returns(
self,
strategy_ids: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Compare return distributions.
Args:
strategy_ids: Specific strategies to compare (None = all)
Returns:
Return comparison data
"""
if strategy_ids:
strategies = [
self._strategies[sid] for sid in strategy_ids
if sid in self._strategies
]
else:
strategies = list(self._strategies.values())
comparison = {
"strategies": [s.strategy_name for s in strategies],
"returns": {
s.strategy_name: {
"total": str(s.total_return),
"annualized": str(s.annualized_return),
"best_trade": str(s.best_trade),
"worst_trade": str(s.worst_trade),
"avg_trade": str(s.avg_trade_return),
}
for s in strategies
},
}
# Return series statistics if available
for s in strategies:
if s.returns_series:
returns = [float(r) for r in s.returns_series]
comparison["returns"][s.strategy_name]["series_stats"] = {
"count": len(returns),
"mean": statistics.mean(returns),
"median": statistics.median(returns),
"stdev": statistics.stdev(returns) if len(returns) > 1 else 0,
"skew": self._calculate_skew(returns),
"kurtosis": self._calculate_kurtosis(returns),
}
return comparison
@staticmethod
def _calculate_skew(data: List[float]) -> float:
"""Calculate skewness of a distribution.
Args:
data: List of values
Returns:
Skewness value
"""
if len(data) < 3:
return 0.0
n = len(data)
mean = statistics.mean(data)
std = statistics.stdev(data)
if std == 0:
return 0.0
return sum((x - mean) ** 3 for x in data) / ((n - 1) * std ** 3)
@staticmethod
def _calculate_kurtosis(data: List[float]) -> float:
"""Calculate excess kurtosis of a distribution.
Args:
data: List of values
Returns:
Excess kurtosis value
"""
if len(data) < 4:
return 0.0
n = len(data)
mean = statistics.mean(data)
std = statistics.stdev(data)
if std == 0:
return 0.0
return sum((x - mean) ** 4 for x in data) / ((n - 1) * std ** 4) - 3
def get_ranking_table(self) -> List[Dict[str, Any]]:
"""Generate a ranking table for all strategies.
Returns:
List of strategy data with rankings
"""
if not self._strategies:
return []
# Get all rankings
rankings = {}
for criteria in RankingCriteria:
for rank, sid in enumerate(self._rank_by_criteria(criteria), 1):
if sid not in rankings:
rankings[sid] = {}
rankings[sid][criteria.value] = rank
# Build table
table = []
for sid, strategy in self._strategies.items():
row = {
"strategy_id": sid,
"strategy_name": strategy.strategy_name,
"total_return": str(strategy.total_return),
"sharpe_ratio": str(strategy.sharpe_ratio) if strategy.sharpe_ratio else "N/A",
"max_drawdown": str(strategy.max_drawdown),
"win_rate": str(strategy.win_rate),
"rankings": rankings.get(sid, {}),
}
# Calculate average rank
ranks = list(rankings.get(sid, {}).values())
row["avg_rank"] = sum(ranks) / len(ranks) if ranks else 0
table.append(row)
# Sort by average rank
table.sort(key=lambda x: x["avg_rank"])
return table