feat(simulation): add Scenario Runner for parallel portfolio simulations - Issue #33 (45 tests)

Implements parallel scenario execution framework:
- ScenarioRunner with sequential, threaded, and process execution modes
- ScenarioConfig for configuring simulation parameters
- ScenarioResult for capturing simulation outcomes
- RunnerProgress for tracking execution progress
- Progress callbacks for real-time updates
- Cancellation support for long-running batches
- ScenarioBatchBuilder for parameter sweeps and variations
- Result aggregation with best/worst scenario identification

Features:
- Thread-safe parallel execution with configurable worker count
- FIFO result ordering preserved regardless of completion order
- Exception handling with graceful degradation
- Timeout support per scenario
- Cartesian product generation for parameter variations

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Andrew Kaszubski 2025-12-26 21:59:12 +11:00
parent 13f2bba0b3
commit e7bff2c4cf
4 changed files with 1621 additions and 0 deletions

View File

@ -0,0 +1 @@
"""Unit tests for the simulation module."""

View File

@ -0,0 +1,835 @@
"""Tests for Scenario Runner.
Issue #33: [SIM-32] Scenario runner - parallel portfolio simulations
Tests cover:
- ExecutionMode and ScenarioStatus enums
- ScenarioConfig and ScenarioResult dataclasses
- ScenarioRunner sequential execution
- ScenarioRunner parallel execution
- Progress tracking and callbacks
- ScenarioBatchBuilder variations
- Result aggregation
- Error handling
- Cancellation
"""
import pytest
import time
from datetime import datetime, date
from decimal import Decimal
from unittest.mock import Mock, patch
import threading
from tradingagents.simulation.scenario_runner import (
ExecutionMode,
ScenarioStatus,
ScenarioConfig,
ScenarioResult,
RunnerProgress,
ScenarioRunner,
ScenarioBatchBuilder,
aggregate_results,
)
# ==============================================================================
# ExecutionMode Enum Tests
# ==============================================================================
class TestExecutionMode:
"""Tests for ExecutionMode enum."""
def test_sequential_value(self):
"""Test SEQUENTIAL mode value."""
assert ExecutionMode.SEQUENTIAL.value == "sequential"
def test_threaded_value(self):
"""Test THREADED mode value."""
assert ExecutionMode.THREADED.value == "threaded"
def test_process_value(self):
"""Test PROCESS mode value."""
assert ExecutionMode.PROCESS.value == "process"
def test_all_modes_exist(self):
"""Test all expected modes exist."""
modes = [m for m in ExecutionMode]
assert len(modes) == 3
# ==============================================================================
# ScenarioStatus Enum Tests
# ==============================================================================
class TestScenarioStatus:
"""Tests for ScenarioStatus enum."""
def test_pending_value(self):
"""Test PENDING status value."""
assert ScenarioStatus.PENDING.value == "pending"
def test_running_value(self):
"""Test RUNNING status value."""
assert ScenarioStatus.RUNNING.value == "running"
def test_completed_value(self):
"""Test COMPLETED status value."""
assert ScenarioStatus.COMPLETED.value == "completed"
def test_failed_value(self):
"""Test FAILED status value."""
assert ScenarioStatus.FAILED.value == "failed"
def test_cancelled_value(self):
"""Test CANCELLED status value."""
assert ScenarioStatus.CANCELLED.value == "cancelled"
def test_all_statuses_exist(self):
"""Test all expected statuses exist."""
statuses = [s for s in ScenarioStatus]
assert len(statuses) == 5
# ==============================================================================
# ScenarioConfig Tests
# ==============================================================================
class TestScenarioConfig:
"""Tests for ScenarioConfig dataclass."""
def test_default_config(self):
"""Test creating config with defaults."""
config = ScenarioConfig()
assert config.scenario_id is not None
assert config.name.startswith("Scenario-")
assert config.initial_capital == Decimal("100000")
assert config.symbols == []
assert config.strategy_params == {}
assert config.risk_params == {}
def test_custom_config(self):
"""Test creating config with custom values."""
config = ScenarioConfig(
name="Bull Market Test",
start_date=date(2023, 1, 1),
end_date=date(2023, 12, 31),
initial_capital=Decimal("50000"),
symbols=["AAPL", "GOOGL"],
strategy_params={"leverage": 1.5},
risk_params={"max_drawdown": 0.2},
)
assert config.name == "Bull Market Test"
assert config.start_date == date(2023, 1, 1)
assert config.end_date == date(2023, 12, 31)
assert config.initial_capital == Decimal("50000")
assert config.symbols == ["AAPL", "GOOGL"]
assert config.strategy_params["leverage"] == 1.5
assert config.risk_params["max_drawdown"] == 0.2
def test_auto_generated_name(self):
"""Test auto-generated name from scenario_id."""
config = ScenarioConfig()
# Name should start with "Scenario-" and contain part of the ID
assert config.name.startswith("Scenario-")
assert len(config.name) > 8
def test_explicit_scenario_id(self):
"""Test explicit scenario ID."""
config = ScenarioConfig(scenario_id="test-123")
assert config.scenario_id == "test-123"
# ==============================================================================
# ScenarioResult Tests
# ==============================================================================
class TestScenarioResult:
"""Tests for ScenarioResult dataclass."""
def test_successful_result(self):
"""Test creating a successful result."""
result = ScenarioResult(
scenario_id="test-1",
scenario_name="Test",
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
final_value=Decimal("110000"),
total_return=Decimal("0.10"),
)
assert result.is_successful is True
assert result.is_finished is True
def test_failed_result(self):
"""Test creating a failed result."""
result = ScenarioResult(
scenario_id="test-1",
scenario_name="Test",
status=ScenarioStatus.FAILED,
start_time=datetime.now(),
error_message="Simulation error",
)
assert result.is_successful is False
assert result.is_finished is True
assert result.error_message == "Simulation error"
def test_pending_result(self):
"""Test pending result properties."""
result = ScenarioResult(
scenario_id="test-1",
scenario_name="Test",
status=ScenarioStatus.PENDING,
start_time=datetime.now(),
)
assert result.is_successful is False
assert result.is_finished is False
def test_running_result(self):
"""Test running result properties."""
result = ScenarioResult(
scenario_id="test-1",
scenario_name="Test",
status=ScenarioStatus.RUNNING,
start_time=datetime.now(),
)
assert result.is_successful is False
assert result.is_finished is False
def test_cancelled_result(self):
"""Test cancelled result properties."""
result = ScenarioResult(
scenario_id="test-1",
scenario_name="Test",
status=ScenarioStatus.CANCELLED,
start_time=datetime.now(),
)
assert result.is_successful is False
assert result.is_finished is True
# ==============================================================================
# RunnerProgress Tests
# ==============================================================================
class TestRunnerProgress:
"""Tests for RunnerProgress dataclass."""
def test_progress_percent(self):
"""Test progress percentage calculation."""
progress = RunnerProgress(
total_scenarios=10,
completed=3,
failed=2,
)
assert progress.progress_percent == 50.0
def test_progress_percent_zero_total(self):
"""Test progress with zero total scenarios."""
progress = RunnerProgress(total_scenarios=0)
assert progress.progress_percent == 100.0
def test_is_complete(self):
"""Test completion detection."""
progress = RunnerProgress(
total_scenarios=5,
completed=3,
failed=2,
)
assert progress.is_complete is True
def test_not_complete(self):
"""Test incomplete detection."""
progress = RunnerProgress(
total_scenarios=5,
completed=2,
failed=1,
)
assert progress.is_complete is False
# ==============================================================================
# ScenarioRunner Tests - Sequential Execution
# ==============================================================================
class TestScenarioRunnerSequential:
"""Tests for ScenarioRunner sequential execution."""
def test_run_empty_list(self):
"""Test running with empty scenario list."""
def executor(config):
return ScenarioResult(
scenario_id=config.scenario_id,
scenario_name=config.name,
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
)
runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL)
results = runner.run([])
assert results == []
def test_run_single_scenario(self):
"""Test running a single scenario."""
def executor(config):
return ScenarioResult(
scenario_id=config.scenario_id,
scenario_name=config.name,
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
final_value=Decimal("110000"),
total_return=Decimal("0.10"),
)
runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL)
scenarios = [ScenarioConfig(name="Test1")]
results = runner.run(scenarios)
assert len(results) == 1
assert results[0].is_successful
assert results[0].scenario_name == "Test1"
def test_run_multiple_scenarios(self):
"""Test running multiple scenarios sequentially."""
call_order = []
def executor(config):
call_order.append(config.name)
return ScenarioResult(
scenario_id=config.scenario_id,
scenario_name=config.name,
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
)
runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL)
scenarios = [
ScenarioConfig(name="Test1"),
ScenarioConfig(name="Test2"),
ScenarioConfig(name="Test3"),
]
results = runner.run(scenarios)
assert len(results) == 3
assert call_order == ["Test1", "Test2", "Test3"]
assert all(r.is_successful for r in results)
def test_run_with_executor_exception(self):
"""Test handling executor exceptions."""
def executor(config):
if config.name == "FailMe":
raise ValueError("Simulated error")
return ScenarioResult(
scenario_id=config.scenario_id,
scenario_name=config.name,
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
)
runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL)
scenarios = [
ScenarioConfig(name="Test1"),
ScenarioConfig(name="FailMe"),
ScenarioConfig(name="Test3"),
]
results = runner.run(scenarios)
assert len(results) == 3
assert results[0].is_successful
assert results[1].status == ScenarioStatus.FAILED
assert "Simulated error" in results[1].error_message
assert results[2].is_successful
# ==============================================================================
# ScenarioRunner Tests - Parallel Execution
# ==============================================================================
class TestScenarioRunnerParallel:
"""Tests for ScenarioRunner parallel execution."""
def test_run_threaded(self):
"""Test running scenarios in parallel using threads."""
def executor(config):
time.sleep(0.01) # Small delay to ensure parallelism
return ScenarioResult(
scenario_id=config.scenario_id,
scenario_name=config.name,
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
)
runner = ScenarioRunner(
executor=executor,
mode=ExecutionMode.THREADED,
max_workers=4,
)
scenarios = [ScenarioConfig(name=f"Test{i}") for i in range(8)]
start = time.time()
results = runner.run(scenarios)
elapsed = time.time() - start
assert len(results) == 8
assert all(r.is_successful for r in results)
# With 4 workers and 8 scenarios at 0.01s each,
# parallel should be faster than sequential (0.08s)
# Allow some overhead
assert elapsed < 0.1
def test_results_order_preserved(self):
"""Test that results match input order."""
import random
def executor(config):
# Random delay to encourage out-of-order completion
time.sleep(random.uniform(0.001, 0.01))
return ScenarioResult(
scenario_id=config.scenario_id,
scenario_name=config.name,
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
final_value=Decimal(config.metadata.get("index", 0)),
)
runner = ScenarioRunner(
executor=executor,
mode=ExecutionMode.THREADED,
max_workers=4,
)
scenarios = [
ScenarioConfig(name=f"Test{i}", metadata={"index": i})
for i in range(10)
]
results = runner.run(scenarios)
assert len(results) == 10
for i, result in enumerate(results):
assert result.scenario_name == f"Test{i}"
assert result.final_value == Decimal(i)
# ==============================================================================
# ScenarioRunner Tests - Progress Tracking
# ==============================================================================
class TestScenarioRunnerProgress:
"""Tests for ScenarioRunner progress tracking."""
def test_progress_callback(self):
"""Test progress callback invocation."""
progress_updates = []
def on_progress(progress):
progress_updates.append({
"completed": progress.completed,
"total": progress.total_scenarios,
})
def executor(config):
return ScenarioResult(
scenario_id=config.scenario_id,
scenario_name=config.name,
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
)
runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL)
scenarios = [ScenarioConfig(name=f"Test{i}") for i in range(3)]
runner.run(scenarios, progress_callback=on_progress)
# Should have received progress updates
assert len(progress_updates) > 0
# Final update should show all completed
final = progress_updates[-1]
assert final["completed"] == 3
def test_get_progress(self):
"""Test get_progress method."""
def executor(config):
time.sleep(0.01)
return ScenarioResult(
scenario_id=config.scenario_id,
scenario_name=config.name,
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
)
runner = ScenarioRunner(
executor=executor,
mode=ExecutionMode.THREADED,
max_workers=2,
)
scenarios = [ScenarioConfig(name=f"Test{i}") for i in range(4)]
# Start in background
results = []
def run_in_thread():
nonlocal results
results = runner.run(scenarios)
thread = threading.Thread(target=run_in_thread)
thread.start()
# Give it time to start
time.sleep(0.005)
progress = runner.get_progress()
assert progress.total_scenarios == 4
thread.join()
assert len(results) == 4
# ==============================================================================
# ScenarioRunner Tests - Cancellation
# ==============================================================================
class TestScenarioRunnerCancellation:
"""Tests for ScenarioRunner cancellation."""
def test_cancel_pending_scenarios(self):
"""Test cancelling pending scenarios."""
executed = []
def executor(config):
executed.append(config.name)
time.sleep(0.02)
return ScenarioResult(
scenario_id=config.scenario_id,
scenario_name=config.name,
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
)
runner = ScenarioRunner(
executor=executor,
mode=ExecutionMode.SEQUENTIAL,
)
scenarios = [ScenarioConfig(name=f"Test{i}") for i in range(5)]
# Cancel after short delay
def cancel_after_delay():
time.sleep(0.03)
runner.cancel()
cancel_thread = threading.Thread(target=cancel_after_delay)
cancel_thread.start()
results = runner.run(scenarios)
cancel_thread.join()
# Some should be completed, some cancelled
completed = [r for r in results if r.status == ScenarioStatus.COMPLETED]
cancelled = [r for r in results if r.status == ScenarioStatus.CANCELLED]
assert len(completed) + len(cancelled) == 5
# At least the first one should have completed
assert len(completed) >= 1
# ==============================================================================
# ScenarioBatchBuilder Tests
# ==============================================================================
class TestScenarioBatchBuilder:
"""Tests for ScenarioBatchBuilder."""
def test_empty_build(self):
"""Test building with no configuration."""
builder = ScenarioBatchBuilder()
scenarios = builder.build()
assert len(scenarios) == 1 # One default scenario
def test_base_config(self):
"""Test setting base configuration."""
builder = ScenarioBatchBuilder()
scenarios = (
builder
.with_base_config(
initial_capital=Decimal("50000"),
symbols=["AAPL", "GOOGL"],
)
.build()
)
assert len(scenarios) == 1
assert scenarios[0].initial_capital == Decimal("50000")
assert scenarios[0].symbols == ["AAPL", "GOOGL"]
def test_parameter_variation(self):
"""Test varying a single parameter."""
builder = ScenarioBatchBuilder()
scenarios = (
builder
.vary_parameter("leverage", [0.5, 1.0, 1.5])
.build()
)
assert len(scenarios) == 3
leverages = [s.strategy_params["leverage"] for s in scenarios]
assert leverages == [0.5, 1.0, 1.5]
def test_multiple_parameter_variations(self):
"""Test varying multiple parameters (Cartesian product)."""
builder = ScenarioBatchBuilder()
scenarios = (
builder
.vary_parameter("leverage", [1.0, 2.0])
.vary_parameter("stop_loss", [0.05, 0.10])
.build()
)
assert len(scenarios) == 4 # 2 x 2
combinations = [
(s.strategy_params["leverage"], s.strategy_params["stop_loss"])
for s in scenarios
]
expected = [
(1.0, 0.05),
(1.0, 0.10),
(2.0, 0.05),
(2.0, 0.10),
]
assert sorted(combinations) == sorted(expected)
def test_explicit_date_ranges(self):
"""Test setting explicit date ranges."""
builder = ScenarioBatchBuilder()
scenarios = (
builder
.with_date_ranges([
(date(2020, 1, 1), date(2020, 12, 31)),
(date(2021, 1, 1), date(2021, 12, 31)),
])
.build()
)
assert len(scenarios) == 2
assert scenarios[0].start_date == date(2020, 1, 1)
assert scenarios[1].start_date == date(2021, 1, 1)
def test_clear_builder(self):
"""Test clearing builder configuration."""
builder = ScenarioBatchBuilder()
builder.with_base_config(initial_capital=Decimal("50000"))
builder.vary_parameter("leverage", [1.0, 2.0])
builder.clear()
scenarios = builder.build()
assert len(scenarios) == 1
assert scenarios[0].initial_capital == Decimal("100000") # Default
def test_scenario_names_generated(self):
"""Test that scenario names are auto-generated."""
builder = ScenarioBatchBuilder()
scenarios = (
builder
.vary_parameter("leverage", [1.0, 2.0])
.build()
)
assert "leverage=1.0" in scenarios[0].name
assert "leverage=2.0" in scenarios[1].name
# ==============================================================================
# Result Aggregation Tests
# ==============================================================================
class TestResultAggregation:
"""Tests for result aggregation."""
def test_aggregate_empty_results(self):
"""Test aggregating empty results list."""
agg = aggregate_results([])
assert agg["total_scenarios"] == 0
assert agg["successful"] == 0
assert agg["failed"] == 0
def test_aggregate_successful_results(self):
"""Test aggregating successful results."""
results = [
ScenarioResult(
scenario_id=f"test-{i}",
scenario_name=f"Test{i}",
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
total_return=Decimal(f"0.{i+1}"),
max_drawdown=Decimal(f"-0.0{i+1}"),
duration_seconds=1.0,
)
for i in range(3)
]
agg = aggregate_results(results)
assert agg["total_scenarios"] == 3
assert agg["successful"] == 3
assert agg["failed"] == 0
assert agg["success_rate"] == 1.0
assert "avg_return" in agg
assert "min_return" in agg
assert "max_return" in agg
def test_aggregate_mixed_results(self):
"""Test aggregating mixed success/failure results."""
results = [
ScenarioResult(
scenario_id="test-1",
scenario_name="Success1",
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
total_return=Decimal("0.10"),
),
ScenarioResult(
scenario_id="test-2",
scenario_name="Failed1",
status=ScenarioStatus.FAILED,
start_time=datetime.now(),
error_message="Error",
),
ScenarioResult(
scenario_id="test-3",
scenario_name="Success2",
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
total_return=Decimal("0.20"),
),
]
agg = aggregate_results(results)
assert agg["total_scenarios"] == 3
assert agg["successful"] == 2
assert agg["failed"] == 1
assert agg["success_rate"] == pytest.approx(0.667, rel=0.01)
def test_aggregate_best_worst_scenarios(self):
"""Test best/worst scenario identification."""
results = [
ScenarioResult(
scenario_id="test-1",
scenario_name="Low",
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
total_return=Decimal("0.05"),
),
ScenarioResult(
scenario_id="test-2",
scenario_name="High",
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
total_return=Decimal("0.25"),
),
ScenarioResult(
scenario_id="test-3",
scenario_name="Mid",
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
total_return=Decimal("0.15"),
),
]
agg = aggregate_results(results)
assert agg["best_scenario"]["name"] == "High"
assert agg["worst_scenario"]["name"] == "Low"
# ==============================================================================
# Module Import Tests
# ==============================================================================
class TestModuleImports:
"""Tests for module imports."""
def test_import_from_simulation_module(self):
"""Test importing from simulation module."""
from tradingagents.simulation import (
ExecutionMode,
ScenarioStatus,
ScenarioConfig,
ScenarioResult,
RunnerProgress,
ScenarioRunner,
ScenarioBatchBuilder,
aggregate_results,
)
assert ExecutionMode is not None
assert ScenarioStatus is not None
assert ScenarioConfig is not None
assert ScenarioResult is not None
assert RunnerProgress is not None
assert ScenarioRunner is not None
assert ScenarioBatchBuilder is not None
assert aggregate_results is not None
# ==============================================================================
# Integration Tests
# ==============================================================================
class TestScenarioRunnerIntegration:
"""Integration tests for ScenarioRunner."""
def test_full_simulation_workflow(self):
"""Test complete simulation workflow."""
# Create executor that simulates trading
def trading_simulator(config: ScenarioConfig) -> ScenarioResult:
# Simulate based on strategy params
leverage = config.strategy_params.get("leverage", 1.0)
base_return = 0.08 # 8% base return
final_return = base_return * leverage
final_value = config.initial_capital * Decimal(str(1 + final_return))
return ScenarioResult(
scenario_id=config.scenario_id,
scenario_name=config.name,
status=ScenarioStatus.COMPLETED,
start_time=datetime.now(),
end_time=datetime.now(),
final_value=final_value.quantize(Decimal("0.01")),
total_return=Decimal(str(final_return)).quantize(Decimal("0.0001")),
max_drawdown=Decimal("-0.10"),
trades_executed=25,
)
# Build scenarios with varying leverage
scenarios = (
ScenarioBatchBuilder()
.with_base_config(initial_capital=Decimal("100000"))
.vary_parameter("leverage", [0.5, 1.0, 1.5, 2.0])
.build()
)
# Run simulations
runner = ScenarioRunner(
executor=trading_simulator,
mode=ExecutionMode.THREADED,
max_workers=4,
)
progress_updates = []
results = runner.run(
scenarios,
progress_callback=lambda p: progress_updates.append(p.completed),
)
# Verify results
assert len(results) == 4
assert all(r.is_successful for r in results)
# Aggregate and analyze
agg = aggregate_results(results)
assert agg["total_scenarios"] == 4
assert agg["successful"] == 4
assert agg["best_scenario"]["name"] == "leverage=2.0"
assert agg["worst_scenario"]["name"] == "leverage=0.5"

