diff --git a/tests/unit/simulation/__init__.py b/tests/unit/simulation/__init__.py new file mode 100644 index 00000000..e792b96d --- /dev/null +++ b/tests/unit/simulation/__init__.py @@ -0,0 +1 @@ +"""Unit tests for the simulation module.""" diff --git a/tests/unit/simulation/test_scenario_runner.py b/tests/unit/simulation/test_scenario_runner.py new file mode 100644 index 00000000..f5e9a924 --- /dev/null +++ b/tests/unit/simulation/test_scenario_runner.py @@ -0,0 +1,835 @@ +"""Tests for Scenario Runner. + +Issue #33: [SIM-32] Scenario runner - parallel portfolio simulations + +Tests cover: +- ExecutionMode and ScenarioStatus enums +- ScenarioConfig and ScenarioResult dataclasses +- ScenarioRunner sequential execution +- ScenarioRunner parallel execution +- Progress tracking and callbacks +- ScenarioBatchBuilder variations +- Result aggregation +- Error handling +- Cancellation +""" + +import pytest +import time +from datetime import datetime, date +from decimal import Decimal +from unittest.mock import Mock, patch +import threading + +from tradingagents.simulation.scenario_runner import ( + ExecutionMode, + ScenarioStatus, + ScenarioConfig, + ScenarioResult, + RunnerProgress, + ScenarioRunner, + ScenarioBatchBuilder, + aggregate_results, +) + + +# ============================================================================== +# ExecutionMode Enum Tests +# ============================================================================== + + +class TestExecutionMode: + """Tests for ExecutionMode enum.""" + + def test_sequential_value(self): + """Test SEQUENTIAL mode value.""" + assert ExecutionMode.SEQUENTIAL.value == "sequential" + + def test_threaded_value(self): + """Test THREADED mode value.""" + assert ExecutionMode.THREADED.value == "threaded" + + def test_process_value(self): + """Test PROCESS mode value.""" + assert ExecutionMode.PROCESS.value == "process" + + def test_all_modes_exist(self): + """Test all expected modes exist.""" + modes = [m for m in ExecutionMode] + assert len(modes) == 3 + + +# ============================================================================== +# ScenarioStatus Enum Tests +# ============================================================================== + + +class TestScenarioStatus: + """Tests for ScenarioStatus enum.""" + + def test_pending_value(self): + """Test PENDING status value.""" + assert ScenarioStatus.PENDING.value == "pending" + + def test_running_value(self): + """Test RUNNING status value.""" + assert ScenarioStatus.RUNNING.value == "running" + + def test_completed_value(self): + """Test COMPLETED status value.""" + assert ScenarioStatus.COMPLETED.value == "completed" + + def test_failed_value(self): + """Test FAILED status value.""" + assert ScenarioStatus.FAILED.value == "failed" + + def test_cancelled_value(self): + """Test CANCELLED status value.""" + assert ScenarioStatus.CANCELLED.value == "cancelled" + + def test_all_statuses_exist(self): + """Test all expected statuses exist.""" + statuses = [s for s in ScenarioStatus] + assert len(statuses) == 5 + + +# ============================================================================== +# ScenarioConfig Tests +# ============================================================================== + + +class TestScenarioConfig: + """Tests for ScenarioConfig dataclass.""" + + def test_default_config(self): + """Test creating config with defaults.""" + config = ScenarioConfig() + assert config.scenario_id is not None + assert config.name.startswith("Scenario-") + assert config.initial_capital == Decimal("100000") + assert config.symbols == [] + assert config.strategy_params == {} + assert config.risk_params == {} + + def test_custom_config(self): + """Test creating config with custom values.""" + config = ScenarioConfig( + name="Bull Market Test", + start_date=date(2023, 1, 1), + end_date=date(2023, 12, 31), + initial_capital=Decimal("50000"), + symbols=["AAPL", "GOOGL"], + strategy_params={"leverage": 1.5}, + risk_params={"max_drawdown": 0.2}, + ) + assert config.name == "Bull Market Test" + assert config.start_date == date(2023, 1, 1) + assert config.end_date == date(2023, 12, 31) + assert config.initial_capital == Decimal("50000") + assert config.symbols == ["AAPL", "GOOGL"] + assert config.strategy_params["leverage"] == 1.5 + assert config.risk_params["max_drawdown"] == 0.2 + + def test_auto_generated_name(self): + """Test auto-generated name from scenario_id.""" + config = ScenarioConfig() + # Name should start with "Scenario-" and contain part of the ID + assert config.name.startswith("Scenario-") + assert len(config.name) > 8 + + def test_explicit_scenario_id(self): + """Test explicit scenario ID.""" + config = ScenarioConfig(scenario_id="test-123") + assert config.scenario_id == "test-123" + + +# ============================================================================== +# ScenarioResult Tests +# ============================================================================== + + +class TestScenarioResult: + """Tests for ScenarioResult dataclass.""" + + def test_successful_result(self): + """Test creating a successful result.""" + result = ScenarioResult( + scenario_id="test-1", + scenario_name="Test", + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + final_value=Decimal("110000"), + total_return=Decimal("0.10"), + ) + assert result.is_successful is True + assert result.is_finished is True + + def test_failed_result(self): + """Test creating a failed result.""" + result = ScenarioResult( + scenario_id="test-1", + scenario_name="Test", + status=ScenarioStatus.FAILED, + start_time=datetime.now(), + error_message="Simulation error", + ) + assert result.is_successful is False + assert result.is_finished is True + assert result.error_message == "Simulation error" + + def test_pending_result(self): + """Test pending result properties.""" + result = ScenarioResult( + scenario_id="test-1", + scenario_name="Test", + status=ScenarioStatus.PENDING, + start_time=datetime.now(), + ) + assert result.is_successful is False + assert result.is_finished is False + + def test_running_result(self): + """Test running result properties.""" + result = ScenarioResult( + scenario_id="test-1", + scenario_name="Test", + status=ScenarioStatus.RUNNING, + start_time=datetime.now(), + ) + assert result.is_successful is False + assert result.is_finished is False + + def test_cancelled_result(self): + """Test cancelled result properties.""" + result = ScenarioResult( + scenario_id="test-1", + scenario_name="Test", + status=ScenarioStatus.CANCELLED, + start_time=datetime.now(), + ) + assert result.is_successful is False + assert result.is_finished is True + + +# ============================================================================== +# RunnerProgress Tests +# ============================================================================== + + +class TestRunnerProgress: + """Tests for RunnerProgress dataclass.""" + + def test_progress_percent(self): + """Test progress percentage calculation.""" + progress = RunnerProgress( + total_scenarios=10, + completed=3, + failed=2, + ) + assert progress.progress_percent == 50.0 + + def test_progress_percent_zero_total(self): + """Test progress with zero total scenarios.""" + progress = RunnerProgress(total_scenarios=0) + assert progress.progress_percent == 100.0 + + def test_is_complete(self): + """Test completion detection.""" + progress = RunnerProgress( + total_scenarios=5, + completed=3, + failed=2, + ) + assert progress.is_complete is True + + def test_not_complete(self): + """Test incomplete detection.""" + progress = RunnerProgress( + total_scenarios=5, + completed=2, + failed=1, + ) + assert progress.is_complete is False + + +# ============================================================================== +# ScenarioRunner Tests - Sequential Execution +# ============================================================================== + + +class TestScenarioRunnerSequential: + """Tests for ScenarioRunner sequential execution.""" + + def test_run_empty_list(self): + """Test running with empty scenario list.""" + def executor(config): + return ScenarioResult( + scenario_id=config.scenario_id, + scenario_name=config.name, + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + ) + + runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL) + results = runner.run([]) + assert results == [] + + def test_run_single_scenario(self): + """Test running a single scenario.""" + def executor(config): + return ScenarioResult( + scenario_id=config.scenario_id, + scenario_name=config.name, + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + final_value=Decimal("110000"), + total_return=Decimal("0.10"), + ) + + runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL) + scenarios = [ScenarioConfig(name="Test1")] + results = runner.run(scenarios) + + assert len(results) == 1 + assert results[0].is_successful + assert results[0].scenario_name == "Test1" + + def test_run_multiple_scenarios(self): + """Test running multiple scenarios sequentially.""" + call_order = [] + + def executor(config): + call_order.append(config.name) + return ScenarioResult( + scenario_id=config.scenario_id, + scenario_name=config.name, + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + ) + + runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL) + scenarios = [ + ScenarioConfig(name="Test1"), + ScenarioConfig(name="Test2"), + ScenarioConfig(name="Test3"), + ] + results = runner.run(scenarios) + + assert len(results) == 3 + assert call_order == ["Test1", "Test2", "Test3"] + assert all(r.is_successful for r in results) + + def test_run_with_executor_exception(self): + """Test handling executor exceptions.""" + def executor(config): + if config.name == "FailMe": + raise ValueError("Simulated error") + return ScenarioResult( + scenario_id=config.scenario_id, + scenario_name=config.name, + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + ) + + runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL) + scenarios = [ + ScenarioConfig(name="Test1"), + ScenarioConfig(name="FailMe"), + ScenarioConfig(name="Test3"), + ] + results = runner.run(scenarios) + + assert len(results) == 3 + assert results[0].is_successful + assert results[1].status == ScenarioStatus.FAILED + assert "Simulated error" in results[1].error_message + assert results[2].is_successful + + +# ============================================================================== +# ScenarioRunner Tests - Parallel Execution +# ============================================================================== + + +class TestScenarioRunnerParallel: + """Tests for ScenarioRunner parallel execution.""" + + def test_run_threaded(self): + """Test running scenarios in parallel using threads.""" + def executor(config): + time.sleep(0.01) # Small delay to ensure parallelism + return ScenarioResult( + scenario_id=config.scenario_id, + scenario_name=config.name, + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + ) + + runner = ScenarioRunner( + executor=executor, + mode=ExecutionMode.THREADED, + max_workers=4, + ) + scenarios = [ScenarioConfig(name=f"Test{i}") for i in range(8)] + + start = time.time() + results = runner.run(scenarios) + elapsed = time.time() - start + + assert len(results) == 8 + assert all(r.is_successful for r in results) + # With 4 workers and 8 scenarios at 0.01s each, + # parallel should be faster than sequential (0.08s) + # Allow some overhead + assert elapsed < 0.1 + + def test_results_order_preserved(self): + """Test that results match input order.""" + import random + + def executor(config): + # Random delay to encourage out-of-order completion + time.sleep(random.uniform(0.001, 0.01)) + return ScenarioResult( + scenario_id=config.scenario_id, + scenario_name=config.name, + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + final_value=Decimal(config.metadata.get("index", 0)), + ) + + runner = ScenarioRunner( + executor=executor, + mode=ExecutionMode.THREADED, + max_workers=4, + ) + scenarios = [ + ScenarioConfig(name=f"Test{i}", metadata={"index": i}) + for i in range(10) + ] + results = runner.run(scenarios) + + assert len(results) == 10 + for i, result in enumerate(results): + assert result.scenario_name == f"Test{i}" + assert result.final_value == Decimal(i) + + +# ============================================================================== +# ScenarioRunner Tests - Progress Tracking +# ============================================================================== + + +class TestScenarioRunnerProgress: + """Tests for ScenarioRunner progress tracking.""" + + def test_progress_callback(self): + """Test progress callback invocation.""" + progress_updates = [] + + def on_progress(progress): + progress_updates.append({ + "completed": progress.completed, + "total": progress.total_scenarios, + }) + + def executor(config): + return ScenarioResult( + scenario_id=config.scenario_id, + scenario_name=config.name, + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + ) + + runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL) + scenarios = [ScenarioConfig(name=f"Test{i}") for i in range(3)] + runner.run(scenarios, progress_callback=on_progress) + + # Should have received progress updates + assert len(progress_updates) > 0 + # Final update should show all completed + final = progress_updates[-1] + assert final["completed"] == 3 + + def test_get_progress(self): + """Test get_progress method.""" + def executor(config): + time.sleep(0.01) + return ScenarioResult( + scenario_id=config.scenario_id, + scenario_name=config.name, + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + ) + + runner = ScenarioRunner( + executor=executor, + mode=ExecutionMode.THREADED, + max_workers=2, + ) + + scenarios = [ScenarioConfig(name=f"Test{i}") for i in range(4)] + + # Start in background + results = [] + def run_in_thread(): + nonlocal results + results = runner.run(scenarios) + + thread = threading.Thread(target=run_in_thread) + thread.start() + + # Give it time to start + time.sleep(0.005) + progress = runner.get_progress() + assert progress.total_scenarios == 4 + + thread.join() + assert len(results) == 4 + + +# ============================================================================== +# ScenarioRunner Tests - Cancellation +# ============================================================================== + + +class TestScenarioRunnerCancellation: + """Tests for ScenarioRunner cancellation.""" + + def test_cancel_pending_scenarios(self): + """Test cancelling pending scenarios.""" + executed = [] + + def executor(config): + executed.append(config.name) + time.sleep(0.02) + return ScenarioResult( + scenario_id=config.scenario_id, + scenario_name=config.name, + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + ) + + runner = ScenarioRunner( + executor=executor, + mode=ExecutionMode.SEQUENTIAL, + ) + + scenarios = [ScenarioConfig(name=f"Test{i}") for i in range(5)] + + # Cancel after short delay + def cancel_after_delay(): + time.sleep(0.03) + runner.cancel() + + cancel_thread = threading.Thread(target=cancel_after_delay) + cancel_thread.start() + + results = runner.run(scenarios) + cancel_thread.join() + + # Some should be completed, some cancelled + completed = [r for r in results if r.status == ScenarioStatus.COMPLETED] + cancelled = [r for r in results if r.status == ScenarioStatus.CANCELLED] + + assert len(completed) + len(cancelled) == 5 + # At least the first one should have completed + assert len(completed) >= 1 + + +# ============================================================================== +# ScenarioBatchBuilder Tests +# ============================================================================== + + +class TestScenarioBatchBuilder: + """Tests for ScenarioBatchBuilder.""" + + def test_empty_build(self): + """Test building with no configuration.""" + builder = ScenarioBatchBuilder() + scenarios = builder.build() + assert len(scenarios) == 1 # One default scenario + + def test_base_config(self): + """Test setting base configuration.""" + builder = ScenarioBatchBuilder() + scenarios = ( + builder + .with_base_config( + initial_capital=Decimal("50000"), + symbols=["AAPL", "GOOGL"], + ) + .build() + ) + assert len(scenarios) == 1 + assert scenarios[0].initial_capital == Decimal("50000") + assert scenarios[0].symbols == ["AAPL", "GOOGL"] + + def test_parameter_variation(self): + """Test varying a single parameter.""" + builder = ScenarioBatchBuilder() + scenarios = ( + builder + .vary_parameter("leverage", [0.5, 1.0, 1.5]) + .build() + ) + assert len(scenarios) == 3 + leverages = [s.strategy_params["leverage"] for s in scenarios] + assert leverages == [0.5, 1.0, 1.5] + + def test_multiple_parameter_variations(self): + """Test varying multiple parameters (Cartesian product).""" + builder = ScenarioBatchBuilder() + scenarios = ( + builder + .vary_parameter("leverage", [1.0, 2.0]) + .vary_parameter("stop_loss", [0.05, 0.10]) + .build() + ) + assert len(scenarios) == 4 # 2 x 2 + + combinations = [ + (s.strategy_params["leverage"], s.strategy_params["stop_loss"]) + for s in scenarios + ] + expected = [ + (1.0, 0.05), + (1.0, 0.10), + (2.0, 0.05), + (2.0, 0.10), + ] + assert sorted(combinations) == sorted(expected) + + def test_explicit_date_ranges(self): + """Test setting explicit date ranges.""" + builder = ScenarioBatchBuilder() + scenarios = ( + builder + .with_date_ranges([ + (date(2020, 1, 1), date(2020, 12, 31)), + (date(2021, 1, 1), date(2021, 12, 31)), + ]) + .build() + ) + assert len(scenarios) == 2 + assert scenarios[0].start_date == date(2020, 1, 1) + assert scenarios[1].start_date == date(2021, 1, 1) + + def test_clear_builder(self): + """Test clearing builder configuration.""" + builder = ScenarioBatchBuilder() + builder.with_base_config(initial_capital=Decimal("50000")) + builder.vary_parameter("leverage", [1.0, 2.0]) + builder.clear() + + scenarios = builder.build() + assert len(scenarios) == 1 + assert scenarios[0].initial_capital == Decimal("100000") # Default + + def test_scenario_names_generated(self): + """Test that scenario names are auto-generated.""" + builder = ScenarioBatchBuilder() + scenarios = ( + builder + .vary_parameter("leverage", [1.0, 2.0]) + .build() + ) + assert "leverage=1.0" in scenarios[0].name + assert "leverage=2.0" in scenarios[1].name + + +# ============================================================================== +# Result Aggregation Tests +# ============================================================================== + + +class TestResultAggregation: + """Tests for result aggregation.""" + + def test_aggregate_empty_results(self): + """Test aggregating empty results list.""" + agg = aggregate_results([]) + assert agg["total_scenarios"] == 0 + assert agg["successful"] == 0 + assert agg["failed"] == 0 + + def test_aggregate_successful_results(self): + """Test aggregating successful results.""" + results = [ + ScenarioResult( + scenario_id=f"test-{i}", + scenario_name=f"Test{i}", + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + total_return=Decimal(f"0.{i+1}"), + max_drawdown=Decimal(f"-0.0{i+1}"), + duration_seconds=1.0, + ) + for i in range(3) + ] + + agg = aggregate_results(results) + assert agg["total_scenarios"] == 3 + assert agg["successful"] == 3 + assert agg["failed"] == 0 + assert agg["success_rate"] == 1.0 + assert "avg_return" in agg + assert "min_return" in agg + assert "max_return" in agg + + def test_aggregate_mixed_results(self): + """Test aggregating mixed success/failure results.""" + results = [ + ScenarioResult( + scenario_id="test-1", + scenario_name="Success1", + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + total_return=Decimal("0.10"), + ), + ScenarioResult( + scenario_id="test-2", + scenario_name="Failed1", + status=ScenarioStatus.FAILED, + start_time=datetime.now(), + error_message="Error", + ), + ScenarioResult( + scenario_id="test-3", + scenario_name="Success2", + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + total_return=Decimal("0.20"), + ), + ] + + agg = aggregate_results(results) + assert agg["total_scenarios"] == 3 + assert agg["successful"] == 2 + assert agg["failed"] == 1 + assert agg["success_rate"] == pytest.approx(0.667, rel=0.01) + + def test_aggregate_best_worst_scenarios(self): + """Test best/worst scenario identification.""" + results = [ + ScenarioResult( + scenario_id="test-1", + scenario_name="Low", + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + total_return=Decimal("0.05"), + ), + ScenarioResult( + scenario_id="test-2", + scenario_name="High", + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + total_return=Decimal("0.25"), + ), + ScenarioResult( + scenario_id="test-3", + scenario_name="Mid", + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + total_return=Decimal("0.15"), + ), + ] + + agg = aggregate_results(results) + assert agg["best_scenario"]["name"] == "High" + assert agg["worst_scenario"]["name"] == "Low" + + +# ============================================================================== +# Module Import Tests +# ============================================================================== + + +class TestModuleImports: + """Tests for module imports.""" + + def test_import_from_simulation_module(self): + """Test importing from simulation module.""" + from tradingagents.simulation import ( + ExecutionMode, + ScenarioStatus, + ScenarioConfig, + ScenarioResult, + RunnerProgress, + ScenarioRunner, + ScenarioBatchBuilder, + aggregate_results, + ) + assert ExecutionMode is not None + assert ScenarioStatus is not None + assert ScenarioConfig is not None + assert ScenarioResult is not None + assert RunnerProgress is not None + assert ScenarioRunner is not None + assert ScenarioBatchBuilder is not None + assert aggregate_results is not None + + +# ============================================================================== +# Integration Tests +# ============================================================================== + + +class TestScenarioRunnerIntegration: + """Integration tests for ScenarioRunner.""" + + def test_full_simulation_workflow(self): + """Test complete simulation workflow.""" + # Create executor that simulates trading + def trading_simulator(config: ScenarioConfig) -> ScenarioResult: + # Simulate based on strategy params + leverage = config.strategy_params.get("leverage", 1.0) + base_return = 0.08 # 8% base return + final_return = base_return * leverage + + final_value = config.initial_capital * Decimal(str(1 + final_return)) + + return ScenarioResult( + scenario_id=config.scenario_id, + scenario_name=config.name, + status=ScenarioStatus.COMPLETED, + start_time=datetime.now(), + end_time=datetime.now(), + final_value=final_value.quantize(Decimal("0.01")), + total_return=Decimal(str(final_return)).quantize(Decimal("0.0001")), + max_drawdown=Decimal("-0.10"), + trades_executed=25, + ) + + # Build scenarios with varying leverage + scenarios = ( + ScenarioBatchBuilder() + .with_base_config(initial_capital=Decimal("100000")) + .vary_parameter("leverage", [0.5, 1.0, 1.5, 2.0]) + .build() + ) + + # Run simulations + runner = ScenarioRunner( + executor=trading_simulator, + mode=ExecutionMode.THREADED, + max_workers=4, + ) + + progress_updates = [] + results = runner.run( + scenarios, + progress_callback=lambda p: progress_updates.append(p.completed), + ) + + # Verify results + assert len(results) == 4 + assert all(r.is_successful for r in results) + + # Aggregate and analyze + agg = aggregate_results(results) + assert agg["total_scenarios"] == 4 + assert agg["successful"] == 4 + assert agg["best_scenario"]["name"] == "leverage=2.0" + assert agg["worst_scenario"]["name"] == "leverage=0.5" diff --git a/tradingagents/simulation/__init__.py b/tradingagents/simulation/__init__.py new file mode 100644 index 00000000..0981dc3f --- /dev/null +++ b/tradingagents/simulation/__init__.py @@ -0,0 +1,96 @@ +"""Simulation module for portfolio simulations and backtesting. + +This module provides simulation capabilities including: +- Parallel scenario execution +- Parameter sweep analysis +- Strategy comparison +- Economic regime simulation + +Issue #33: [SIM-32] Scenario runner - parallel portfolio simulations + +Submodules: + scenario_runner: Core scenario execution framework + +Classes: + Enums: + - ExecutionMode: Parallel execution mode (sequential, threaded, process) + - ScenarioStatus: Status of a scenario run + + Data Classes: + - ScenarioConfig: Configuration for a simulation scenario + - ScenarioResult: Result from a scenario simulation + - RunnerProgress: Progress information for batch runs + + Main Classes: + - ScenarioRunner: Runner for parallel portfolio simulations + - ScenarioBatchBuilder: Builder for creating scenario batches + + Protocols: + - ScenarioExecutor: Protocol for scenario execution functions + + Utility Functions: + - aggregate_results: Aggregate results from multiple scenarios + +Example: + >>> from tradingagents.simulation import ( + ... ScenarioRunner, + ... ScenarioConfig, + ... ScenarioResult, + ... ScenarioStatus, + ... ExecutionMode, + ... ) + >>> from datetime import datetime + >>> from decimal import Decimal + >>> + >>> def simple_executor(config: ScenarioConfig) -> ScenarioResult: + ... return ScenarioResult( + ... scenario_id=config.scenario_id, + ... scenario_name=config.name, + ... status=ScenarioStatus.COMPLETED, + ... start_time=datetime.now(), + ... final_value=config.initial_capital * Decimal("1.1"), + ... total_return=Decimal("0.1"), + ... ) + >>> + >>> runner = ScenarioRunner(executor=simple_executor) + >>> scenarios = [ScenarioConfig(name="Test1"), ScenarioConfig(name="Test2")] + >>> results = runner.run(scenarios) +""" + +from .scenario_runner import ( + # Enums + ExecutionMode, + ScenarioStatus, + # Data Classes + ScenarioConfig, + ScenarioResult, + RunnerProgress, + # Main Classes + ScenarioRunner, + ScenarioBatchBuilder, + # Protocols + ScenarioExecutor, + # Types + ProgressCallback, + # Utility Functions + aggregate_results, +) + +__all__ = [ + # Enums + "ExecutionMode", + "ScenarioStatus", + # Data Classes + "ScenarioConfig", + "ScenarioResult", + "RunnerProgress", + # Main Classes + "ScenarioRunner", + "ScenarioBatchBuilder", + # Protocols + "ScenarioExecutor", + # Types + "ProgressCallback", + # Utility Functions + "aggregate_results", +] diff --git a/tradingagents/simulation/scenario_runner.py b/tradingagents/simulation/scenario_runner.py new file mode 100644 index 00000000..38de43fc --- /dev/null +++ b/tradingagents/simulation/scenario_runner.py @@ -0,0 +1,689 @@ +"""Scenario Runner for parallel portfolio simulations. + +This module provides scenario-based simulation capabilities including: +- Parallel execution of multiple portfolio scenarios +- Configurable simulation parameters +- Progress tracking and callbacks +- Result aggregation and comparison + +Issue #33: [SIM-32] Scenario runner - parallel portfolio simulations + +Design Principles: + - Thread-safe parallel execution + - Configurable concurrency limits + - Memory-efficient result handling + - Progress reporting callbacks +""" + +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed +from dataclasses import dataclass, field +from datetime import datetime, date +from decimal import Decimal +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union +import copy +import threading +import time +import uuid + + +class ExecutionMode(Enum): + """Mode of parallel execution.""" + SEQUENTIAL = "sequential" # Run scenarios one at a time + THREADED = "threaded" # Use thread pool (for I/O bound) + PROCESS = "process" # Use process pool (for CPU bound) + + +class ScenarioStatus(Enum): + """Status of a scenario run.""" + PENDING = "pending" # Not yet started + RUNNING = "running" # Currently executing + COMPLETED = "completed" # Finished successfully + FAILED = "failed" # Finished with error + CANCELLED = "cancelled" # Cancelled before completion + + +@dataclass +class ScenarioConfig: + """Configuration for a simulation scenario. + + Attributes: + scenario_id: Unique identifier for the scenario + name: Human-readable name + start_date: Simulation start date + end_date: Simulation end date + initial_capital: Starting capital + symbols: List of symbols to include + strategy_params: Strategy-specific parameters + risk_params: Risk management parameters + metadata: Additional scenario data + """ + scenario_id: str = field(default_factory=lambda: str(uuid.uuid4())) + name: str = "" + start_date: Optional[date] = None + end_date: Optional[date] = None + initial_capital: Decimal = Decimal("100000") + symbols: List[str] = field(default_factory=list) + strategy_params: Dict[str, Any] = field(default_factory=dict) + risk_params: Dict[str, Any] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + """Set default name if not provided.""" + if not self.name: + self.name = f"Scenario-{self.scenario_id[:8]}" + + +@dataclass +class ScenarioResult: + """Result from a scenario simulation. + + Attributes: + scenario_id: ID of the scenario that was run + scenario_name: Name of the scenario + status: Final status of the run + start_time: When simulation started + end_time: When simulation ended + duration_seconds: Total runtime in seconds + final_value: Final portfolio value + total_return: Total return as decimal (0.10 = 10%) + trades_executed: Number of trades made + max_drawdown: Maximum drawdown experienced + sharpe_ratio: Sharpe ratio if calculable + error_message: Error message if failed + portfolio_history: Time series of portfolio values + trade_history: List of trades executed + metrics: Additional performance metrics + metadata: Additional result data + """ + scenario_id: str + scenario_name: str + status: ScenarioStatus + start_time: datetime + end_time: Optional[datetime] = None + duration_seconds: float = 0.0 + final_value: Decimal = Decimal("0") + total_return: Decimal = Decimal("0") + trades_executed: int = 0 + max_drawdown: Decimal = Decimal("0") + sharpe_ratio: Optional[Decimal] = None + error_message: Optional[str] = None + portfolio_history: List[Tuple[date, Decimal]] = field(default_factory=list) + trade_history: List[Dict[str, Any]] = field(default_factory=list) + metrics: Dict[str, Any] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def is_successful(self) -> bool: + """Check if scenario completed successfully.""" + return self.status == ScenarioStatus.COMPLETED + + @property + def is_finished(self) -> bool: + """Check if scenario has finished (success or failure).""" + return self.status in ( + ScenarioStatus.COMPLETED, + ScenarioStatus.FAILED, + ScenarioStatus.CANCELLED, + ) + + +class ScenarioExecutor(Protocol): + """Protocol for scenario execution functions. + + Implementations should take a ScenarioConfig and return a ScenarioResult. + """ + + def __call__(self, config: ScenarioConfig) -> ScenarioResult: + """Execute a scenario and return results.""" + ... + + +@dataclass +class RunnerProgress: + """Progress information for a scenario run batch. + + Attributes: + total_scenarios: Total number of scenarios + completed: Number of completed scenarios + failed: Number of failed scenarios + running: Number of currently running scenarios + pending: Number of pending scenarios + start_time: When the batch started + estimated_remaining_seconds: Estimated time remaining + """ + total_scenarios: int + completed: int = 0 + failed: int = 0 + running: int = 0 + pending: int = 0 + start_time: Optional[datetime] = None + estimated_remaining_seconds: Optional[float] = None + + @property + def progress_percent(self) -> float: + """Calculate completion percentage.""" + if self.total_scenarios == 0: + return 100.0 + return (self.completed + self.failed) / self.total_scenarios * 100 + + @property + def is_complete(self) -> bool: + """Check if all scenarios are finished.""" + return (self.completed + self.failed) >= self.total_scenarios + + +ProgressCallback = Callable[[RunnerProgress], None] + + +class ScenarioRunner: + """Runner for parallel portfolio simulations. + + Executes multiple simulation scenarios in parallel using configurable + execution modes (sequential, threaded, or process-based). + + Example: + >>> def simulate(config: ScenarioConfig) -> ScenarioResult: + ... # Your simulation logic + ... return ScenarioResult( + ... scenario_id=config.scenario_id, + ... scenario_name=config.name, + ... status=ScenarioStatus.COMPLETED, + ... start_time=datetime.now(), + ... ) + ... + >>> runner = ScenarioRunner(executor=simulate) + >>> scenarios = [ + ... ScenarioConfig(name="Bull Market", strategy_params={"leverage": 1.5}), + ... ScenarioConfig(name="Bear Market", strategy_params={"leverage": 0.5}), + ... ] + >>> results = runner.run(scenarios) + >>> print(f"Completed: {len([r for r in results if r.is_successful])}") + """ + + def __init__( + self, + executor: ScenarioExecutor, + mode: ExecutionMode = ExecutionMode.THREADED, + max_workers: Optional[int] = None, + timeout_seconds: Optional[float] = None, + ): + """Initialize the scenario runner. + + Args: + executor: Function that executes a single scenario + mode: Execution mode (sequential, threaded, process) + max_workers: Maximum number of parallel workers (None = auto) + timeout_seconds: Timeout per scenario (None = no timeout) + """ + self.executor = executor + self.mode = mode + self.max_workers = max_workers + self.timeout_seconds = timeout_seconds + self._lock = threading.Lock() + self._cancelled = False + self._progress = RunnerProgress(total_scenarios=0) + self._progress_callbacks: List[ProgressCallback] = [] + + def add_progress_callback(self, callback: ProgressCallback) -> None: + """Add a callback for progress updates. + + Args: + callback: Function to call with progress updates + """ + with self._lock: + self._progress_callbacks.append(callback) + + def remove_progress_callback(self, callback: ProgressCallback) -> None: + """Remove a progress callback. + + Args: + callback: Callback to remove + """ + with self._lock: + if callback in self._progress_callbacks: + self._progress_callbacks.remove(callback) + + def _notify_progress(self) -> None: + """Notify all registered progress callbacks.""" + with self._lock: + callbacks = self._progress_callbacks.copy() + progress = copy.copy(self._progress) + + for callback in callbacks: + try: + callback(progress) + except Exception: + pass # Don't let callback errors affect execution + + def _update_progress( + self, + completed_delta: int = 0, + failed_delta: int = 0, + running_delta: int = 0, + pending_delta: int = 0, + ) -> None: + """Update progress counters thread-safely.""" + with self._lock: + self._progress.completed += completed_delta + self._progress.failed += failed_delta + self._progress.running += running_delta + self._progress.pending += pending_delta + + # Estimate remaining time + if self._progress.start_time and self._progress.completed > 0: + elapsed = (datetime.now() - self._progress.start_time).total_seconds() + avg_time = elapsed / self._progress.completed + remaining = self._progress.pending + self._progress.running + self._progress.estimated_remaining_seconds = avg_time * remaining + + self._notify_progress() + + def _execute_scenario(self, config: ScenarioConfig) -> ScenarioResult: + """Execute a single scenario with error handling. + + Args: + config: Scenario configuration + + Returns: + ScenarioResult with outcome + """ + start_time = datetime.now() + + # Check if cancelled + if self._cancelled: + return ScenarioResult( + scenario_id=config.scenario_id, + scenario_name=config.name, + status=ScenarioStatus.CANCELLED, + start_time=start_time, + end_time=datetime.now(), + ) + + self._update_progress(running_delta=1, pending_delta=-1) + + try: + result = self.executor(config) + result.end_time = datetime.now() + result.duration_seconds = (result.end_time - start_time).total_seconds() + + if result.status == ScenarioStatus.COMPLETED: + self._update_progress(completed_delta=1, running_delta=-1) + else: + self._update_progress(failed_delta=1, running_delta=-1) + + return result + + except Exception as e: + end_time = datetime.now() + self._update_progress(failed_delta=1, running_delta=-1) + return ScenarioResult( + scenario_id=config.scenario_id, + scenario_name=config.name, + status=ScenarioStatus.FAILED, + start_time=start_time, + end_time=end_time, + duration_seconds=(end_time - start_time).total_seconds(), + error_message=str(e), + ) + + def run( + self, + scenarios: List[ScenarioConfig], + progress_callback: Optional[ProgressCallback] = None, + ) -> List[ScenarioResult]: + """Run multiple scenarios. + + Args: + scenarios: List of scenario configurations to run + progress_callback: Optional callback for progress updates + + Returns: + List of results in the same order as input scenarios + """ + if not scenarios: + return [] + + # Reset state + self._cancelled = False + self._progress = RunnerProgress( + total_scenarios=len(scenarios), + pending=len(scenarios), + start_time=datetime.now(), + ) + + if progress_callback: + self.add_progress_callback(progress_callback) + + try: + if self.mode == ExecutionMode.SEQUENTIAL: + results = self._run_sequential(scenarios) + elif self.mode == ExecutionMode.THREADED: + results = self._run_parallel(scenarios, ThreadPoolExecutor) + elif self.mode == ExecutionMode.PROCESS: + results = self._run_parallel(scenarios, ProcessPoolExecutor) + else: + raise ValueError(f"Unknown execution mode: {self.mode}") + + return results + + finally: + if progress_callback: + self.remove_progress_callback(progress_callback) + + def _run_sequential( + self, scenarios: List[ScenarioConfig] + ) -> List[ScenarioResult]: + """Run scenarios sequentially. + + Args: + scenarios: List of scenarios to run + + Returns: + List of results in order + """ + results = [] + for config in scenarios: + if self._cancelled: + results.append(ScenarioResult( + scenario_id=config.scenario_id, + scenario_name=config.name, + status=ScenarioStatus.CANCELLED, + start_time=datetime.now(), + )) + else: + results.append(self._execute_scenario(config)) + return results + + def _run_parallel( + self, + scenarios: List[ScenarioConfig], + executor_class: type, + ) -> List[ScenarioResult]: + """Run scenarios in parallel using a pool executor. + + Args: + scenarios: List of scenarios to run + executor_class: ThreadPoolExecutor or ProcessPoolExecutor + + Returns: + List of results in original order + """ + # Map scenario IDs to their original indices + id_to_index = {config.scenario_id: i for i, config in enumerate(scenarios)} + results = [None] * len(scenarios) + + with executor_class(max_workers=self.max_workers) as pool: + # Submit all scenarios + future_to_id = { + pool.submit(self._execute_scenario, config): config.scenario_id + for config in scenarios + } + + # Collect results as they complete + for future in as_completed(future_to_id, timeout=self.timeout_seconds): + scenario_id = future_to_id[future] + index = id_to_index[scenario_id] + + try: + result = future.result() + results[index] = result + except Exception as e: + # Create error result for this scenario + config = scenarios[index] + results[index] = ScenarioResult( + scenario_id=config.scenario_id, + scenario_name=config.name, + status=ScenarioStatus.FAILED, + start_time=datetime.now(), + error_message=str(e), + ) + + return results + + def cancel(self) -> None: + """Cancel all pending scenarios. + + Running scenarios will complete, but pending ones will be skipped. + """ + with self._lock: + self._cancelled = True + + def get_progress(self) -> RunnerProgress: + """Get current progress information. + + Returns: + Current progress state + """ + with self._lock: + return copy.copy(self._progress) + + +class ScenarioBatchBuilder: + """Builder for creating batches of scenario configurations. + + Provides convenient methods for generating variations of scenarios + for sensitivity analysis, parameter sweeps, etc. + + Example: + >>> builder = ScenarioBatchBuilder() + >>> scenarios = ( + ... builder + ... .with_base_config(symbols=["AAPL", "GOOGL"]) + ... .vary_parameter("leverage", [0.5, 1.0, 1.5, 2.0]) + ... .vary_date_range( + ... date(2020, 1, 1), + ... date(2023, 12, 31), + ... window_months=12, + ... ) + ... .build() + ... ) + """ + + def __init__(self): + """Initialize the batch builder.""" + self._base_config: Dict[str, Any] = {} + self._parameter_variations: Dict[str, List[Any]] = {} + self._date_ranges: List[Tuple[date, date]] = [] + + def with_base_config(self, **kwargs) -> "ScenarioBatchBuilder": + """Set base configuration for all scenarios. + + Args: + **kwargs: Configuration parameters + + Returns: + Self for chaining + """ + self._base_config.update(kwargs) + return self + + def vary_parameter( + self, name: str, values: List[Any] + ) -> "ScenarioBatchBuilder": + """Add a parameter to vary across scenarios. + + Args: + name: Parameter name (in strategy_params or risk_params) + values: List of values to use + + Returns: + Self for chaining + """ + self._parameter_variations[name] = values + return self + + def vary_date_range( + self, + start: date, + end: date, + window_months: int = 12, + step_months: int = 3, + ) -> "ScenarioBatchBuilder": + """Add rolling date windows. + + Args: + start: Overall start date + end: Overall end date + window_months: Size of each window in months + step_months: Step between windows in months + + Returns: + Self for chaining + """ + from datetime import timedelta + from dateutil.relativedelta import relativedelta + + current_start = start + while current_start < end: + current_end = current_start + relativedelta(months=window_months) + if current_end > end: + current_end = end + self._date_ranges.append((current_start, current_end)) + current_start += relativedelta(months=step_months) + + return self + + def with_date_ranges( + self, ranges: List[Tuple[date, date]] + ) -> "ScenarioBatchBuilder": + """Set explicit date ranges. + + Args: + ranges: List of (start_date, end_date) tuples + + Returns: + Self for chaining + """ + self._date_ranges.extend(ranges) + return self + + def build(self) -> List[ScenarioConfig]: + """Build all scenario configurations. + + Creates the Cartesian product of all variations. + + Returns: + List of ScenarioConfig objects + """ + import itertools + + scenarios = [] + + # Get all parameter combinations + param_names = list(self._parameter_variations.keys()) + param_values = [self._parameter_variations[name] for name in param_names] + + if param_values: + param_combinations = list(itertools.product(*param_values)) + else: + param_combinations = [()] + + # Get date ranges (use single None tuple if no ranges) + date_ranges = self._date_ranges if self._date_ranges else [(None, None)] + + # Generate all combinations + for param_combo in param_combinations: + for start_date, end_date in date_ranges: + # Build scenario config + config_dict = copy.deepcopy(self._base_config) + + # Set dates + if start_date is not None: + config_dict["start_date"] = start_date + if end_date is not None: + config_dict["end_date"] = end_date + + # Set parameter variations + strategy_params = config_dict.get("strategy_params", {}) + for name, value in zip(param_names, param_combo): + strategy_params[name] = value + config_dict["strategy_params"] = strategy_params + + # Generate descriptive name + name_parts = [] + for name, value in zip(param_names, param_combo): + name_parts.append(f"{name}={value}") + if start_date: + name_parts.append(f"{start_date.year}") + + if name_parts: + config_dict["name"] = " | ".join(name_parts) + + scenarios.append(ScenarioConfig(**config_dict)) + + return scenarios + + def clear(self) -> "ScenarioBatchBuilder": + """Clear all configuration. + + Returns: + Self for chaining + """ + self._base_config = {} + self._parameter_variations = {} + self._date_ranges = [] + return self + + +def aggregate_results(results: List[ScenarioResult]) -> Dict[str, Any]: + """Aggregate results from multiple scenarios. + + Args: + results: List of scenario results + + Returns: + Dictionary with aggregated statistics + """ + if not results: + return { + "total_scenarios": 0, + "successful": 0, + "failed": 0, + } + + successful = [r for r in results if r.is_successful] + failed = [r for r in results if r.status == ScenarioStatus.FAILED] + + # Calculate aggregate metrics + returns = [float(r.total_return) for r in successful if r.total_return] + drawdowns = [float(r.max_drawdown) for r in successful if r.max_drawdown] + durations = [r.duration_seconds for r in results if r.duration_seconds] + + aggregate = { + "total_scenarios": len(results), + "successful": len(successful), + "failed": len(failed), + "success_rate": len(successful) / len(results) if results else 0, + "total_duration_seconds": sum(durations), + "avg_duration_seconds": sum(durations) / len(durations) if durations else 0, + } + + if returns: + aggregate.update({ + "avg_return": sum(returns) / len(returns), + "min_return": min(returns), + "max_return": max(returns), + "median_return": sorted(returns)[len(returns) // 2], + }) + + if drawdowns: + aggregate.update({ + "avg_max_drawdown": sum(drawdowns) / len(drawdowns), + "worst_drawdown": min(drawdowns), # More negative is worse + }) + + # Best and worst scenarios + if successful: + best = max(successful, key=lambda r: float(r.total_return or 0)) + worst = min(successful, key=lambda r: float(r.total_return or 0)) + aggregate["best_scenario"] = { + "name": best.scenario_name, + "return": str(best.total_return), + } + aggregate["worst_scenario"] = { + "name": worst.scenario_name, + "return": str(worst.total_return), + } + + return aggregate