From 9aee43312dfa1924f55ed1e465d7cb3fb6f2df6c Mon Sep 17 00:00:00 2001 From: Andrew Kaszubski Date: Fri, 26 Dec 2025 21:29:14 +1100 Subject: [PATCH] feat(execution): add Risk Controls for pre-trade validation - Issue #28 (45 tests) --- tests/unit/execution/test_risk_controls.py | 678 +++++++++++++++++++ tradingagents/execution/__init__.py | 20 + tradingagents/execution/risk_controls.py | 735 +++++++++++++++++++++ 3 files changed, 1433 insertions(+) create mode 100644 tests/unit/execution/test_risk_controls.py create mode 100644 tradingagents/execution/risk_controls.py diff --git a/tests/unit/execution/test_risk_controls.py b/tests/unit/execution/test_risk_controls.py new file mode 100644 index 00000000..13aafa44 --- /dev/null +++ b/tests/unit/execution/test_risk_controls.py @@ -0,0 +1,678 @@ +"""Tests for Risk Controls implementation. + +Issue #28: [EXEC-27] Risk controls - position limits, loss limits +""" + +from decimal import Decimal +from datetime import datetime, timezone, timedelta, date +import pytest + +from tradingagents.execution import ( + RiskManager, + RiskCheckResult, + RiskRuleType, + RiskViolation, + RiskCheckResponse, + PositionLimits, + LossLimits, + PortfolioState, + OrderRequest, + OrderSide, + Position, + PositionSide, +) + + +class TestRiskViolation: + """Test RiskViolation dataclass.""" + + def test_violation_creation(self): + """Test creating a violation.""" + violation = RiskViolation( + rule_type=RiskRuleType.POSITION_SIZE, + rule_name="max_position_size", + message="Position too large", + current_value=Decimal("1500"), + limit_value=Decimal("1000"), + ) + assert violation.rule_type == RiskRuleType.POSITION_SIZE + assert violation.current_value == Decimal("1500") + assert violation.severity == "error" + + def test_violation_with_warning_severity(self): + """Test violation with warning severity.""" + violation = RiskViolation( + rule_type=RiskRuleType.DAILY_LOSS, + rule_name="potential_loss", + message="Potential loss warning", + current_value=Decimal("100"), + limit_value=Decimal("200"), + severity="warning", + ) + assert violation.severity == "warning" + + +class TestRiskCheckResponse: + """Test RiskCheckResponse dataclass.""" + + def test_default_passed(self): + """Test default response is passed.""" + response = RiskCheckResponse() + assert response.passed is True + assert response.violations == [] + assert response.warnings == [] + + def test_add_violation(self): + """Test adding violation fails response.""" + response = RiskCheckResponse() + violation = RiskViolation( + rule_type=RiskRuleType.POSITION_SIZE, + rule_name="test", + message="Test violation", + current_value=Decimal("100"), + limit_value=Decimal("50"), + ) + response.add_violation(violation) + assert response.passed is False + assert len(response.violations) == 1 + + def test_add_warning(self): + """Test adding warning keeps passed.""" + response = RiskCheckResponse() + violation = RiskViolation( + rule_type=RiskRuleType.DAILY_LOSS, + rule_name="test", + message="Test warning", + current_value=Decimal("100"), + limit_value=Decimal("200"), + severity="warning", + ) + response.add_violation(violation) + assert response.passed is True + assert len(response.warnings) == 1 + + def test_rejection_message(self): + """Test rejection message formatting.""" + response = RiskCheckResponse() + violation = RiskViolation( + rule_type=RiskRuleType.POSITION_SIZE, + rule_name="test", + message="Position too large", + current_value=Decimal("100"), + limit_value=Decimal("50"), + ) + response.add_violation(violation) + assert "Position too large" in response.rejection_message + + def test_rejection_message_none_when_passed(self): + """Test rejection message is None when passed.""" + response = RiskCheckResponse() + assert response.rejection_message is None + + +class TestPositionLimits: + """Test PositionLimits dataclass.""" + + def test_default_limits(self): + """Test default limits are None.""" + limits = PositionLimits() + assert limits.max_position_size is None + assert limits.max_position_value is None + assert limits.max_concentration_percent is None + + def test_custom_limits(self): + """Test setting custom limits.""" + limits = PositionLimits( + max_position_size=Decimal("1000"), + max_position_value=Decimal("50000"), + max_concentration_percent=Decimal("20"), + ) + assert limits.max_position_size == Decimal("1000") + assert limits.max_position_value == Decimal("50000") + assert limits.max_concentration_percent == Decimal("20") + + def test_per_symbol_limits(self): + """Test per-symbol limits.""" + limits = PositionLimits( + max_position_size=Decimal("1000"), + per_symbol_limits={ + "AAPL": {"max_position_size": Decimal("500")}, + }, + ) + assert limits.get_limit_for_symbol("AAPL", "max_position_size") == Decimal("500") + assert limits.get_limit_for_symbol("MSFT", "max_position_size") == Decimal("1000") + + +class TestLossLimits: + """Test LossLimits dataclass.""" + + def test_default_limits(self): + """Test default limits are None.""" + limits = LossLimits() + assert limits.max_daily_loss is None + assert limits.max_drawdown is None + assert limits.cooling_off_period_minutes == 0 + + def test_custom_limits(self): + """Test setting custom limits.""" + limits = LossLimits( + max_daily_loss=Decimal("500"), + max_daily_loss_percent=Decimal("5"), + max_drawdown_percent=Decimal("20"), + cooling_off_period_minutes=30, + ) + assert limits.max_daily_loss == Decimal("500") + assert limits.max_daily_loss_percent == Decimal("5") + assert limits.cooling_off_period_minutes == 30 + + +class TestPortfolioState: + """Test PortfolioState dataclass.""" + + def test_default_state(self): + """Test default state.""" + state = PortfolioState() + assert state.cash == Decimal("0") + assert state.equity == Decimal("0") + assert state.daily_pnl == Decimal("0") + + def test_drawdown_calculation(self): + """Test drawdown calculation.""" + state = PortfolioState( + equity=Decimal("90000"), + peak_equity=Decimal("100000"), + ) + assert state.current_drawdown == Decimal("10000") + assert state.current_drawdown_percent == Decimal("10") + + def test_no_drawdown_without_peak(self): + """Test no drawdown without peak.""" + state = PortfolioState(equity=Decimal("100000")) + assert state.current_drawdown == Decimal("0") + assert state.current_drawdown_percent == Decimal("0") + + +class TestRiskManagerInit: + """Test RiskManager initialization.""" + + def test_default_initialization(self): + """Test default initialization.""" + manager = RiskManager() + assert manager.enabled is True + assert manager.position_limits is not None + assert manager.loss_limits is not None + + def test_disabled_initialization(self): + """Test disabled initialization.""" + manager = RiskManager(enabled=False) + assert manager.enabled is False + + def test_with_limits(self): + """Test initialization with limits.""" + pos_limits = PositionLimits(max_position_size=Decimal("1000")) + loss_limits = LossLimits(max_daily_loss=Decimal("500")) + manager = RiskManager( + position_limits=pos_limits, + loss_limits=loss_limits, + ) + assert manager.position_limits.max_position_size == Decimal("1000") + assert manager.loss_limits.max_daily_loss == Decimal("500") + + +class TestRiskManagerPositionLimits: + """Test RiskManager position limit checks.""" + + def test_position_size_within_limit(self): + """Test position size within limit passes.""" + manager = RiskManager( + position_limits=PositionLimits(max_position_size=Decimal("1000")) + ) + portfolio = PortfolioState() + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("500")) + + result = manager.validate_order(order, portfolio) + assert result.passed is True + + def test_position_size_exceeds_limit(self): + """Test position size exceeding limit fails.""" + manager = RiskManager( + position_limits=PositionLimits(max_position_size=Decimal("1000")) + ) + portfolio = PortfolioState() + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("1500")) + + result = manager.validate_order(order, portfolio) + assert result.passed is False + assert any(v.rule_name == "max_position_size" for v in result.violations) + + def test_position_size_with_existing_position(self): + """Test position size check with existing position.""" + manager = RiskManager( + position_limits=PositionLimits(max_position_size=Decimal("1000")) + ) + portfolio = PortfolioState( + positions={ + "AAPL": Position( + symbol="AAPL", + quantity=Decimal("600"), + side=PositionSide.LONG, + avg_entry_price=Decimal("100"), + current_price=Decimal("100"), + market_value=Decimal("60000"), + cost_basis=Decimal("60000"), + unrealized_pnl=Decimal("0"), + unrealized_pnl_percent=Decimal("0"), + ) + } + ) + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("500")) + + result = manager.validate_order(order, portfolio) + assert result.passed is False # 600 + 500 = 1100 > 1000 + + def test_position_value_limit(self): + """Test position value limit.""" + manager = RiskManager( + position_limits=PositionLimits(max_position_value=Decimal("50000")) + ) + portfolio = PortfolioState() + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("1000")) + + # At $100, value is $100,000 which exceeds $50,000 limit + result = manager.validate_order(order, portfolio, estimated_fill_price=Decimal("100")) + assert result.passed is False + assert any(v.rule_name == "max_position_value" for v in result.violations) + + def test_concentration_limit(self): + """Test concentration limit.""" + manager = RiskManager( + position_limits=PositionLimits(max_concentration_percent=Decimal("20")) + ) + portfolio = PortfolioState(equity=Decimal("100000")) + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("300")) + + # At $100, value is $30,000 = 30% of $100,000 equity + result = manager.validate_order(order, portfolio, estimated_fill_price=Decimal("100")) + assert result.passed is False + assert any(v.rule_name == "max_concentration" for v in result.violations) + + def test_total_positions_limit(self): + """Test total positions limit.""" + manager = RiskManager( + position_limits=PositionLimits(max_total_positions=2) + ) + portfolio = PortfolioState( + positions={ + "AAPL": Position( + symbol="AAPL", + quantity=Decimal("100"), + side=PositionSide.LONG, + avg_entry_price=Decimal("100"), + current_price=Decimal("100"), + market_value=Decimal("10000"), + cost_basis=Decimal("10000"), + unrealized_pnl=Decimal("0"), + unrealized_pnl_percent=Decimal("0"), + ), + "MSFT": Position( + symbol="MSFT", + quantity=Decimal("100"), + side=PositionSide.LONG, + avg_entry_price=Decimal("100"), + current_price=Decimal("100"), + market_value=Decimal("10000"), + cost_basis=Decimal("10000"), + unrealized_pnl=Decimal("0"), + unrealized_pnl_percent=Decimal("0"), + ), + } + ) + order = OrderRequest.market("GOOGL", OrderSide.BUY, Decimal("100")) + + result = manager.validate_order(order, portfolio) + assert result.passed is False + assert any(v.rule_name == "max_total_positions" for v in result.violations) + + +class TestRiskManagerLossLimits: + """Test RiskManager loss limit checks.""" + + def test_daily_loss_within_limit(self): + """Test daily loss within limit passes.""" + manager = RiskManager( + loss_limits=LossLimits(max_daily_loss=Decimal("1000")) + ) + portfolio = PortfolioState(daily_pnl=Decimal("-500")) + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")) + + result = manager.validate_order(order, portfolio) + assert result.passed is True + + def test_daily_loss_exceeds_limit(self): + """Test daily loss exceeding limit fails.""" + manager = RiskManager( + loss_limits=LossLimits(max_daily_loss=Decimal("1000")) + ) + portfolio = PortfolioState(daily_pnl=Decimal("-1500")) + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")) + + result = manager.validate_order(order, portfolio) + assert result.passed is False + assert any(v.rule_name == "max_daily_loss" for v in result.violations) + + def test_daily_loss_percent_limit(self): + """Test daily loss percentage limit.""" + manager = RiskManager( + loss_limits=LossLimits(max_daily_loss_percent=Decimal("5")) + ) + portfolio = PortfolioState( + equity=Decimal("100000"), + daily_pnl=Decimal("-6000"), # 6% loss + ) + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")) + + result = manager.validate_order(order, portfolio) + assert result.passed is False + assert any(v.rule_name == "max_daily_loss_percent" for v in result.violations) + + def test_drawdown_limit(self): + """Test drawdown limit.""" + manager = RiskManager( + loss_limits=LossLimits(max_drawdown=Decimal("15000")) + ) + portfolio = PortfolioState( + equity=Decimal("85000"), + peak_equity=Decimal("100000"), # 15k drawdown + ) + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")) + + # Exactly at limit should pass + result = manager.validate_order(order, portfolio) + assert result.passed is True + + # Beyond limit should fail + portfolio.equity = Decimal("80000") # 20k drawdown + result = manager.validate_order(order, portfolio) + assert result.passed is False + assert any(v.rule_name == "max_drawdown" for v in result.violations) + + def test_drawdown_percent_limit(self): + """Test drawdown percentage limit.""" + manager = RiskManager( + loss_limits=LossLimits(max_drawdown_percent=Decimal("20")) + ) + portfolio = PortfolioState( + equity=Decimal("75000"), + peak_equity=Decimal("100000"), # 25% drawdown + ) + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")) + + result = manager.validate_order(order, portfolio) + assert result.passed is False + assert any(v.rule_name == "max_drawdown_percent" for v in result.violations) + + def test_consecutive_losses_limit(self): + """Test consecutive losses limit.""" + manager = RiskManager( + loss_limits=LossLimits(max_consecutive_losses=5) + ) + portfolio = PortfolioState(consecutive_losses=5) + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")) + + result = manager.validate_order(order, portfolio) + assert result.passed is False + assert any(v.rule_name == "max_consecutive_losses" for v in result.violations) + + +class TestRiskManagerCoolingOff: + """Test RiskManager cooling off period.""" + + def test_cooling_off_triggered(self): + """Test cooling off is triggered on loss limit.""" + manager = RiskManager( + loss_limits=LossLimits( + max_daily_loss=Decimal("1000"), + cooling_off_period_minutes=30, + ) + ) + portfolio = PortfolioState(daily_pnl=Decimal("-1500")) + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")) + + result = manager.validate_order(order, portfolio) + assert result.passed is False + assert manager._in_cooling_off is True + + def test_order_blocked_during_cooling_off(self): + """Test orders blocked during cooling off.""" + manager = RiskManager( + loss_limits=LossLimits( + max_daily_loss=Decimal("1000"), + cooling_off_period_minutes=30, + ) + ) + # Trigger cooling off + portfolio_loss = PortfolioState(daily_pnl=Decimal("-1500")) + manager.validate_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")), + portfolio_loss, + ) + + # Try another order + portfolio_ok = PortfolioState(daily_pnl=Decimal("0")) + result = manager.validate_order( + OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")), + portfolio_ok, + ) + + assert result.passed is False + assert any(v.rule_name == "cooling_off_period" for v in result.violations) + + def test_cooling_off_reset(self): + """Test cooling off can be reset.""" + manager = RiskManager( + loss_limits=LossLimits( + max_daily_loss=Decimal("1000"), + cooling_off_period_minutes=30, + ) + ) + manager._in_cooling_off = True + manager._cooling_off_until = datetime.now(timezone.utc) + + manager.reset_daily_limits() + + assert manager._in_cooling_off is False + assert manager._cooling_off_until is None + + +class TestRiskManagerDisabled: + """Test RiskManager when disabled.""" + + def test_disabled_passes_all(self): + """Test disabled manager passes all orders.""" + manager = RiskManager( + position_limits=PositionLimits(max_position_size=Decimal("100")), + enabled=False, + ) + portfolio = PortfolioState() + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10000")) + + result = manager.validate_order(order, portfolio) + assert result.passed is True + + def test_enable_disable(self): + """Test enabling and disabling.""" + manager = RiskManager() + assert manager.enabled is True + + manager.enabled = False + assert manager.enabled is False + + manager.enabled = True + assert manager.enabled is True + + +class TestRiskManagerCustomRules: + """Test RiskManager custom rules.""" + + def test_custom_rule_passing(self): + """Test custom rule that passes.""" + def custom_rule(order, portfolio): + return None # Pass + + manager = RiskManager(custom_rules=[custom_rule]) + portfolio = PortfolioState() + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")) + + result = manager.validate_order(order, portfolio) + assert result.passed is True + assert "custom_rule_0" in result.checked_rules + + def test_custom_rule_failing(self): + """Test custom rule that fails.""" + def custom_rule(order, portfolio): + return RiskViolation( + rule_type=RiskRuleType.CUSTOM, + rule_name="custom_test", + message="Custom rule failed", + current_value=Decimal("0"), + limit_value=Decimal("0"), + ) + + manager = RiskManager(custom_rules=[custom_rule]) + portfolio = PortfolioState() + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")) + + result = manager.validate_order(order, portfolio) + assert result.passed is False + assert any(v.rule_name == "custom_test" for v in result.violations) + + def test_custom_rule_error_handled(self): + """Test custom rule error doesn't break validation.""" + def bad_rule(order, portfolio): + raise Exception("Rule error") + + manager = RiskManager(custom_rules=[bad_rule]) + portfolio = PortfolioState() + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")) + + # Should not raise + result = manager.validate_order(order, portfolio) + assert result.passed is True + + def test_add_custom_rule(self): + """Test adding custom rule after init.""" + manager = RiskManager() + + def custom_rule(order, portfolio): + return None + + manager.add_custom_rule(custom_rule) + assert len(manager._custom_rules) == 1 + + +class TestRiskManagerTracking: + """Test RiskManager tracking methods.""" + + def test_update_daily_pnl(self): + """Test updating daily P&L.""" + manager = RiskManager() + today = date.today() + + manager.update_daily_pnl(Decimal("100"), today) + assert manager.get_daily_pnl(today) == Decimal("100") + + manager.update_daily_pnl(Decimal("50"), today) + assert manager.get_daily_pnl(today) == Decimal("150") + + def test_update_peak_equity(self): + """Test updating peak equity.""" + manager = RiskManager() + + manager.update_peak_equity(Decimal("100000")) + assert manager._peak_equity == Decimal("100000") + + # Higher should update + manager.update_peak_equity(Decimal("110000")) + assert manager._peak_equity == Decimal("110000") + + # Lower should not update + manager.update_peak_equity(Decimal("105000")) + assert manager._peak_equity == Decimal("110000") + + def test_reset_all(self): + """Test resetting all state.""" + manager = RiskManager() + manager.update_daily_pnl(Decimal("100"), date.today()) + manager.update_peak_equity(Decimal("100000")) + manager._in_cooling_off = True + + manager.reset_all() + + assert manager.get_daily_pnl(date.today()) == Decimal("0") + assert manager._peak_equity is None + assert manager._in_cooling_off is False + + +class TestRiskManagerRuleChecking: + """Test that all rules are checked.""" + + def test_all_rules_checked(self): + """Test all rules appear in checked list.""" + manager = RiskManager( + position_limits=PositionLimits( + max_position_size=Decimal("1000"), + max_position_value=Decimal("50000"), + max_concentration_percent=Decimal("20"), + max_total_positions=10, + ), + loss_limits=LossLimits( + max_daily_loss=Decimal("1000"), + max_drawdown=Decimal("10000"), + max_single_trade_loss=Decimal("500"), + max_consecutive_losses=5, + ), + ) + portfolio = PortfolioState(equity=Decimal("100000")) + order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10")) + + result = manager.validate_order(order, portfolio, estimated_fill_price=Decimal("100")) + + expected_rules = [ + "position_size", + "position_value", + "concentration", + "total_positions", + "daily_loss", + "drawdown", + "single_trade_loss", + "consecutive_losses", + ] + for rule in expected_rules: + assert rule in result.checked_rules + + +class TestRiskRuleTypeEnum: + """Test RiskRuleType enum.""" + + def test_all_types_defined(self): + """Test all expected types are defined.""" + expected = [ + "POSITION_SIZE", + "POSITION_VALUE", + "CONCENTRATION", + "DAILY_LOSS", + "DRAWDOWN", + "CUSTOM", + ] + for type_name in expected: + assert hasattr(RiskRuleType, type_name) + + +class TestRiskCheckResultEnum: + """Test RiskCheckResult enum.""" + + def test_all_results_defined(self): + """Test all expected results are defined.""" + expected = ["PASSED", "FAILED", "WARNING", "SKIPPED"] + for result_name in expected: + assert hasattr(RiskCheckResult, result_name) diff --git a/tradingagents/execution/__init__.py b/tradingagents/execution/__init__.py index d26154c6..fa1cbdb7 100644 --- a/tradingagents/execution/__init__.py +++ b/tradingagents/execution/__init__.py @@ -124,6 +124,17 @@ from .order_manager import ( OPEN_STATES, ) +from .risk_controls import ( + RiskManager, + RiskCheckResult, + RiskRuleType, + RiskViolation, + RiskCheckResponse, + PositionLimits, + LossLimits, + PortfolioState, +) + __all__ = [ # Enums "AssetClass", @@ -176,4 +187,13 @@ __all__ = [ "VALID_TRANSITIONS", "TERMINAL_STATES", "OPEN_STATES", + # Risk Controls + "RiskManager", + "RiskCheckResult", + "RiskRuleType", + "RiskViolation", + "RiskCheckResponse", + "PositionLimits", + "LossLimits", + "PortfolioState", ] diff --git a/tradingagents/execution/risk_controls.py b/tradingagents/execution/risk_controls.py new file mode 100644 index 00000000..aa994364 --- /dev/null +++ b/tradingagents/execution/risk_controls.py @@ -0,0 +1,735 @@ +"""Risk Controls for order execution. + +Issue #28: [EXEC-27] Risk controls - position limits, loss limits + +This module provides pre-trade risk validation including: +- Position size limits (max shares, max notional value) +- Concentration limits (max % of portfolio in single position) +- Daily loss limits +- Drawdown limits +- Pre-trade validation framework + +Example: + >>> from tradingagents.execution import RiskManager, PositionLimits, LossLimits + >>> + >>> limits = PositionLimits( + ... max_position_size=Decimal("10000"), + ... max_position_value=Decimal("50000"), + ... max_concentration_percent=Decimal("20"), + ... ) + >>> risk_manager = RiskManager(position_limits=limits) + >>> result = risk_manager.check_order(order_request, portfolio) +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timezone, date +from decimal import Decimal +from enum import Enum +from typing import Any, Dict, List, Optional, Callable + +from .broker_base import ( + OrderRequest, + OrderSide, + Position, +) + + +class RiskCheckResult(Enum): + """Result of a risk check.""" + + PASSED = "passed" + FAILED = "failed" + WARNING = "warning" + SKIPPED = "skipped" + + +class RiskRuleType(Enum): + """Type of risk rule.""" + + POSITION_SIZE = "position_size" + POSITION_VALUE = "position_value" + CONCENTRATION = "concentration" + DAILY_LOSS = "daily_loss" + DRAWDOWN = "drawdown" + CUSTOM = "custom" + + +@dataclass +class RiskViolation: + """Details of a risk limit violation. + + Attributes: + rule_type: Type of rule violated + rule_name: Name of the specific rule + message: Human-readable violation message + current_value: Current value that violated the limit + limit_value: The limit that was exceeded + severity: Violation severity (error, warning) + metadata: Additional context + """ + + rule_type: RiskRuleType + rule_name: str + message: str + current_value: Decimal + limit_value: Decimal + severity: str = "error" + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class RiskCheckResponse: + """Response from risk validation. + + Attributes: + passed: Whether all risk checks passed + violations: List of rule violations + warnings: List of warnings (non-blocking) + checked_rules: List of rules that were checked + timestamp: When the check was performed + """ + + passed: bool = True + violations: List[RiskViolation] = field(default_factory=list) + warnings: List[RiskViolation] = field(default_factory=list) + checked_rules: List[str] = field(default_factory=list) + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + def add_violation(self, violation: RiskViolation) -> None: + """Add a violation to the response.""" + if violation.severity == "warning": + self.warnings.append(violation) + else: + self.violations.append(violation) + self.passed = False + + @property + def rejection_message(self) -> Optional[str]: + """Get formatted rejection message if failed.""" + if self.passed: + return None + messages = [v.message for v in self.violations] + return "; ".join(messages) + + +@dataclass +class PositionLimits: + """Position-related risk limits. + + Attributes: + max_position_size: Maximum shares/units in single position + max_position_value: Maximum notional value of single position + max_concentration_percent: Maximum % of portfolio in single position + max_total_positions: Maximum number of open positions + max_sector_concentration: Maximum % in single sector + per_symbol_limits: Custom limits per symbol + """ + + max_position_size: Optional[Decimal] = None + max_position_value: Optional[Decimal] = None + max_concentration_percent: Optional[Decimal] = None + max_total_positions: Optional[int] = None + max_sector_concentration: Optional[Decimal] = None + per_symbol_limits: Dict[str, Dict[str, Decimal]] = field(default_factory=dict) + + def get_limit_for_symbol( + self, symbol: str, limit_type: str + ) -> Optional[Decimal]: + """Get specific limit for a symbol, falling back to default.""" + if symbol in self.per_symbol_limits: + if limit_type in self.per_symbol_limits[symbol]: + return self.per_symbol_limits[symbol][limit_type] + + return getattr(self, limit_type, None) + + +@dataclass +class LossLimits: + """Loss-related risk limits. + + Attributes: + max_daily_loss: Maximum loss allowed per day + max_daily_loss_percent: Maximum loss as % of equity per day + max_drawdown: Maximum drawdown from peak + max_drawdown_percent: Maximum drawdown as % of peak + max_single_trade_loss: Maximum loss on single trade + max_consecutive_losses: Maximum consecutive losing trades + cooling_off_period_minutes: Minutes to wait after hitting limit + """ + + max_daily_loss: Optional[Decimal] = None + max_daily_loss_percent: Optional[Decimal] = None + max_drawdown: Optional[Decimal] = None + max_drawdown_percent: Optional[Decimal] = None + max_single_trade_loss: Optional[Decimal] = None + max_consecutive_losses: Optional[int] = None + cooling_off_period_minutes: int = 0 + + +@dataclass +class PortfolioState: + """Current portfolio state for risk calculations. + + Attributes: + positions: Current positions keyed by symbol + cash: Available cash + equity: Total portfolio equity + buying_power: Available buying power + daily_pnl: Profit/loss for current day + peak_equity: Peak equity for drawdown calculations + consecutive_losses: Current consecutive losing trades + last_loss_time: When last loss occurred + """ + + positions: Dict[str, Position] = field(default_factory=dict) + cash: Decimal = Decimal("0") + equity: Decimal = Decimal("0") + buying_power: Decimal = Decimal("0") + daily_pnl: Decimal = Decimal("0") + peak_equity: Optional[Decimal] = None + consecutive_losses: int = 0 + last_loss_time: Optional[datetime] = None + + @property + def current_drawdown(self) -> Decimal: + """Calculate current drawdown from peak.""" + if self.peak_equity is None or self.peak_equity <= 0: + return Decimal("0") + return self.peak_equity - self.equity + + @property + def current_drawdown_percent(self) -> Decimal: + """Calculate current drawdown as percentage.""" + if self.peak_equity is None or self.peak_equity <= 0: + return Decimal("0") + return (self.current_drawdown / self.peak_equity) * Decimal("100") + + +class RiskManager: + """Manages pre-trade risk validation. + + The RiskManager validates orders against configured risk limits + before allowing them to be submitted to brokers. + + Example: + >>> risk_manager = RiskManager( + ... position_limits=PositionLimits(max_position_size=1000), + ... loss_limits=LossLimits(max_daily_loss=Decimal("500")), + ... ) + >>> result = risk_manager.validate_order(order_request, portfolio_state) + >>> if not result.passed: + ... print(f"Order rejected: {result.rejection_message}") + """ + + def __init__( + self, + position_limits: Optional[PositionLimits] = None, + loss_limits: Optional[LossLimits] = None, + custom_rules: Optional[List[Callable]] = None, + enabled: bool = True, + ) -> None: + """Initialize risk manager. + + Args: + position_limits: Position-related limits + loss_limits: Loss-related limits + custom_rules: Custom validation functions + enabled: Whether risk checks are enabled + """ + self._position_limits = position_limits or PositionLimits() + self._loss_limits = loss_limits or LossLimits() + self._custom_rules = custom_rules or [] + self._enabled = enabled + self._daily_pnl_by_date: Dict[date, Decimal] = {} + self._peak_equity: Optional[Decimal] = None + self._in_cooling_off = False + self._cooling_off_until: Optional[datetime] = None + + @property + def enabled(self) -> bool: + """Check if risk manager is enabled.""" + return self._enabled + + @enabled.setter + def enabled(self, value: bool) -> None: + """Enable or disable risk manager.""" + self._enabled = value + + @property + def position_limits(self) -> PositionLimits: + """Get position limits.""" + return self._position_limits + + @position_limits.setter + def position_limits(self, limits: PositionLimits) -> None: + """Set position limits.""" + self._position_limits = limits + + @property + def loss_limits(self) -> LossLimits: + """Get loss limits.""" + return self._loss_limits + + @loss_limits.setter + def loss_limits(self, limits: LossLimits) -> None: + """Set loss limits.""" + self._loss_limits = limits + + def add_custom_rule( + self, + rule: Callable[[OrderRequest, PortfolioState], Optional[RiskViolation]], + ) -> None: + """Add a custom validation rule. + + Args: + rule: Function that takes order and portfolio, returns violation or None + """ + self._custom_rules.append(rule) + + def validate_order( + self, + order: OrderRequest, + portfolio: PortfolioState, + estimated_fill_price: Optional[Decimal] = None, + ) -> RiskCheckResponse: + """Validate an order against all risk limits. + + Args: + order: Order request to validate + portfolio: Current portfolio state + estimated_fill_price: Expected fill price for value calculations + + Returns: + RiskCheckResponse with validation results + """ + response = RiskCheckResponse() + + if not self._enabled: + response.checked_rules.append("(risk checks disabled)") + return response + + # Check cooling off period + if self._in_cooling_off and self._cooling_off_until: + if datetime.now(timezone.utc) < self._cooling_off_until: + response.add_violation( + RiskViolation( + rule_type=RiskRuleType.DAILY_LOSS, + rule_name="cooling_off_period", + message=f"In cooling off period until {self._cooling_off_until}", + current_value=Decimal("0"), + limit_value=Decimal("0"), + ) + ) + return response + else: + self._in_cooling_off = False + self._cooling_off_until = None + + # Run all checks + self._check_position_size(order, portfolio, response) + self._check_position_value(order, portfolio, response, estimated_fill_price) + self._check_concentration(order, portfolio, response, estimated_fill_price) + self._check_total_positions(order, portfolio, response) + self._check_daily_loss(portfolio, response) + self._check_drawdown(portfolio, response) + self._check_single_trade_loss(order, portfolio, response, estimated_fill_price) + self._check_consecutive_losses(portfolio, response) + self._run_custom_rules(order, portfolio, response) + + return response + + def _check_position_size( + self, + order: OrderRequest, + portfolio: PortfolioState, + response: RiskCheckResponse, + ) -> None: + """Check position size limits.""" + response.checked_rules.append("position_size") + + limit = self._position_limits.get_limit_for_symbol( + order.symbol, "max_position_size" + ) + if limit is None: + return + + current_position = portfolio.positions.get(order.symbol) + current_qty = current_position.quantity if current_position else Decimal("0") + + if order.side == OrderSide.BUY: + new_qty = current_qty + order.quantity + else: + new_qty = current_qty - order.quantity + + if abs(new_qty) > limit: + response.add_violation( + RiskViolation( + rule_type=RiskRuleType.POSITION_SIZE, + rule_name="max_position_size", + message=( + f"Position size {abs(new_qty)} exceeds limit {limit} " + f"for {order.symbol}" + ), + current_value=abs(new_qty), + limit_value=limit, + ) + ) + + def _check_position_value( + self, + order: OrderRequest, + portfolio: PortfolioState, + response: RiskCheckResponse, + estimated_price: Optional[Decimal], + ) -> None: + """Check position value limits.""" + response.checked_rules.append("position_value") + + limit = self._position_limits.get_limit_for_symbol( + order.symbol, "max_position_value" + ) + if limit is None: + return + + if estimated_price is None: + # Can't check without price + return + + current_position = portfolio.positions.get(order.symbol) + current_qty = current_position.quantity if current_position else Decimal("0") + + if order.side == OrderSide.BUY: + new_qty = current_qty + order.quantity + else: + new_qty = current_qty - order.quantity + + new_value = abs(new_qty * estimated_price) + + if new_value > limit: + response.add_violation( + RiskViolation( + rule_type=RiskRuleType.POSITION_VALUE, + rule_name="max_position_value", + message=( + f"Position value ${new_value} exceeds limit ${limit} " + f"for {order.symbol}" + ), + current_value=new_value, + limit_value=limit, + ) + ) + + def _check_concentration( + self, + order: OrderRequest, + portfolio: PortfolioState, + response: RiskCheckResponse, + estimated_price: Optional[Decimal], + ) -> None: + """Check concentration limits.""" + response.checked_rules.append("concentration") + + limit = self._position_limits.max_concentration_percent + if limit is None: + return + + if estimated_price is None or portfolio.equity <= 0: + return + + current_position = portfolio.positions.get(order.symbol) + current_qty = current_position.quantity if current_position else Decimal("0") + + if order.side == OrderSide.BUY: + new_qty = current_qty + order.quantity + else: + new_qty = current_qty - order.quantity + + new_value = abs(new_qty * estimated_price) + concentration = (new_value / portfolio.equity) * Decimal("100") + + if concentration > limit: + response.add_violation( + RiskViolation( + rule_type=RiskRuleType.CONCENTRATION, + rule_name="max_concentration", + message=( + f"Concentration {concentration:.1f}% exceeds limit {limit}% " + f"for {order.symbol}" + ), + current_value=concentration, + limit_value=limit, + ) + ) + + def _check_total_positions( + self, + order: OrderRequest, + portfolio: PortfolioState, + response: RiskCheckResponse, + ) -> None: + """Check total positions limit.""" + response.checked_rules.append("total_positions") + + limit = self._position_limits.max_total_positions + if limit is None: + return + + # Only check for new positions (buys in symbols we don't have) + if order.side != OrderSide.BUY: + return + + if order.symbol in portfolio.positions: + return # Adding to existing position + + if len(portfolio.positions) >= limit: + response.add_violation( + RiskViolation( + rule_type=RiskRuleType.POSITION_SIZE, + rule_name="max_total_positions", + message=( + f"Total positions {len(portfolio.positions)} at limit {limit}" + ), + current_value=Decimal(str(len(portfolio.positions))), + limit_value=Decimal(str(limit)), + ) + ) + + def _check_daily_loss( + self, + portfolio: PortfolioState, + response: RiskCheckResponse, + ) -> None: + """Check daily loss limits.""" + response.checked_rules.append("daily_loss") + + # Check absolute daily loss + if self._loss_limits.max_daily_loss is not None: + if portfolio.daily_pnl < -self._loss_limits.max_daily_loss: + self._trigger_cooling_off() + response.add_violation( + RiskViolation( + rule_type=RiskRuleType.DAILY_LOSS, + rule_name="max_daily_loss", + message=( + f"Daily loss ${abs(portfolio.daily_pnl)} exceeds limit " + f"${self._loss_limits.max_daily_loss}" + ), + current_value=abs(portfolio.daily_pnl), + limit_value=self._loss_limits.max_daily_loss, + ) + ) + + # Check percentage daily loss + if self._loss_limits.max_daily_loss_percent is not None: + if portfolio.equity > 0: + daily_loss_pct = ( + abs(portfolio.daily_pnl) / portfolio.equity + ) * Decimal("100") + if ( + portfolio.daily_pnl < 0 + and daily_loss_pct > self._loss_limits.max_daily_loss_percent + ): + self._trigger_cooling_off() + response.add_violation( + RiskViolation( + rule_type=RiskRuleType.DAILY_LOSS, + rule_name="max_daily_loss_percent", + message=( + f"Daily loss {daily_loss_pct:.1f}% exceeds limit " + f"{self._loss_limits.max_daily_loss_percent}%" + ), + current_value=daily_loss_pct, + limit_value=self._loss_limits.max_daily_loss_percent, + ) + ) + + def _check_drawdown( + self, + portfolio: PortfolioState, + response: RiskCheckResponse, + ) -> None: + """Check drawdown limits.""" + response.checked_rules.append("drawdown") + + # Check absolute drawdown + if self._loss_limits.max_drawdown is not None: + if portfolio.current_drawdown > self._loss_limits.max_drawdown: + response.add_violation( + RiskViolation( + rule_type=RiskRuleType.DRAWDOWN, + rule_name="max_drawdown", + message=( + f"Drawdown ${portfolio.current_drawdown} exceeds limit " + f"${self._loss_limits.max_drawdown}" + ), + current_value=portfolio.current_drawdown, + limit_value=self._loss_limits.max_drawdown, + ) + ) + + # Check percentage drawdown + if self._loss_limits.max_drawdown_percent is not None: + if portfolio.current_drawdown_percent > self._loss_limits.max_drawdown_percent: + response.add_violation( + RiskViolation( + rule_type=RiskRuleType.DRAWDOWN, + rule_name="max_drawdown_percent", + message=( + f"Drawdown {portfolio.current_drawdown_percent:.1f}% " + f"exceeds limit {self._loss_limits.max_drawdown_percent}%" + ), + current_value=portfolio.current_drawdown_percent, + limit_value=self._loss_limits.max_drawdown_percent, + ) + ) + + def _check_single_trade_loss( + self, + order: OrderRequest, + portfolio: PortfolioState, + response: RiskCheckResponse, + estimated_price: Optional[Decimal], + ) -> None: + """Check single trade loss limit.""" + response.checked_rules.append("single_trade_loss") + + limit = self._loss_limits.max_single_trade_loss + if limit is None: + return + + if estimated_price is None: + return + + # Calculate max potential loss for the order + order_value = order.quantity * estimated_price + + # For sells, potential loss is limited + # For buys, assume worst case is total loss of order value + if order.side == OrderSide.BUY: + potential_loss = order_value + if potential_loss > limit: + response.add_violation( + RiskViolation( + rule_type=RiskRuleType.DAILY_LOSS, + rule_name="max_single_trade_loss", + message=( + f"Potential loss ${potential_loss} exceeds limit " + f"${limit}" + ), + current_value=potential_loss, + limit_value=limit, + severity="warning", # Warning, not blocking + ) + ) + + def _check_consecutive_losses( + self, + portfolio: PortfolioState, + response: RiskCheckResponse, + ) -> None: + """Check consecutive losses limit.""" + response.checked_rules.append("consecutive_losses") + + limit = self._loss_limits.max_consecutive_losses + if limit is None: + return + + if portfolio.consecutive_losses >= limit: + self._trigger_cooling_off() + response.add_violation( + RiskViolation( + rule_type=RiskRuleType.DAILY_LOSS, + rule_name="max_consecutive_losses", + message=( + f"Consecutive losses {portfolio.consecutive_losses} " + f"reached limit {limit}" + ), + current_value=Decimal(str(portfolio.consecutive_losses)), + limit_value=Decimal(str(limit)), + ) + ) + + def _run_custom_rules( + self, + order: OrderRequest, + portfolio: PortfolioState, + response: RiskCheckResponse, + ) -> None: + """Run custom validation rules.""" + for i, rule in enumerate(self._custom_rules): + response.checked_rules.append(f"custom_rule_{i}") + try: + violation = rule(order, portfolio) + if violation: + response.add_violation(violation) + except Exception: + # Don't let custom rule errors break validation + pass + + def _trigger_cooling_off(self) -> None: + """Trigger cooling off period.""" + if self._loss_limits.cooling_off_period_minutes > 0: + self._in_cooling_off = True + from datetime import timedelta + + self._cooling_off_until = datetime.now(timezone.utc) + timedelta( + minutes=self._loss_limits.cooling_off_period_minutes + ) + + def update_daily_pnl(self, pnl: Decimal, trade_date: date) -> None: + """Update daily P&L tracking. + + Args: + pnl: P&L for the trade + trade_date: Date of the trade + """ + if trade_date not in self._daily_pnl_by_date: + self._daily_pnl_by_date[trade_date] = Decimal("0") + self._daily_pnl_by_date[trade_date] += pnl + + def get_daily_pnl(self, trade_date: date) -> Decimal: + """Get daily P&L for a date. + + Args: + trade_date: Date to get P&L for + + Returns: + P&L for the date + """ + return self._daily_pnl_by_date.get(trade_date, Decimal("0")) + + def update_peak_equity(self, equity: Decimal) -> None: + """Update peak equity tracking. + + Args: + equity: Current equity + """ + if self._peak_equity is None or equity > self._peak_equity: + self._peak_equity = equity + + def reset_daily_limits(self) -> None: + """Reset daily tracking (call at start of each trading day).""" + self._in_cooling_off = False + self._cooling_off_until = None + + def reset_all(self) -> None: + """Reset all tracking state.""" + self._daily_pnl_by_date.clear() + self._peak_equity = None + self._in_cooling_off = False + self._cooling_off_until = None + + +# Export +__all__ = [ + "RiskManager", + "RiskCheckResult", + "RiskRuleType", + "RiskViolation", + "RiskCheckResponse", + "PositionLimits", + "LossLimits", + "PortfolioState", +]