View File

@ -0,0 +1,96 @@
"""Simulation module for portfolio simulations and backtesting.
This module provides simulation capabilities including:
- Parallel scenario execution
- Parameter sweep analysis
- Strategy comparison
- Economic regime simulation
Issue #33: [SIM-32] Scenario runner - parallel portfolio simulations
Submodules:
scenario_runner: Core scenario execution framework
Classes:
Enums:
- ExecutionMode: Parallel execution mode (sequential, threaded, process)
- ScenarioStatus: Status of a scenario run
Data Classes:
- ScenarioConfig: Configuration for a simulation scenario
- ScenarioResult: Result from a scenario simulation
- RunnerProgress: Progress information for batch runs
Main Classes:
- ScenarioRunner: Runner for parallel portfolio simulations
- ScenarioBatchBuilder: Builder for creating scenario batches
Protocols:
- ScenarioExecutor: Protocol for scenario execution functions
Utility Functions:
- aggregate_results: Aggregate results from multiple scenarios
Example:
>>> from tradingagents.simulation import (
... ScenarioRunner,
... ScenarioConfig,
... ScenarioResult,
... ScenarioStatus,
... ExecutionMode,
... )
>>> from datetime import datetime
>>> from decimal import Decimal
>>>
>>> def simple_executor(config: ScenarioConfig) -> ScenarioResult:
... return ScenarioResult(
... scenario_id=config.scenario_id,
... scenario_name=config.name,
... status=ScenarioStatus.COMPLETED,
... start_time=datetime.now(),
... final_value=config.initial_capital * Decimal("1.1"),
... total_return=Decimal("0.1"),
... )
>>>
>>> runner = ScenarioRunner(executor=simple_executor)
>>> scenarios = [ScenarioConfig(name="Test1"), ScenarioConfig(name="Test2")]
>>> results = runner.run(scenarios)
"""
from .scenario_runner import (
# Enums
ExecutionMode,
ScenarioStatus,
# Data Classes
ScenarioConfig,
ScenarioResult,
RunnerProgress,
# Main Classes
ScenarioRunner,
ScenarioBatchBuilder,
# Protocols
ScenarioExecutor,
# Types
ProgressCallback,
# Utility Functions
aggregate_results,
)
__all__ = [
# Enums
"ExecutionMode",
"ScenarioStatus",
# Data Classes
"ScenarioConfig",
"ScenarioResult",
"RunnerProgress",
# Main Classes
"ScenarioRunner",
"ScenarioBatchBuilder",
# Protocols
"ScenarioExecutor",
# Types
"ProgressCallback",
# Utility Functions
"aggregate_results",
]

