feat(simulation): add Scenario Runner for parallel portfolio simulations - Issue #33 (45 tests)
Implements parallel scenario execution framework: - ScenarioRunner with sequential, threaded, and process execution modes - ScenarioConfig for configuring simulation parameters - ScenarioResult for capturing simulation outcomes - RunnerProgress for tracking execution progress - Progress callbacks for real-time updates - Cancellation support for long-running batches - ScenarioBatchBuilder for parameter sweeps and variations - Result aggregation with best/worst scenario identification Features: - Thread-safe parallel execution with configurable worker count - FIFO result ordering preserved regardless of completion order - Exception handling with graceful degradation - Timeout support per scenario - Cartesian product generation for parameter variations 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
13f2bba0b3
commit
e7bff2c4cf
|
|
@ -0,0 +1 @@
|
|||
"""Unit tests for the simulation module."""
|
||||
|
|
@ -0,0 +1,835 @@
|
|||
"""Tests for Scenario Runner.
|
||||
|
||||
Issue #33: [SIM-32] Scenario runner - parallel portfolio simulations
|
||||
|
||||
Tests cover:
|
||||
- ExecutionMode and ScenarioStatus enums
|
||||
- ScenarioConfig and ScenarioResult dataclasses
|
||||
- ScenarioRunner sequential execution
|
||||
- ScenarioRunner parallel execution
|
||||
- Progress tracking and callbacks
|
||||
- ScenarioBatchBuilder variations
|
||||
- Result aggregation
|
||||
- Error handling
|
||||
- Cancellation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
from unittest.mock import Mock, patch
|
||||
import threading
|
||||
|
||||
from tradingagents.simulation.scenario_runner import (
|
||||
ExecutionMode,
|
||||
ScenarioStatus,
|
||||
ScenarioConfig,
|
||||
ScenarioResult,
|
||||
RunnerProgress,
|
||||
ScenarioRunner,
|
||||
ScenarioBatchBuilder,
|
||||
aggregate_results,
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# ExecutionMode Enum Tests
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TestExecutionMode:
|
||||
"""Tests for ExecutionMode enum."""
|
||||
|
||||
def test_sequential_value(self):
|
||||
"""Test SEQUENTIAL mode value."""
|
||||
assert ExecutionMode.SEQUENTIAL.value == "sequential"
|
||||
|
||||
def test_threaded_value(self):
|
||||
"""Test THREADED mode value."""
|
||||
assert ExecutionMode.THREADED.value == "threaded"
|
||||
|
||||
def test_process_value(self):
|
||||
"""Test PROCESS mode value."""
|
||||
assert ExecutionMode.PROCESS.value == "process"
|
||||
|
||||
def test_all_modes_exist(self):
|
||||
"""Test all expected modes exist."""
|
||||
modes = [m for m in ExecutionMode]
|
||||
assert len(modes) == 3
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# ScenarioStatus Enum Tests
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TestScenarioStatus:
|
||||
"""Tests for ScenarioStatus enum."""
|
||||
|
||||
def test_pending_value(self):
|
||||
"""Test PENDING status value."""
|
||||
assert ScenarioStatus.PENDING.value == "pending"
|
||||
|
||||
def test_running_value(self):
|
||||
"""Test RUNNING status value."""
|
||||
assert ScenarioStatus.RUNNING.value == "running"
|
||||
|
||||
def test_completed_value(self):
|
||||
"""Test COMPLETED status value."""
|
||||
assert ScenarioStatus.COMPLETED.value == "completed"
|
||||
|
||||
def test_failed_value(self):
|
||||
"""Test FAILED status value."""
|
||||
assert ScenarioStatus.FAILED.value == "failed"
|
||||
|
||||
def test_cancelled_value(self):
|
||||
"""Test CANCELLED status value."""
|
||||
assert ScenarioStatus.CANCELLED.value == "cancelled"
|
||||
|
||||
def test_all_statuses_exist(self):
|
||||
"""Test all expected statuses exist."""
|
||||
statuses = [s for s in ScenarioStatus]
|
||||
assert len(statuses) == 5
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# ScenarioConfig Tests
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TestScenarioConfig:
|
||||
"""Tests for ScenarioConfig dataclass."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test creating config with defaults."""
|
||||
config = ScenarioConfig()
|
||||
assert config.scenario_id is not None
|
||||
assert config.name.startswith("Scenario-")
|
||||
assert config.initial_capital == Decimal("100000")
|
||||
assert config.symbols == []
|
||||
assert config.strategy_params == {}
|
||||
assert config.risk_params == {}
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test creating config with custom values."""
|
||||
config = ScenarioConfig(
|
||||
name="Bull Market Test",
|
||||
start_date=date(2023, 1, 1),
|
||||
end_date=date(2023, 12, 31),
|
||||
initial_capital=Decimal("50000"),
|
||||
symbols=["AAPL", "GOOGL"],
|
||||
strategy_params={"leverage": 1.5},
|
||||
risk_params={"max_drawdown": 0.2},
|
||||
)
|
||||
assert config.name == "Bull Market Test"
|
||||
assert config.start_date == date(2023, 1, 1)
|
||||
assert config.end_date == date(2023, 12, 31)
|
||||
assert config.initial_capital == Decimal("50000")
|
||||
assert config.symbols == ["AAPL", "GOOGL"]
|
||||
assert config.strategy_params["leverage"] == 1.5
|
||||
assert config.risk_params["max_drawdown"] == 0.2
|
||||
|
||||
def test_auto_generated_name(self):
|
||||
"""Test auto-generated name from scenario_id."""
|
||||
config = ScenarioConfig()
|
||||
# Name should start with "Scenario-" and contain part of the ID
|
||||
assert config.name.startswith("Scenario-")
|
||||
assert len(config.name) > 8
|
||||
|
||||
def test_explicit_scenario_id(self):
|
||||
"""Test explicit scenario ID."""
|
||||
config = ScenarioConfig(scenario_id="test-123")
|
||||
assert config.scenario_id == "test-123"
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# ScenarioResult Tests
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TestScenarioResult:
|
||||
"""Tests for ScenarioResult dataclass."""
|
||||
|
||||
def test_successful_result(self):
|
||||
"""Test creating a successful result."""
|
||||
result = ScenarioResult(
|
||||
scenario_id="test-1",
|
||||
scenario_name="Test",
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
final_value=Decimal("110000"),
|
||||
total_return=Decimal("0.10"),
|
||||
)
|
||||
assert result.is_successful is True
|
||||
assert result.is_finished is True
|
||||
|
||||
def test_failed_result(self):
|
||||
"""Test creating a failed result."""
|
||||
result = ScenarioResult(
|
||||
scenario_id="test-1",
|
||||
scenario_name="Test",
|
||||
status=ScenarioStatus.FAILED,
|
||||
start_time=datetime.now(),
|
||||
error_message="Simulation error",
|
||||
)
|
||||
assert result.is_successful is False
|
||||
assert result.is_finished is True
|
||||
assert result.error_message == "Simulation error"
|
||||
|
||||
def test_pending_result(self):
|
||||
"""Test pending result properties."""
|
||||
result = ScenarioResult(
|
||||
scenario_id="test-1",
|
||||
scenario_name="Test",
|
||||
status=ScenarioStatus.PENDING,
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
assert result.is_successful is False
|
||||
assert result.is_finished is False
|
||||
|
||||
def test_running_result(self):
|
||||
"""Test running result properties."""
|
||||
result = ScenarioResult(
|
||||
scenario_id="test-1",
|
||||
scenario_name="Test",
|
||||
status=ScenarioStatus.RUNNING,
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
assert result.is_successful is False
|
||||
assert result.is_finished is False
|
||||
|
||||
def test_cancelled_result(self):
|
||||
"""Test cancelled result properties."""
|
||||
result = ScenarioResult(
|
||||
scenario_id="test-1",
|
||||
scenario_name="Test",
|
||||
status=ScenarioStatus.CANCELLED,
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
assert result.is_successful is False
|
||||
assert result.is_finished is True
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# RunnerProgress Tests
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TestRunnerProgress:
|
||||
"""Tests for RunnerProgress dataclass."""
|
||||
|
||||
def test_progress_percent(self):
|
||||
"""Test progress percentage calculation."""
|
||||
progress = RunnerProgress(
|
||||
total_scenarios=10,
|
||||
completed=3,
|
||||
failed=2,
|
||||
)
|
||||
assert progress.progress_percent == 50.0
|
||||
|
||||
def test_progress_percent_zero_total(self):
|
||||
"""Test progress with zero total scenarios."""
|
||||
progress = RunnerProgress(total_scenarios=0)
|
||||
assert progress.progress_percent == 100.0
|
||||
|
||||
def test_is_complete(self):
|
||||
"""Test completion detection."""
|
||||
progress = RunnerProgress(
|
||||
total_scenarios=5,
|
||||
completed=3,
|
||||
failed=2,
|
||||
)
|
||||
assert progress.is_complete is True
|
||||
|
||||
def test_not_complete(self):
|
||||
"""Test incomplete detection."""
|
||||
progress = RunnerProgress(
|
||||
total_scenarios=5,
|
||||
completed=2,
|
||||
failed=1,
|
||||
)
|
||||
assert progress.is_complete is False
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# ScenarioRunner Tests - Sequential Execution
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TestScenarioRunnerSequential:
|
||||
"""Tests for ScenarioRunner sequential execution."""
|
||||
|
||||
def test_run_empty_list(self):
|
||||
"""Test running with empty scenario list."""
|
||||
def executor(config):
|
||||
return ScenarioResult(
|
||||
scenario_id=config.scenario_id,
|
||||
scenario_name=config.name,
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL)
|
||||
results = runner.run([])
|
||||
assert results == []
|
||||
|
||||
def test_run_single_scenario(self):
|
||||
"""Test running a single scenario."""
|
||||
def executor(config):
|
||||
return ScenarioResult(
|
||||
scenario_id=config.scenario_id,
|
||||
scenario_name=config.name,
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
final_value=Decimal("110000"),
|
||||
total_return=Decimal("0.10"),
|
||||
)
|
||||
|
||||
runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL)
|
||||
scenarios = [ScenarioConfig(name="Test1")]
|
||||
results = runner.run(scenarios)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].is_successful
|
||||
assert results[0].scenario_name == "Test1"
|
||||
|
||||
def test_run_multiple_scenarios(self):
|
||||
"""Test running multiple scenarios sequentially."""
|
||||
call_order = []
|
||||
|
||||
def executor(config):
|
||||
call_order.append(config.name)
|
||||
return ScenarioResult(
|
||||
scenario_id=config.scenario_id,
|
||||
scenario_name=config.name,
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL)
|
||||
scenarios = [
|
||||
ScenarioConfig(name="Test1"),
|
||||
ScenarioConfig(name="Test2"),
|
||||
ScenarioConfig(name="Test3"),
|
||||
]
|
||||
results = runner.run(scenarios)
|
||||
|
||||
assert len(results) == 3
|
||||
assert call_order == ["Test1", "Test2", "Test3"]
|
||||
assert all(r.is_successful for r in results)
|
||||
|
||||
def test_run_with_executor_exception(self):
|
||||
"""Test handling executor exceptions."""
|
||||
def executor(config):
|
||||
if config.name == "FailMe":
|
||||
raise ValueError("Simulated error")
|
||||
return ScenarioResult(
|
||||
scenario_id=config.scenario_id,
|
||||
scenario_name=config.name,
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL)
|
||||
scenarios = [
|
||||
ScenarioConfig(name="Test1"),
|
||||
ScenarioConfig(name="FailMe"),
|
||||
ScenarioConfig(name="Test3"),
|
||||
]
|
||||
results = runner.run(scenarios)
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0].is_successful
|
||||
assert results[1].status == ScenarioStatus.FAILED
|
||||
assert "Simulated error" in results[1].error_message
|
||||
assert results[2].is_successful
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# ScenarioRunner Tests - Parallel Execution
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TestScenarioRunnerParallel:
|
||||
"""Tests for ScenarioRunner parallel execution."""
|
||||
|
||||
def test_run_threaded(self):
|
||||
"""Test running scenarios in parallel using threads."""
|
||||
def executor(config):
|
||||
time.sleep(0.01) # Small delay to ensure parallelism
|
||||
return ScenarioResult(
|
||||
scenario_id=config.scenario_id,
|
||||
scenario_name=config.name,
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
runner = ScenarioRunner(
|
||||
executor=executor,
|
||||
mode=ExecutionMode.THREADED,
|
||||
max_workers=4,
|
||||
)
|
||||
scenarios = [ScenarioConfig(name=f"Test{i}") for i in range(8)]
|
||||
|
||||
start = time.time()
|
||||
results = runner.run(scenarios)
|
||||
elapsed = time.time() - start
|
||||
|
||||
assert len(results) == 8
|
||||
assert all(r.is_successful for r in results)
|
||||
# With 4 workers and 8 scenarios at 0.01s each,
|
||||
# parallel should be faster than sequential (0.08s)
|
||||
# Allow some overhead
|
||||
assert elapsed < 0.1
|
||||
|
||||
def test_results_order_preserved(self):
|
||||
"""Test that results match input order."""
|
||||
import random
|
||||
|
||||
def executor(config):
|
||||
# Random delay to encourage out-of-order completion
|
||||
time.sleep(random.uniform(0.001, 0.01))
|
||||
return ScenarioResult(
|
||||
scenario_id=config.scenario_id,
|
||||
scenario_name=config.name,
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
final_value=Decimal(config.metadata.get("index", 0)),
|
||||
)
|
||||
|
||||
runner = ScenarioRunner(
|
||||
executor=executor,
|
||||
mode=ExecutionMode.THREADED,
|
||||
max_workers=4,
|
||||
)
|
||||
scenarios = [
|
||||
ScenarioConfig(name=f"Test{i}", metadata={"index": i})
|
||||
for i in range(10)
|
||||
]
|
||||
results = runner.run(scenarios)
|
||||
|
||||
assert len(results) == 10
|
||||
for i, result in enumerate(results):
|
||||
assert result.scenario_name == f"Test{i}"
|
||||
assert result.final_value == Decimal(i)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# ScenarioRunner Tests - Progress Tracking
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TestScenarioRunnerProgress:
|
||||
"""Tests for ScenarioRunner progress tracking."""
|
||||
|
||||
def test_progress_callback(self):
|
||||
"""Test progress callback invocation."""
|
||||
progress_updates = []
|
||||
|
||||
def on_progress(progress):
|
||||
progress_updates.append({
|
||||
"completed": progress.completed,
|
||||
"total": progress.total_scenarios,
|
||||
})
|
||||
|
||||
def executor(config):
|
||||
return ScenarioResult(
|
||||
scenario_id=config.scenario_id,
|
||||
scenario_name=config.name,
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
runner = ScenarioRunner(executor=executor, mode=ExecutionMode.SEQUENTIAL)
|
||||
scenarios = [ScenarioConfig(name=f"Test{i}") for i in range(3)]
|
||||
runner.run(scenarios, progress_callback=on_progress)
|
||||
|
||||
# Should have received progress updates
|
||||
assert len(progress_updates) > 0
|
||||
# Final update should show all completed
|
||||
final = progress_updates[-1]
|
||||
assert final["completed"] == 3
|
||||
|
||||
def test_get_progress(self):
|
||||
"""Test get_progress method."""
|
||||
def executor(config):
|
||||
time.sleep(0.01)
|
||||
return ScenarioResult(
|
||||
scenario_id=config.scenario_id,
|
||||
scenario_name=config.name,
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
runner = ScenarioRunner(
|
||||
executor=executor,
|
||||
mode=ExecutionMode.THREADED,
|
||||
max_workers=2,
|
||||
)
|
||||
|
||||
scenarios = [ScenarioConfig(name=f"Test{i}") for i in range(4)]
|
||||
|
||||
# Start in background
|
||||
results = []
|
||||
def run_in_thread():
|
||||
nonlocal results
|
||||
results = runner.run(scenarios)
|
||||
|
||||
thread = threading.Thread(target=run_in_thread)
|
||||
thread.start()
|
||||
|
||||
# Give it time to start
|
||||
time.sleep(0.005)
|
||||
progress = runner.get_progress()
|
||||
assert progress.total_scenarios == 4
|
||||
|
||||
thread.join()
|
||||
assert len(results) == 4
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# ScenarioRunner Tests - Cancellation
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TestScenarioRunnerCancellation:
|
||||
"""Tests for ScenarioRunner cancellation."""
|
||||
|
||||
def test_cancel_pending_scenarios(self):
|
||||
"""Test cancelling pending scenarios."""
|
||||
executed = []
|
||||
|
||||
def executor(config):
|
||||
executed.append(config.name)
|
||||
time.sleep(0.02)
|
||||
return ScenarioResult(
|
||||
scenario_id=config.scenario_id,
|
||||
scenario_name=config.name,
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
runner = ScenarioRunner(
|
||||
executor=executor,
|
||||
mode=ExecutionMode.SEQUENTIAL,
|
||||
)
|
||||
|
||||
scenarios = [ScenarioConfig(name=f"Test{i}") for i in range(5)]
|
||||
|
||||
# Cancel after short delay
|
||||
def cancel_after_delay():
|
||||
time.sleep(0.03)
|
||||
runner.cancel()
|
||||
|
||||
cancel_thread = threading.Thread(target=cancel_after_delay)
|
||||
cancel_thread.start()
|
||||
|
||||
results = runner.run(scenarios)
|
||||
cancel_thread.join()
|
||||
|
||||
# Some should be completed, some cancelled
|
||||
completed = [r for r in results if r.status == ScenarioStatus.COMPLETED]
|
||||
cancelled = [r for r in results if r.status == ScenarioStatus.CANCELLED]
|
||||
|
||||
assert len(completed) + len(cancelled) == 5
|
||||
# At least the first one should have completed
|
||||
assert len(completed) >= 1
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# ScenarioBatchBuilder Tests
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TestScenarioBatchBuilder:
|
||||
"""Tests for ScenarioBatchBuilder."""
|
||||
|
||||
def test_empty_build(self):
|
||||
"""Test building with no configuration."""
|
||||
builder = ScenarioBatchBuilder()
|
||||
scenarios = builder.build()
|
||||
assert len(scenarios) == 1 # One default scenario
|
||||
|
||||
def test_base_config(self):
|
||||
"""Test setting base configuration."""
|
||||
builder = ScenarioBatchBuilder()
|
||||
scenarios = (
|
||||
builder
|
||||
.with_base_config(
|
||||
initial_capital=Decimal("50000"),
|
||||
symbols=["AAPL", "GOOGL"],
|
||||
)
|
||||
.build()
|
||||
)
|
||||
assert len(scenarios) == 1
|
||||
assert scenarios[0].initial_capital == Decimal("50000")
|
||||
assert scenarios[0].symbols == ["AAPL", "GOOGL"]
|
||||
|
||||
def test_parameter_variation(self):
|
||||
"""Test varying a single parameter."""
|
||||
builder = ScenarioBatchBuilder()
|
||||
scenarios = (
|
||||
builder
|
||||
.vary_parameter("leverage", [0.5, 1.0, 1.5])
|
||||
.build()
|
||||
)
|
||||
assert len(scenarios) == 3
|
||||
leverages = [s.strategy_params["leverage"] for s in scenarios]
|
||||
assert leverages == [0.5, 1.0, 1.5]
|
||||
|
||||
def test_multiple_parameter_variations(self):
|
||||
"""Test varying multiple parameters (Cartesian product)."""
|
||||
builder = ScenarioBatchBuilder()
|
||||
scenarios = (
|
||||
builder
|
||||
.vary_parameter("leverage", [1.0, 2.0])
|
||||
.vary_parameter("stop_loss", [0.05, 0.10])
|
||||
.build()
|
||||
)
|
||||
assert len(scenarios) == 4 # 2 x 2
|
||||
|
||||
combinations = [
|
||||
(s.strategy_params["leverage"], s.strategy_params["stop_loss"])
|
||||
for s in scenarios
|
||||
]
|
||||
expected = [
|
||||
(1.0, 0.05),
|
||||
(1.0, 0.10),
|
||||
(2.0, 0.05),
|
||||
(2.0, 0.10),
|
||||
]
|
||||
assert sorted(combinations) == sorted(expected)
|
||||
|
||||
def test_explicit_date_ranges(self):
|
||||
"""Test setting explicit date ranges."""
|
||||
builder = ScenarioBatchBuilder()
|
||||
scenarios = (
|
||||
builder
|
||||
.with_date_ranges([
|
||||
(date(2020, 1, 1), date(2020, 12, 31)),
|
||||
(date(2021, 1, 1), date(2021, 12, 31)),
|
||||
])
|
||||
.build()
|
||||
)
|
||||
assert len(scenarios) == 2
|
||||
assert scenarios[0].start_date == date(2020, 1, 1)
|
||||
assert scenarios[1].start_date == date(2021, 1, 1)
|
||||
|
||||
def test_clear_builder(self):
|
||||
"""Test clearing builder configuration."""
|
||||
builder = ScenarioBatchBuilder()
|
||||
builder.with_base_config(initial_capital=Decimal("50000"))
|
||||
builder.vary_parameter("leverage", [1.0, 2.0])
|
||||
builder.clear()
|
||||
|
||||
scenarios = builder.build()
|
||||
assert len(scenarios) == 1
|
||||
assert scenarios[0].initial_capital == Decimal("100000") # Default
|
||||
|
||||
def test_scenario_names_generated(self):
|
||||
"""Test that scenario names are auto-generated."""
|
||||
builder = ScenarioBatchBuilder()
|
||||
scenarios = (
|
||||
builder
|
||||
.vary_parameter("leverage", [1.0, 2.0])
|
||||
.build()
|
||||
)
|
||||
assert "leverage=1.0" in scenarios[0].name
|
||||
assert "leverage=2.0" in scenarios[1].name
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Result Aggregation Tests
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TestResultAggregation:
|
||||
"""Tests for result aggregation."""
|
||||
|
||||
def test_aggregate_empty_results(self):
|
||||
"""Test aggregating empty results list."""
|
||||
agg = aggregate_results([])
|
||||
assert agg["total_scenarios"] == 0
|
||||
assert agg["successful"] == 0
|
||||
assert agg["failed"] == 0
|
||||
|
||||
def test_aggregate_successful_results(self):
|
||||
"""Test aggregating successful results."""
|
||||
results = [
|
||||
ScenarioResult(
|
||||
scenario_id=f"test-{i}",
|
||||
scenario_name=f"Test{i}",
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
total_return=Decimal(f"0.{i+1}"),
|
||||
max_drawdown=Decimal(f"-0.0{i+1}"),
|
||||
duration_seconds=1.0,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
agg = aggregate_results(results)
|
||||
assert agg["total_scenarios"] == 3
|
||||
assert agg["successful"] == 3
|
||||
assert agg["failed"] == 0
|
||||
assert agg["success_rate"] == 1.0
|
||||
assert "avg_return" in agg
|
||||
assert "min_return" in agg
|
||||
assert "max_return" in agg
|
||||
|
||||
def test_aggregate_mixed_results(self):
|
||||
"""Test aggregating mixed success/failure results."""
|
||||
results = [
|
||||
ScenarioResult(
|
||||
scenario_id="test-1",
|
||||
scenario_name="Success1",
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
total_return=Decimal("0.10"),
|
||||
),
|
||||
ScenarioResult(
|
||||
scenario_id="test-2",
|
||||
scenario_name="Failed1",
|
||||
status=ScenarioStatus.FAILED,
|
||||
start_time=datetime.now(),
|
||||
error_message="Error",
|
||||
),
|
||||
ScenarioResult(
|
||||
scenario_id="test-3",
|
||||
scenario_name="Success2",
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
total_return=Decimal("0.20"),
|
||||
),
|
||||
]
|
||||
|
||||
agg = aggregate_results(results)
|
||||
assert agg["total_scenarios"] == 3
|
||||
assert agg["successful"] == 2
|
||||
assert agg["failed"] == 1
|
||||
assert agg["success_rate"] == pytest.approx(0.667, rel=0.01)
|
||||
|
||||
def test_aggregate_best_worst_scenarios(self):
|
||||
"""Test best/worst scenario identification."""
|
||||
results = [
|
||||
ScenarioResult(
|
||||
scenario_id="test-1",
|
||||
scenario_name="Low",
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
total_return=Decimal("0.05"),
|
||||
),
|
||||
ScenarioResult(
|
||||
scenario_id="test-2",
|
||||
scenario_name="High",
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
total_return=Decimal("0.25"),
|
||||
),
|
||||
ScenarioResult(
|
||||
scenario_id="test-3",
|
||||
scenario_name="Mid",
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
total_return=Decimal("0.15"),
|
||||
),
|
||||
]
|
||||
|
||||
agg = aggregate_results(results)
|
||||
assert agg["best_scenario"]["name"] == "High"
|
||||
assert agg["worst_scenario"]["name"] == "Low"
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Module Import Tests
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TestModuleImports:
|
||||
"""Tests for module imports."""
|
||||
|
||||
def test_import_from_simulation_module(self):
|
||||
"""Test importing from simulation module."""
|
||||
from tradingagents.simulation import (
|
||||
ExecutionMode,
|
||||
ScenarioStatus,
|
||||
ScenarioConfig,
|
||||
ScenarioResult,
|
||||
RunnerProgress,
|
||||
ScenarioRunner,
|
||||
ScenarioBatchBuilder,
|
||||
aggregate_results,
|
||||
)
|
||||
assert ExecutionMode is not None
|
||||
assert ScenarioStatus is not None
|
||||
assert ScenarioConfig is not None
|
||||
assert ScenarioResult is not None
|
||||
assert RunnerProgress is not None
|
||||
assert ScenarioRunner is not None
|
||||
assert ScenarioBatchBuilder is not None
|
||||
assert aggregate_results is not None
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Integration Tests
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TestScenarioRunnerIntegration:
|
||||
"""Integration tests for ScenarioRunner."""
|
||||
|
||||
def test_full_simulation_workflow(self):
|
||||
"""Test complete simulation workflow."""
|
||||
# Create executor that simulates trading
|
||||
def trading_simulator(config: ScenarioConfig) -> ScenarioResult:
|
||||
# Simulate based on strategy params
|
||||
leverage = config.strategy_params.get("leverage", 1.0)
|
||||
base_return = 0.08 # 8% base return
|
||||
final_return = base_return * leverage
|
||||
|
||||
final_value = config.initial_capital * Decimal(str(1 + final_return))
|
||||
|
||||
return ScenarioResult(
|
||||
scenario_id=config.scenario_id,
|
||||
scenario_name=config.name,
|
||||
status=ScenarioStatus.COMPLETED,
|
||||
start_time=datetime.now(),
|
||||
end_time=datetime.now(),
|
||||
final_value=final_value.quantize(Decimal("0.01")),
|
||||
total_return=Decimal(str(final_return)).quantize(Decimal("0.0001")),
|
||||
max_drawdown=Decimal("-0.10"),
|
||||
trades_executed=25,
|
||||
)
|
||||
|
||||
# Build scenarios with varying leverage
|
||||
scenarios = (
|
||||
ScenarioBatchBuilder()
|
||||
.with_base_config(initial_capital=Decimal("100000"))
|
||||
.vary_parameter("leverage", [0.5, 1.0, 1.5, 2.0])
|
||||
.build()
|
||||
)
|
||||
|
||||
# Run simulations
|
||||
runner = ScenarioRunner(
|
||||
executor=trading_simulator,
|
||||
mode=ExecutionMode.THREADED,
|
||||
max_workers=4,
|
||||
)
|
||||
|
||||
progress_updates = []
|
||||
results = runner.run(
|
||||
scenarios,
|
||||
progress_callback=lambda p: progress_updates.append(p.completed),
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 4
|
||||
assert all(r.is_successful for r in results)
|
||||
|
||||
# Aggregate and analyze
|
||||
agg = aggregate_results(results)
|
||||
assert agg["total_scenarios"] == 4
|
||||
assert agg["successful"] == 4
|
||||
assert agg["best_scenario"]["name"] == "leverage=2.0"
|
||||
assert agg["worst_scenario"]["name"] == "leverage=0.5"
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
"""Simulation module for portfolio simulations and backtesting.
|
||||
|
||||
This module provides simulation capabilities including:
|
||||
- Parallel scenario execution
|
||||
- Parameter sweep analysis
|
||||
- Strategy comparison
|
||||
- Economic regime simulation
|
||||
|
||||
Issue #33: [SIM-32] Scenario runner - parallel portfolio simulations
|
||||
|
||||
Submodules:
|
||||
scenario_runner: Core scenario execution framework
|
||||
|
||||
Classes:
|
||||
Enums:
|
||||
- ExecutionMode: Parallel execution mode (sequential, threaded, process)
|
||||
- ScenarioStatus: Status of a scenario run
|
||||
|
||||
Data Classes:
|
||||
- ScenarioConfig: Configuration for a simulation scenario
|
||||
- ScenarioResult: Result from a scenario simulation
|
||||
- RunnerProgress: Progress information for batch runs
|
||||
|
||||
Main Classes:
|
||||
- ScenarioRunner: Runner for parallel portfolio simulations
|
||||
- ScenarioBatchBuilder: Builder for creating scenario batches
|
||||
|
||||
Protocols:
|
||||
- ScenarioExecutor: Protocol for scenario execution functions
|
||||
|
||||
Utility Functions:
|
||||
- aggregate_results: Aggregate results from multiple scenarios
|
||||
|
||||
Example:
|
||||
>>> from tradingagents.simulation import (
|
||||
... ScenarioRunner,
|
||||
... ScenarioConfig,
|
||||
... ScenarioResult,
|
||||
... ScenarioStatus,
|
||||
... ExecutionMode,
|
||||
... )
|
||||
>>> from datetime import datetime
|
||||
>>> from decimal import Decimal
|
||||
>>>
|
||||
>>> def simple_executor(config: ScenarioConfig) -> ScenarioResult:
|
||||
... return ScenarioResult(
|
||||
... scenario_id=config.scenario_id,
|
||||
... scenario_name=config.name,
|
||||
... status=ScenarioStatus.COMPLETED,
|
||||
... start_time=datetime.now(),
|
||||
... final_value=config.initial_capital * Decimal("1.1"),
|
||||
... total_return=Decimal("0.1"),
|
||||
... )
|
||||
>>>
|
||||
>>> runner = ScenarioRunner(executor=simple_executor)
|
||||
>>> scenarios = [ScenarioConfig(name="Test1"), ScenarioConfig(name="Test2")]
|
||||
>>> results = runner.run(scenarios)
|
||||
"""
|
||||
|
||||
from .scenario_runner import (
|
||||
# Enums
|
||||
ExecutionMode,
|
||||
ScenarioStatus,
|
||||
# Data Classes
|
||||
ScenarioConfig,
|
||||
ScenarioResult,
|
||||
RunnerProgress,
|
||||
# Main Classes
|
||||
ScenarioRunner,
|
||||
ScenarioBatchBuilder,
|
||||
# Protocols
|
||||
ScenarioExecutor,
|
||||
# Types
|
||||
ProgressCallback,
|
||||
# Utility Functions
|
||||
aggregate_results,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Enums
|
||||
"ExecutionMode",
|
||||
"ScenarioStatus",
|
||||
# Data Classes
|
||||
"ScenarioConfig",
|
||||
"ScenarioResult",
|
||||
"RunnerProgress",
|
||||
# Main Classes
|
||||
"ScenarioRunner",
|
||||
"ScenarioBatchBuilder",
|
||||
# Protocols
|
||||
"ScenarioExecutor",
|
||||
# Types
|
||||
"ProgressCallback",
|
||||
# Utility Functions
|
||||
"aggregate_results",
|
||||
]
|
||||
|
|
@ -0,0 +1,689 @@
|
|||
"""Scenario Runner for parallel portfolio simulations.
|
||||
|
||||
This module provides scenario-based simulation capabilities including:
|
||||
- Parallel execution of multiple portfolio scenarios
|
||||
- Configurable simulation parameters
|
||||
- Progress tracking and callbacks
|
||||
- Result aggregation and comparison
|
||||
|
||||
Issue #33: [SIM-32] Scenario runner - parallel portfolio simulations
|
||||
|
||||
Design Principles:
|
||||
- Thread-safe parallel execution
|
||||
- Configurable concurrency limits
|
||||
- Memory-efficient result handling
|
||||
- Progress reporting callbacks
|
||||
"""
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
|
||||
import copy
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
|
||||
class ExecutionMode(Enum):
|
||||
"""Mode of parallel execution."""
|
||||
SEQUENTIAL = "sequential" # Run scenarios one at a time
|
||||
THREADED = "threaded" # Use thread pool (for I/O bound)
|
||||
PROCESS = "process" # Use process pool (for CPU bound)
|
||||
|
||||
|
||||
class ScenarioStatus(Enum):
|
||||
"""Status of a scenario run."""
|
||||
PENDING = "pending" # Not yet started
|
||||
RUNNING = "running" # Currently executing
|
||||
COMPLETED = "completed" # Finished successfully
|
||||
FAILED = "failed" # Finished with error
|
||||
CANCELLED = "cancelled" # Cancelled before completion
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenarioConfig:
|
||||
"""Configuration for a simulation scenario.
|
||||
|
||||
Attributes:
|
||||
scenario_id: Unique identifier for the scenario
|
||||
name: Human-readable name
|
||||
start_date: Simulation start date
|
||||
end_date: Simulation end date
|
||||
initial_capital: Starting capital
|
||||
symbols: List of symbols to include
|
||||
strategy_params: Strategy-specific parameters
|
||||
risk_params: Risk management parameters
|
||||
metadata: Additional scenario data
|
||||
"""
|
||||
scenario_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
name: str = ""
|
||||
start_date: Optional[date] = None
|
||||
end_date: Optional[date] = None
|
||||
initial_capital: Decimal = Decimal("100000")
|
||||
symbols: List[str] = field(default_factory=list)
|
||||
strategy_params: Dict[str, Any] = field(default_factory=dict)
|
||||
risk_params: Dict[str, Any] = field(default_factory=dict)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set default name if not provided."""
|
||||
if not self.name:
|
||||
self.name = f"Scenario-{self.scenario_id[:8]}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenarioResult:
|
||||
"""Result from a scenario simulation.
|
||||
|
||||
Attributes:
|
||||
scenario_id: ID of the scenario that was run
|
||||
scenario_name: Name of the scenario
|
||||
status: Final status of the run
|
||||
start_time: When simulation started
|
||||
end_time: When simulation ended
|
||||
duration_seconds: Total runtime in seconds
|
||||
final_value: Final portfolio value
|
||||
total_return: Total return as decimal (0.10 = 10%)
|
||||
trades_executed: Number of trades made
|
||||
max_drawdown: Maximum drawdown experienced
|
||||
sharpe_ratio: Sharpe ratio if calculable
|
||||
error_message: Error message if failed
|
||||
portfolio_history: Time series of portfolio values
|
||||
trade_history: List of trades executed
|
||||
metrics: Additional performance metrics
|
||||
metadata: Additional result data
|
||||
"""
|
||||
scenario_id: str
|
||||
scenario_name: str
|
||||
status: ScenarioStatus
|
||||
start_time: datetime
|
||||
end_time: Optional[datetime] = None
|
||||
duration_seconds: float = 0.0
|
||||
final_value: Decimal = Decimal("0")
|
||||
total_return: Decimal = Decimal("0")
|
||||
trades_executed: int = 0
|
||||
max_drawdown: Decimal = Decimal("0")
|
||||
sharpe_ratio: Optional[Decimal] = None
|
||||
error_message: Optional[str] = None
|
||||
portfolio_history: List[Tuple[date, Decimal]] = field(default_factory=list)
|
||||
trade_history: List[Dict[str, Any]] = field(default_factory=list)
|
||||
metrics: Dict[str, Any] = field(default_factory=dict)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def is_successful(self) -> bool:
|
||||
"""Check if scenario completed successfully."""
|
||||
return self.status == ScenarioStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def is_finished(self) -> bool:
|
||||
"""Check if scenario has finished (success or failure)."""
|
||||
return self.status in (
|
||||
ScenarioStatus.COMPLETED,
|
||||
ScenarioStatus.FAILED,
|
||||
ScenarioStatus.CANCELLED,
|
||||
)
|
||||
|
||||
|
||||
class ScenarioExecutor(Protocol):
|
||||
"""Protocol for scenario execution functions.
|
||||
|
||||
Implementations should take a ScenarioConfig and return a ScenarioResult.
|
||||
"""
|
||||
|
||||
def __call__(self, config: ScenarioConfig) -> ScenarioResult:
|
||||
"""Execute a scenario and return results."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunnerProgress:
|
||||
"""Progress information for a scenario run batch.
|
||||
|
||||
Attributes:
|
||||
total_scenarios: Total number of scenarios
|
||||
completed: Number of completed scenarios
|
||||
failed: Number of failed scenarios
|
||||
running: Number of currently running scenarios
|
||||
pending: Number of pending scenarios
|
||||
start_time: When the batch started
|
||||
estimated_remaining_seconds: Estimated time remaining
|
||||
"""
|
||||
total_scenarios: int
|
||||
completed: int = 0
|
||||
failed: int = 0
|
||||
running: int = 0
|
||||
pending: int = 0
|
||||
start_time: Optional[datetime] = None
|
||||
estimated_remaining_seconds: Optional[float] = None
|
||||
|
||||
@property
|
||||
def progress_percent(self) -> float:
|
||||
"""Calculate completion percentage."""
|
||||
if self.total_scenarios == 0:
|
||||
return 100.0
|
||||
return (self.completed + self.failed) / self.total_scenarios * 100
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if all scenarios are finished."""
|
||||
return (self.completed + self.failed) >= self.total_scenarios
|
||||
|
||||
|
||||
ProgressCallback = Callable[[RunnerProgress], None]
|
||||
|
||||
|
||||
class ScenarioRunner:
|
||||
"""Runner for parallel portfolio simulations.
|
||||
|
||||
Executes multiple simulation scenarios in parallel using configurable
|
||||
execution modes (sequential, threaded, or process-based).
|
||||
|
||||
Example:
|
||||
>>> def simulate(config: ScenarioConfig) -> ScenarioResult:
|
||||
... # Your simulation logic
|
||||
... return ScenarioResult(
|
||||
... scenario_id=config.scenario_id,
|
||||
... scenario_name=config.name,
|
||||
... status=ScenarioStatus.COMPLETED,
|
||||
... start_time=datetime.now(),
|
||||
... )
|
||||
...
|
||||
>>> runner = ScenarioRunner(executor=simulate)
|
||||
>>> scenarios = [
|
||||
... ScenarioConfig(name="Bull Market", strategy_params={"leverage": 1.5}),
|
||||
... ScenarioConfig(name="Bear Market", strategy_params={"leverage": 0.5}),
|
||||
... ]
|
||||
>>> results = runner.run(scenarios)
|
||||
>>> print(f"Completed: {len([r for r in results if r.is_successful])}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: ScenarioExecutor,
|
||||
mode: ExecutionMode = ExecutionMode.THREADED,
|
||||
max_workers: Optional[int] = None,
|
||||
timeout_seconds: Optional[float] = None,
|
||||
):
|
||||
"""Initialize the scenario runner.
|
||||
|
||||
Args:
|
||||
executor: Function that executes a single scenario
|
||||
mode: Execution mode (sequential, threaded, process)
|
||||
max_workers: Maximum number of parallel workers (None = auto)
|
||||
timeout_seconds: Timeout per scenario (None = no timeout)
|
||||
"""
|
||||
self.executor = executor
|
||||
self.mode = mode
|
||||
self.max_workers = max_workers
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self._lock = threading.Lock()
|
||||
self._cancelled = False
|
||||
self._progress = RunnerProgress(total_scenarios=0)
|
||||
self._progress_callbacks: List[ProgressCallback] = []
|
||||
|
||||
def add_progress_callback(self, callback: ProgressCallback) -> None:
|
||||
"""Add a callback for progress updates.
|
||||
|
||||
Args:
|
||||
callback: Function to call with progress updates
|
||||
"""
|
||||
with self._lock:
|
||||
self._progress_callbacks.append(callback)
|
||||
|
||||
def remove_progress_callback(self, callback: ProgressCallback) -> None:
|
||||
"""Remove a progress callback.
|
||||
|
||||
Args:
|
||||
callback: Callback to remove
|
||||
"""
|
||||
with self._lock:
|
||||
if callback in self._progress_callbacks:
|
||||
self._progress_callbacks.remove(callback)
|
||||
|
||||
def _notify_progress(self) -> None:
|
||||
"""Notify all registered progress callbacks."""
|
||||
with self._lock:
|
||||
callbacks = self._progress_callbacks.copy()
|
||||
progress = copy.copy(self._progress)
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback(progress)
|
||||
except Exception:
|
||||
pass # Don't let callback errors affect execution
|
||||
|
||||
def _update_progress(
|
||||
self,
|
||||
completed_delta: int = 0,
|
||||
failed_delta: int = 0,
|
||||
running_delta: int = 0,
|
||||
pending_delta: int = 0,
|
||||
) -> None:
|
||||
"""Update progress counters thread-safely."""
|
||||
with self._lock:
|
||||
self._progress.completed += completed_delta
|
||||
self._progress.failed += failed_delta
|
||||
self._progress.running += running_delta
|
||||
self._progress.pending += pending_delta
|
||||
|
||||
# Estimate remaining time
|
||||
if self._progress.start_time and self._progress.completed > 0:
|
||||
elapsed = (datetime.now() - self._progress.start_time).total_seconds()
|
||||
avg_time = elapsed / self._progress.completed
|
||||
remaining = self._progress.pending + self._progress.running
|
||||
self._progress.estimated_remaining_seconds = avg_time * remaining
|
||||
|
||||
self._notify_progress()
|
||||
|
||||
def _execute_scenario(self, config: ScenarioConfig) -> ScenarioResult:
|
||||
"""Execute a single scenario with error handling.
|
||||
|
||||
Args:
|
||||
config: Scenario configuration
|
||||
|
||||
Returns:
|
||||
ScenarioResult with outcome
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
# Check if cancelled
|
||||
if self._cancelled:
|
||||
return ScenarioResult(
|
||||
scenario_id=config.scenario_id,
|
||||
scenario_name=config.name,
|
||||
status=ScenarioStatus.CANCELLED,
|
||||
start_time=start_time,
|
||||
end_time=datetime.now(),
|
||||
)
|
||||
|
||||
self._update_progress(running_delta=1, pending_delta=-1)
|
||||
|
||||
try:
|
||||
result = self.executor(config)
|
||||
result.end_time = datetime.now()
|
||||
result.duration_seconds = (result.end_time - start_time).total_seconds()
|
||||
|
||||
if result.status == ScenarioStatus.COMPLETED:
|
||||
self._update_progress(completed_delta=1, running_delta=-1)
|
||||
else:
|
||||
self._update_progress(failed_delta=1, running_delta=-1)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
end_time = datetime.now()
|
||||
self._update_progress(failed_delta=1, running_delta=-1)
|
||||
return ScenarioResult(
|
||||
scenario_id=config.scenario_id,
|
||||
scenario_name=config.name,
|
||||
status=ScenarioStatus.FAILED,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
duration_seconds=(end_time - start_time).total_seconds(),
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
scenarios: List[ScenarioConfig],
|
||||
progress_callback: Optional[ProgressCallback] = None,
|
||||
) -> List[ScenarioResult]:
|
||||
"""Run multiple scenarios.
|
||||
|
||||
Args:
|
||||
scenarios: List of scenario configurations to run
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
List of results in the same order as input scenarios
|
||||
"""
|
||||
if not scenarios:
|
||||
return []
|
||||
|
||||
# Reset state
|
||||
self._cancelled = False
|
||||
self._progress = RunnerProgress(
|
||||
total_scenarios=len(scenarios),
|
||||
pending=len(scenarios),
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
self.add_progress_callback(progress_callback)
|
||||
|
||||
try:
|
||||
if self.mode == ExecutionMode.SEQUENTIAL:
|
||||
results = self._run_sequential(scenarios)
|
||||
elif self.mode == ExecutionMode.THREADED:
|
||||
results = self._run_parallel(scenarios, ThreadPoolExecutor)
|
||||
elif self.mode == ExecutionMode.PROCESS:
|
||||
results = self._run_parallel(scenarios, ProcessPoolExecutor)
|
||||
else:
|
||||
raise ValueError(f"Unknown execution mode: {self.mode}")
|
||||
|
||||
return results
|
||||
|
||||
finally:
|
||||
if progress_callback:
|
||||
self.remove_progress_callback(progress_callback)
|
||||
|
||||
def _run_sequential(
|
||||
self, scenarios: List[ScenarioConfig]
|
||||
) -> List[ScenarioResult]:
|
||||
"""Run scenarios sequentially.
|
||||
|
||||
Args:
|
||||
scenarios: List of scenarios to run
|
||||
|
||||
Returns:
|
||||
List of results in order
|
||||
"""
|
||||
results = []
|
||||
for config in scenarios:
|
||||
if self._cancelled:
|
||||
results.append(ScenarioResult(
|
||||
scenario_id=config.scenario_id,
|
||||
scenario_name=config.name,
|
||||
status=ScenarioStatus.CANCELLED,
|
||||
start_time=datetime.now(),
|
||||
))
|
||||
else:
|
||||
results.append(self._execute_scenario(config))
|
||||
return results
|
||||
|
||||
def _run_parallel(
|
||||
self,
|
||||
scenarios: List[ScenarioConfig],
|
||||
executor_class: type,
|
||||
) -> List[ScenarioResult]:
|
||||
"""Run scenarios in parallel using a pool executor.
|
||||
|
||||
Args:
|
||||
scenarios: List of scenarios to run
|
||||
executor_class: ThreadPoolExecutor or ProcessPoolExecutor
|
||||
|
||||
Returns:
|
||||
List of results in original order
|
||||
"""
|
||||
# Map scenario IDs to their original indices
|
||||
id_to_index = {config.scenario_id: i for i, config in enumerate(scenarios)}
|
||||
results = [None] * len(scenarios)
|
||||
|
||||
with executor_class(max_workers=self.max_workers) as pool:
|
||||
# Submit all scenarios
|
||||
future_to_id = {
|
||||
pool.submit(self._execute_scenario, config): config.scenario_id
|
||||
for config in scenarios
|
||||
}
|
||||
|
||||
# Collect results as they complete
|
||||
for future in as_completed(future_to_id, timeout=self.timeout_seconds):
|
||||
scenario_id = future_to_id[future]
|
||||
index = id_to_index[scenario_id]
|
||||
|
||||
try:
|
||||
result = future.result()
|
||||
results[index] = result
|
||||
except Exception as e:
|
||||
# Create error result for this scenario
|
||||
config = scenarios[index]
|
||||
results[index] = ScenarioResult(
|
||||
scenario_id=config.scenario_id,
|
||||
scenario_name=config.name,
|
||||
status=ScenarioStatus.FAILED,
|
||||
start_time=datetime.now(),
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel all pending scenarios.
|
||||
|
||||
Running scenarios will complete, but pending ones will be skipped.
|
||||
"""
|
||||
with self._lock:
|
||||
self._cancelled = True
|
||||
|
||||
def get_progress(self) -> RunnerProgress:
|
||||
"""Get current progress information.
|
||||
|
||||
Returns:
|
||||
Current progress state
|
||||
"""
|
||||
with self._lock:
|
||||
return copy.copy(self._progress)
|
||||
|
||||
|
||||
class ScenarioBatchBuilder:
|
||||
"""Builder for creating batches of scenario configurations.
|
||||
|
||||
Provides convenient methods for generating variations of scenarios
|
||||
for sensitivity analysis, parameter sweeps, etc.
|
||||
|
||||
Example:
|
||||
>>> builder = ScenarioBatchBuilder()
|
||||
>>> scenarios = (
|
||||
... builder
|
||||
... .with_base_config(symbols=["AAPL", "GOOGL"])
|
||||
... .vary_parameter("leverage", [0.5, 1.0, 1.5, 2.0])
|
||||
... .vary_date_range(
|
||||
... date(2020, 1, 1),
|
||||
... date(2023, 12, 31),
|
||||
... window_months=12,
|
||||
... )
|
||||
... .build()
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the batch builder."""
|
||||
self._base_config: Dict[str, Any] = {}
|
||||
self._parameter_variations: Dict[str, List[Any]] = {}
|
||||
self._date_ranges: List[Tuple[date, date]] = []
|
||||
|
||||
def with_base_config(self, **kwargs) -> "ScenarioBatchBuilder":
|
||||
"""Set base configuration for all scenarios.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters
|
||||
|
||||
Returns:
|
||||
Self for chaining
|
||||
"""
|
||||
self._base_config.update(kwargs)
|
||||
return self
|
||||
|
||||
def vary_parameter(
|
||||
self, name: str, values: List[Any]
|
||||
) -> "ScenarioBatchBuilder":
|
||||
"""Add a parameter to vary across scenarios.
|
||||
|
||||
Args:
|
||||
name: Parameter name (in strategy_params or risk_params)
|
||||
values: List of values to use
|
||||
|
||||
Returns:
|
||||
Self for chaining
|
||||
"""
|
||||
self._parameter_variations[name] = values
|
||||
return self
|
||||
|
||||
def vary_date_range(
|
||||
self,
|
||||
start: date,
|
||||
end: date,
|
||||
window_months: int = 12,
|
||||
step_months: int = 3,
|
||||
) -> "ScenarioBatchBuilder":
|
||||
"""Add rolling date windows.
|
||||
|
||||
Args:
|
||||
start: Overall start date
|
||||
end: Overall end date
|
||||
window_months: Size of each window in months
|
||||
step_months: Step between windows in months
|
||||
|
||||
Returns:
|
||||
Self for chaining
|
||||
"""
|
||||
from datetime import timedelta
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
current_start = start
|
||||
while current_start < end:
|
||||
current_end = current_start + relativedelta(months=window_months)
|
||||
if current_end > end:
|
||||
current_end = end
|
||||
self._date_ranges.append((current_start, current_end))
|
||||
current_start += relativedelta(months=step_months)
|
||||
|
||||
return self
|
||||
|
||||
def with_date_ranges(
|
||||
self, ranges: List[Tuple[date, date]]
|
||||
) -> "ScenarioBatchBuilder":
|
||||
"""Set explicit date ranges.
|
||||
|
||||
Args:
|
||||
ranges: List of (start_date, end_date) tuples
|
||||
|
||||
Returns:
|
||||
Self for chaining
|
||||
"""
|
||||
self._date_ranges.extend(ranges)
|
||||
return self
|
||||
|
||||
def build(self) -> List[ScenarioConfig]:
|
||||
"""Build all scenario configurations.
|
||||
|
||||
Creates the Cartesian product of all variations.
|
||||
|
||||
Returns:
|
||||
List of ScenarioConfig objects
|
||||
"""
|
||||
import itertools
|
||||
|
||||
scenarios = []
|
||||
|
||||
# Get all parameter combinations
|
||||
param_names = list(self._parameter_variations.keys())
|
||||
param_values = [self._parameter_variations[name] for name in param_names]
|
||||
|
||||
if param_values:
|
||||
param_combinations = list(itertools.product(*param_values))
|
||||
else:
|
||||
param_combinations = [()]
|
||||
|
||||
# Get date ranges (use single None tuple if no ranges)
|
||||
date_ranges = self._date_ranges if self._date_ranges else [(None, None)]
|
||||
|
||||
# Generate all combinations
|
||||
for param_combo in param_combinations:
|
||||
for start_date, end_date in date_ranges:
|
||||
# Build scenario config
|
||||
config_dict = copy.deepcopy(self._base_config)
|
||||
|
||||
# Set dates
|
||||
if start_date is not None:
|
||||
config_dict["start_date"] = start_date
|
||||
if end_date is not None:
|
||||
config_dict["end_date"] = end_date
|
||||
|
||||
# Set parameter variations
|
||||
strategy_params = config_dict.get("strategy_params", {})
|
||||
for name, value in zip(param_names, param_combo):
|
||||
strategy_params[name] = value
|
||||
config_dict["strategy_params"] = strategy_params
|
||||
|
||||
# Generate descriptive name
|
||||
name_parts = []
|
||||
for name, value in zip(param_names, param_combo):
|
||||
name_parts.append(f"{name}={value}")
|
||||
if start_date:
|
||||
name_parts.append(f"{start_date.year}")
|
||||
|
||||
if name_parts:
|
||||
config_dict["name"] = " | ".join(name_parts)
|
||||
|
||||
scenarios.append(ScenarioConfig(**config_dict))
|
||||
|
||||
return scenarios
|
||||
|
||||
def clear(self) -> "ScenarioBatchBuilder":
|
||||
"""Clear all configuration.
|
||||
|
||||
Returns:
|
||||
Self for chaining
|
||||
"""
|
||||
self._base_config = {}
|
||||
self._parameter_variations = {}
|
||||
self._date_ranges = []
|
||||
return self
|
||||
|
||||
|
||||
def aggregate_results(results: List[ScenarioResult]) -> Dict[str, Any]:
|
||||
"""Aggregate results from multiple scenarios.
|
||||
|
||||
Args:
|
||||
results: List of scenario results
|
||||
|
||||
Returns:
|
||||
Dictionary with aggregated statistics
|
||||
"""
|
||||
if not results:
|
||||
return {
|
||||
"total_scenarios": 0,
|
||||
"successful": 0,
|
||||
"failed": 0,
|
||||
}
|
||||
|
||||
successful = [r for r in results if r.is_successful]
|
||||
failed = [r for r in results if r.status == ScenarioStatus.FAILED]
|
||||
|
||||
# Calculate aggregate metrics
|
||||
returns = [float(r.total_return) for r in successful if r.total_return]
|
||||
drawdowns = [float(r.max_drawdown) for r in successful if r.max_drawdown]
|
||||
durations = [r.duration_seconds for r in results if r.duration_seconds]
|
||||
|
||||
aggregate = {
|
||||
"total_scenarios": len(results),
|
||||
"successful": len(successful),
|
||||
"failed": len(failed),
|
||||
"success_rate": len(successful) / len(results) if results else 0,
|
||||
"total_duration_seconds": sum(durations),
|
||||
"avg_duration_seconds": sum(durations) / len(durations) if durations else 0,
|
||||
}
|
||||
|
||||
if returns:
|
||||
aggregate.update({
|
||||
"avg_return": sum(returns) / len(returns),
|
||||
"min_return": min(returns),
|
||||
"max_return": max(returns),
|
||||
"median_return": sorted(returns)[len(returns) // 2],
|
||||
})
|
||||
|
||||
if drawdowns:
|
||||
aggregate.update({
|
||||
"avg_max_drawdown": sum(drawdowns) / len(drawdowns),
|
||||
"worst_drawdown": min(drawdowns), # More negative is worse
|
||||
})
|
||||
|
||||
# Best and worst scenarios
|
||||
if successful:
|
||||
best = max(successful, key=lambda r: float(r.total_return or 0))
|
||||
worst = min(successful, key=lambda r: float(r.total_return or 0))
|
||||
aggregate["best_scenario"] = {
|
||||
"name": best.scenario_name,
|
||||
"return": str(best.total_return),
|
||||
}
|
||||
aggregate["worst_scenario"] = {
|
||||
"name": worst.scenario_name,
|
||||
"return": str(worst.total_return),
|
||||
}
|
||||
|
||||
return aggregate
|
||||
Loading…
Reference in New Issue