View File

@ -0,0 +1,689 @@
"""Scenario Runner for parallel portfolio simulations.
This module provides scenario-based simulation capabilities including:
- Parallel execution of multiple portfolio scenarios
- Configurable simulation parameters
- Progress tracking and callbacks
- Result aggregation and comparison
Issue #33: [SIM-32] Scenario runner - parallel portfolio simulations
Design Principles:
- Thread-safe parallel execution
- Configurable concurrency limits
- Memory-efficient result handling
- Progress reporting callbacks
"""
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from dataclasses import dataclass, field
from datetime import datetime, date
from decimal import Decimal
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
import copy
import threading
import time
import uuid
class ExecutionMode(Enum):
"""Mode of parallel execution."""
SEQUENTIAL = "sequential" # Run scenarios one at a time
THREADED = "threaded" # Use thread pool (for I/O bound)
PROCESS = "process" # Use process pool (for CPU bound)
class ScenarioStatus(Enum):
"""Status of a scenario run."""
PENDING = "pending" # Not yet started
RUNNING = "running" # Currently executing
COMPLETED = "completed" # Finished successfully
FAILED = "failed" # Finished with error
CANCELLED = "cancelled" # Cancelled before completion
@dataclass
class ScenarioConfig:
"""Configuration for a simulation scenario.
Attributes:
scenario_id: Unique identifier for the scenario
name: Human-readable name
start_date: Simulation start date
end_date: Simulation end date
initial_capital: Starting capital
symbols: List of symbols to include
strategy_params: Strategy-specific parameters
risk_params: Risk management parameters
metadata: Additional scenario data
"""
scenario_id: str = field(default_factory=lambda: str(uuid.uuid4()))
name: str = ""
start_date: Optional[date] = None
end_date: Optional[date] = None
initial_capital: Decimal = Decimal("100000")
symbols: List[str] = field(default_factory=list)
strategy_params: Dict[str, Any] = field(default_factory=dict)
risk_params: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
"""Set default name if not provided."""
if not self.name:
self.name = f"Scenario-{self.scenario_id[:8]}"
@dataclass
class ScenarioResult:
"""Result from a scenario simulation.
Attributes:
scenario_id: ID of the scenario that was run
scenario_name: Name of the scenario
status: Final status of the run
start_time: When simulation started
end_time: When simulation ended
duration_seconds: Total runtime in seconds
final_value: Final portfolio value
total_return: Total return as decimal (0.10 = 10%)
trades_executed: Number of trades made
max_drawdown: Maximum drawdown experienced
sharpe_ratio: Sharpe ratio if calculable
error_message: Error message if failed
portfolio_history: Time series of portfolio values
trade_history: List of trades executed
metrics: Additional performance metrics
metadata: Additional result data
"""
scenario_id: str
scenario_name: str
status: ScenarioStatus
start_time: datetime
end_time: Optional[datetime] = None
duration_seconds: float = 0.0
final_value: Decimal = Decimal("0")
total_return: Decimal = Decimal("0")
trades_executed: int = 0
max_drawdown: Decimal = Decimal("0")
sharpe_ratio: Optional[Decimal] = None
error_message: Optional[str] = None
portfolio_history: List[Tuple[date, Decimal]] = field(default_factory=list)
trade_history: List[Dict[str, Any]] = field(default_factory=list)
metrics: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def is_successful(self) -> bool:
"""Check if scenario completed successfully."""
return self.status == ScenarioStatus.COMPLETED
@property
def is_finished(self) -> bool:
"""Check if scenario has finished (success or failure)."""
return self.status in (
ScenarioStatus.COMPLETED,
ScenarioStatus.FAILED,
ScenarioStatus.CANCELLED,
)
class ScenarioExecutor(Protocol):
"""Protocol for scenario execution functions.
Implementations should take a ScenarioConfig and return a ScenarioResult.
"""
def __call__(self, config: ScenarioConfig) -> ScenarioResult:
"""Execute a scenario and return results."""
...
@dataclass
class RunnerProgress:
"""Progress information for a scenario run batch.
Attributes:
total_scenarios: Total number of scenarios
completed: Number of completed scenarios
failed: Number of failed scenarios
running: Number of currently running scenarios
pending: Number of pending scenarios
start_time: When the batch started
estimated_remaining_seconds: Estimated time remaining
"""
total_scenarios: int
completed: int = 0
failed: int = 0
running: int = 0
pending: int = 0
start_time: Optional[datetime] = None
estimated_remaining_seconds: Optional[float] = None
@property
def progress_percent(self) -> float:
"""Calculate completion percentage."""
if self.total_scenarios == 0:
return 100.0
return (self.completed + self.failed) / self.total_scenarios * 100
@property
def is_complete(self) -> bool:
"""Check if all scenarios are finished."""
return (self.completed + self.failed) >= self.total_scenarios
ProgressCallback = Callable[[RunnerProgress], None]
class ScenarioRunner:
"""Runner for parallel portfolio simulations.
Executes multiple simulation scenarios in parallel using configurable
execution modes (sequential, threaded, or process-based).
Example:
>>> def simulate(config: ScenarioConfig) -> ScenarioResult:
... # Your simulation logic
... return ScenarioResult(
... scenario_id=config.scenario_id,
... scenario_name=config.name,
... status=ScenarioStatus.COMPLETED,
... start_time=datetime.now(),
... )
...
>>> runner = ScenarioRunner(executor=simulate)
>>> scenarios = [
... ScenarioConfig(name="Bull Market", strategy_params={"leverage": 1.5}),
... ScenarioConfig(name="Bear Market", strategy_params={"leverage": 0.5}),
... ]
>>> results = runner.run(scenarios)
>>> print(f"Completed: {len([r for r in results if r.is_successful])}")
"""
def __init__(
self,
executor: ScenarioExecutor,
mode: ExecutionMode = ExecutionMode.THREADED,
max_workers: Optional[int] = None,
timeout_seconds: Optional[float] = None,
):
"""Initialize the scenario runner.
Args:
executor: Function that executes a single scenario
mode: Execution mode (sequential, threaded, process)
max_workers: Maximum number of parallel workers (None = auto)
timeout_seconds: Timeout per scenario (None = no timeout)
"""
self.executor = executor
self.mode = mode
self.max_workers = max_workers
self.timeout_seconds = timeout_seconds
self._lock = threading.Lock()
self._cancelled = False
self._progress = RunnerProgress(total_scenarios=0)
self._progress_callbacks: List[ProgressCallback] = []
def add_progress_callback(self, callback: ProgressCallback) -> None:
"""Add a callback for progress updates.
Args:
callback: Function to call with progress updates
"""
with self._lock:
self._progress_callbacks.append(callback)
def remove_progress_callback(self, callback: ProgressCallback) -> None:
"""Remove a progress callback.
Args:
callback: Callback to remove
"""
with self._lock:
if callback in self._progress_callbacks:
self._progress_callbacks.remove(callback)
def _notify_progress(self) -> None:
"""Notify all registered progress callbacks."""
with self._lock:
callbacks = self._progress_callbacks.copy()
progress = copy.copy(self._progress)
for callback in callbacks:
try:
callback(progress)
except Exception:
pass # Don't let callback errors affect execution
def _update_progress(
self,
completed_delta: int = 0,
failed_delta: int = 0,
running_delta: int = 0,
pending_delta: int = 0,
) -> None:
"""Update progress counters thread-safely."""
with self._lock:
self._progress.completed += completed_delta
self._progress.failed += failed_delta
self._progress.running += running_delta
self._progress.pending += pending_delta
# Estimate remaining time
if self._progress.start_time and self._progress.completed > 0:
elapsed = (datetime.now() - self._progress.start_time).total_seconds()
avg_time = elapsed / self._progress.completed
remaining = self._progress.pending + self._progress.running
self._progress.estimated_remaining_seconds = avg_time * remaining
self._notify_progress()
def _execute_scenario(self, config: ScenarioConfig) -> ScenarioResult:
"""Execute a single scenario with error handling.
Args:
config: Scenario configuration
Returns:
ScenarioResult with outcome
"""
start_time = datetime.now()
# Check if cancelled
if self._cancelled:
return ScenarioResult(
scenario_id=config.scenario_id,
scenario_name=config.name,
status=ScenarioStatus.CANCELLED,
start_time=start_time,
end_time=datetime.now(),
)
self._update_progress(running_delta=1, pending_delta=-1)
try:
result = self.executor(config)
result.end_time = datetime.now()
result.duration_seconds = (result.end_time - start_time).total_seconds()
if result.status == ScenarioStatus.COMPLETED:
self._update_progress(completed_delta=1, running_delta=-1)
else:
self._update_progress(failed_delta=1, running_delta=-1)
return result
except Exception as e:
end_time = datetime.now()
self._update_progress(failed_delta=1, running_delta=-1)
return ScenarioResult(
scenario_id=config.scenario_id,
scenario_name=config.name,
status=ScenarioStatus.FAILED,
start_time=start_time,
end_time=end_time,
duration_seconds=(end_time - start_time).total_seconds(),
error_message=str(e),
)
def run(
self,
scenarios: List[ScenarioConfig],
progress_callback: Optional[ProgressCallback] = None,
) -> List[ScenarioResult]:
"""Run multiple scenarios.
Args:
scenarios: List of scenario configurations to run
progress_callback: Optional callback for progress updates
Returns:
List of results in the same order as input scenarios
"""
if not scenarios:
return []
# Reset state
self._cancelled = False
self._progress = RunnerProgress(
total_scenarios=len(scenarios),
pending=len(scenarios),
start_time=datetime.now(),
)
if progress_callback:
self.add_progress_callback(progress_callback)
try:
if self.mode == ExecutionMode.SEQUENTIAL:
results = self._run_sequential(scenarios)
elif self.mode == ExecutionMode.THREADED:
results = self._run_parallel(scenarios, ThreadPoolExecutor)
elif self.mode == ExecutionMode.PROCESS:
results = self._run_parallel(scenarios, ProcessPoolExecutor)
else:
raise ValueError(f"Unknown execution mode: {self.mode}")
return results
finally:
if progress_callback:
self.remove_progress_callback(progress_callback)
def _run_sequential(
self, scenarios: List[ScenarioConfig]
) -> List[ScenarioResult]:
"""Run scenarios sequentially.
Args:
scenarios: List of scenarios to run
Returns:
List of results in order
"""
results = []
for config in scenarios:
if self._cancelled:
results.append(ScenarioResult(
scenario_id=config.scenario_id,
scenario_name=config.name,
status=ScenarioStatus.CANCELLED,
start_time=datetime.now(),
))
else:
results.append(self._execute_scenario(config))
return results
def _run_parallel(
self,
scenarios: List[ScenarioConfig],
executor_class: type,
) -> List[ScenarioResult]:
"""Run scenarios in parallel using a pool executor.
Args:
scenarios: List of scenarios to run
executor_class: ThreadPoolExecutor or ProcessPoolExecutor
Returns:
List of results in original order
"""
# Map scenario IDs to their original indices
id_to_index = {config.scenario_id: i for i, config in enumerate(scenarios)}
results = [None] * len(scenarios)
with executor_class(max_workers=self.max_workers) as pool:
# Submit all scenarios
future_to_id = {
pool.submit(self._execute_scenario, config): config.scenario_id
for config in scenarios
}
# Collect results as they complete
for future in as_completed(future_to_id, timeout=self.timeout_seconds):
scenario_id = future_to_id[future]
index = id_to_index[scenario_id]
try:
result = future.result()
results[index] = result
except Exception as e:
# Create error result for this scenario
config = scenarios[index]
results[index] = ScenarioResult(
scenario_id=config.scenario_id,
scenario_name=config.name,
status=ScenarioStatus.FAILED,
start_time=datetime.now(),
error_message=str(e),
)
return results
def cancel(self) -> None:
"""Cancel all pending scenarios.
Running scenarios will complete, but pending ones will be skipped.
"""
with self._lock:
self._cancelled = True
def get_progress(self) -> RunnerProgress:
"""Get current progress information.
Returns:
Current progress state
"""
with self._lock:
return copy.copy(self._progress)
class ScenarioBatchBuilder:
"""Builder for creating batches of scenario configurations.
Provides convenient methods for generating variations of scenarios
for sensitivity analysis, parameter sweeps, etc.
Example:
>>> builder = ScenarioBatchBuilder()
>>> scenarios = (
... builder
... .with_base_config(symbols=["AAPL", "GOOGL"])
... .vary_parameter("leverage", [0.5, 1.0, 1.5, 2.0])
... .vary_date_range(
... date(2020, 1, 1),
... date(2023, 12, 31),
... window_months=12,
... )
... .build()
... )
"""
def __init__(self):
"""Initialize the batch builder."""
self._base_config: Dict[str, Any] = {}
self._parameter_variations: Dict[str, List[Any]] = {}
self._date_ranges: List[Tuple[date, date]] = []
def with_base_config(self, **kwargs) -> "ScenarioBatchBuilder":
"""Set base configuration for all scenarios.
Args:
**kwargs: Configuration parameters
Returns:
Self for chaining
"""
self._base_config.update(kwargs)
return self
def vary_parameter(
self, name: str, values: List[Any]
) -> "ScenarioBatchBuilder":
"""Add a parameter to vary across scenarios.
Args:
name: Parameter name (in strategy_params or risk_params)
values: List of values to use
Returns:
Self for chaining
"""
self._parameter_variations[name] = values
return self
def vary_date_range(
self,
start: date,
end: date,
window_months: int = 12,
step_months: int = 3,
) -> "ScenarioBatchBuilder":
"""Add rolling date windows.
Args:
start: Overall start date
end: Overall end date
window_months: Size of each window in months
step_months: Step between windows in months
Returns:
Self for chaining
"""
from datetime import timedelta
from dateutil.relativedelta import relativedelta
current_start = start
while current_start < end:
current_end = current_start + relativedelta(months=window_months)
if current_end > end:
current_end = end
self._date_ranges.append((current_start, current_end))
current_start += relativedelta(months=step_months)
return self
def with_date_ranges(
self, ranges: List[Tuple[date, date]]
) -> "ScenarioBatchBuilder":
"""Set explicit date ranges.
Args:
ranges: List of (start_date, end_date) tuples
Returns:
Self for chaining
"""
self._date_ranges.extend(ranges)
return self
def build(self) -> List[ScenarioConfig]:
"""Build all scenario configurations.
Creates the Cartesian product of all variations.
Returns:
List of ScenarioConfig objects
"""
import itertools
scenarios = []
# Get all parameter combinations
param_names = list(self._parameter_variations.keys())
param_values = [self._parameter_variations[name] for name in param_names]
if param_values:
param_combinations = list(itertools.product(*param_values))
else:
param_combinations = [()]
# Get date ranges (use single None tuple if no ranges)
date_ranges = self._date_ranges if self._date_ranges else [(None, None)]
# Generate all combinations
for param_combo in param_combinations:
for start_date, end_date in date_ranges:
# Build scenario config
config_dict = copy.deepcopy(self._base_config)
# Set dates
if start_date is not None:
config_dict["start_date"] = start_date
if end_date is not None:
config_dict["end_date"] = end_date
# Set parameter variations
strategy_params = config_dict.get("strategy_params", {})
for name, value in zip(param_names, param_combo):
strategy_params[name] = value
config_dict["strategy_params"] = strategy_params
# Generate descriptive name
name_parts = []
for name, value in zip(param_names, param_combo):
name_parts.append(f"{name}={value}")
if start_date:
name_parts.append(f"{start_date.year}")
if name_parts:
config_dict["name"] = " | ".join(name_parts)
scenarios.append(ScenarioConfig(**config_dict))
return scenarios
def clear(self) -> "ScenarioBatchBuilder":
"""Clear all configuration.
Returns:
Self for chaining
"""
self._base_config = {}
self._parameter_variations = {}
self._date_ranges = []
return self
def aggregate_results(results: List[ScenarioResult]) -> Dict[str, Any]:
"""Aggregate results from multiple scenarios.
Args:
results: List of scenario results
Returns:
Dictionary with aggregated statistics
"""
if not results:
return {
"total_scenarios": 0,
"successful": 0,
"failed": 0,
}
successful = [r for r in results if r.is_successful]
failed = [r for r in results if r.status == ScenarioStatus.FAILED]
# Calculate aggregate metrics
returns = [float(r.total_return) for r in successful if r.total_return]
drawdowns = [float(r.max_drawdown) for r in successful if r.max_drawdown]
durations = [r.duration_seconds for r in results if r.duration_seconds]
aggregate = {
"total_scenarios": len(results),
"successful": len(successful),
"failed": len(failed),
"success_rate": len(successful) / len(results) if results else 0,
"total_duration_seconds": sum(durations),
"avg_duration_seconds": sum(durations) / len(durations) if durations else 0,
}
if returns:
aggregate.update({
"avg_return": sum(returns) / len(returns),
"min_return": min(returns),
"max_return": max(returns),
"median_return": sorted(returns)[len(returns) // 2],
})
if drawdowns:
aggregate.update({
"avg_max_drawdown": sum(drawdowns) / len(drawdowns),
"worst_drawdown": min(drawdowns), # More negative is worse
})
# Best and worst scenarios
if successful:
best = max(successful, key=lambda r: float(r.total_return or 0))
worst = min(successful, key=lambda r: float(r.total_return or 0))
aggregate["best_scenario"] = {
"name": best.scenario_name,
"return": str(best.total_return),
}
aggregate["worst_scenario"] = {
"name": worst.scenario_name,
"return": str(worst.total_return),
}
return aggregate