feat(execution): add Risk Controls for pre-trade validation - Issue #28 (45 tests)

This commit is contained in:
Andrew Kaszubski 2025-12-26 21:29:14 +11:00
parent 6863e3ed87
commit 9aee43312d
3 changed files with 1433 additions and 0 deletions

View File

@ -0,0 +1,678 @@
"""Tests for Risk Controls implementation.
Issue #28: [EXEC-27] Risk controls - position limits, loss limits
"""
from decimal import Decimal
from datetime import datetime, timezone, timedelta, date
import pytest
from tradingagents.execution import (
RiskManager,
RiskCheckResult,
RiskRuleType,
RiskViolation,
RiskCheckResponse,
PositionLimits,
LossLimits,
PortfolioState,
OrderRequest,
OrderSide,
Position,
PositionSide,
)
class TestRiskViolation:
"""Test RiskViolation dataclass."""
def test_violation_creation(self):
"""Test creating a violation."""
violation = RiskViolation(
rule_type=RiskRuleType.POSITION_SIZE,
rule_name="max_position_size",
message="Position too large",
current_value=Decimal("1500"),
limit_value=Decimal("1000"),
)
assert violation.rule_type == RiskRuleType.POSITION_SIZE
assert violation.current_value == Decimal("1500")
assert violation.severity == "error"
def test_violation_with_warning_severity(self):
"""Test violation with warning severity."""
violation = RiskViolation(
rule_type=RiskRuleType.DAILY_LOSS,
rule_name="potential_loss",
message="Potential loss warning",
current_value=Decimal("100"),
limit_value=Decimal("200"),
severity="warning",
)
assert violation.severity == "warning"
class TestRiskCheckResponse:
"""Test RiskCheckResponse dataclass."""
def test_default_passed(self):
"""Test default response is passed."""
response = RiskCheckResponse()
assert response.passed is True
assert response.violations == []
assert response.warnings == []
def test_add_violation(self):
"""Test adding violation fails response."""
response = RiskCheckResponse()
violation = RiskViolation(
rule_type=RiskRuleType.POSITION_SIZE,
rule_name="test",
message="Test violation",
current_value=Decimal("100"),
limit_value=Decimal("50"),
)
response.add_violation(violation)
assert response.passed is False
assert len(response.violations) == 1
def test_add_warning(self):
"""Test adding warning keeps passed."""
response = RiskCheckResponse()
violation = RiskViolation(
rule_type=RiskRuleType.DAILY_LOSS,
rule_name="test",
message="Test warning",
current_value=Decimal("100"),
limit_value=Decimal("200"),
severity="warning",
)
response.add_violation(violation)
assert response.passed is True
assert len(response.warnings) == 1
def test_rejection_message(self):
"""Test rejection message formatting."""
response = RiskCheckResponse()
violation = RiskViolation(
rule_type=RiskRuleType.POSITION_SIZE,
rule_name="test",
message="Position too large",
current_value=Decimal("100"),
limit_value=Decimal("50"),
)
response.add_violation(violation)
assert "Position too large" in response.rejection_message
def test_rejection_message_none_when_passed(self):
"""Test rejection message is None when passed."""
response = RiskCheckResponse()
assert response.rejection_message is None
class TestPositionLimits:
"""Test PositionLimits dataclass."""
def test_default_limits(self):
"""Test default limits are None."""
limits = PositionLimits()
assert limits.max_position_size is None
assert limits.max_position_value is None
assert limits.max_concentration_percent is None
def test_custom_limits(self):
"""Test setting custom limits."""
limits = PositionLimits(
max_position_size=Decimal("1000"),
max_position_value=Decimal("50000"),
max_concentration_percent=Decimal("20"),
)
assert limits.max_position_size == Decimal("1000")
assert limits.max_position_value == Decimal("50000")
assert limits.max_concentration_percent == Decimal("20")
def test_per_symbol_limits(self):
"""Test per-symbol limits."""
limits = PositionLimits(
max_position_size=Decimal("1000"),
per_symbol_limits={
"AAPL": {"max_position_size": Decimal("500")},
},
)
assert limits.get_limit_for_symbol("AAPL", "max_position_size") == Decimal("500")
assert limits.get_limit_for_symbol("MSFT", "max_position_size") == Decimal("1000")
class TestLossLimits:
"""Test LossLimits dataclass."""
def test_default_limits(self):
"""Test default limits are None."""
limits = LossLimits()
assert limits.max_daily_loss is None
assert limits.max_drawdown is None
assert limits.cooling_off_period_minutes == 0
def test_custom_limits(self):
"""Test setting custom limits."""
limits = LossLimits(
max_daily_loss=Decimal("500"),
max_daily_loss_percent=Decimal("5"),
max_drawdown_percent=Decimal("20"),
cooling_off_period_minutes=30,
)
assert limits.max_daily_loss == Decimal("500")
assert limits.max_daily_loss_percent == Decimal("5")
assert limits.cooling_off_period_minutes == 30
class TestPortfolioState:
"""Test PortfolioState dataclass."""
def test_default_state(self):
"""Test default state."""
state = PortfolioState()
assert state.cash == Decimal("0")
assert state.equity == Decimal("0")
assert state.daily_pnl == Decimal("0")
def test_drawdown_calculation(self):
"""Test drawdown calculation."""
state = PortfolioState(
equity=Decimal("90000"),
peak_equity=Decimal("100000"),
)
assert state.current_drawdown == Decimal("10000")
assert state.current_drawdown_percent == Decimal("10")
def test_no_drawdown_without_peak(self):
"""Test no drawdown without peak."""
state = PortfolioState(equity=Decimal("100000"))
assert state.current_drawdown == Decimal("0")
assert state.current_drawdown_percent == Decimal("0")
class TestRiskManagerInit:
"""Test RiskManager initialization."""
def test_default_initialization(self):
"""Test default initialization."""
manager = RiskManager()
assert manager.enabled is True
assert manager.position_limits is not None
assert manager.loss_limits is not None
def test_disabled_initialization(self):
"""Test disabled initialization."""
manager = RiskManager(enabled=False)
assert manager.enabled is False
def test_with_limits(self):
"""Test initialization with limits."""
pos_limits = PositionLimits(max_position_size=Decimal("1000"))
loss_limits = LossLimits(max_daily_loss=Decimal("500"))
manager = RiskManager(
position_limits=pos_limits,
loss_limits=loss_limits,
)
assert manager.position_limits.max_position_size == Decimal("1000")
assert manager.loss_limits.max_daily_loss == Decimal("500")
class TestRiskManagerPositionLimits:
"""Test RiskManager position limit checks."""
def test_position_size_within_limit(self):
"""Test position size within limit passes."""
manager = RiskManager(
position_limits=PositionLimits(max_position_size=Decimal("1000"))
)
portfolio = PortfolioState()
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("500"))
result = manager.validate_order(order, portfolio)
assert result.passed is True
def test_position_size_exceeds_limit(self):
"""Test position size exceeding limit fails."""
manager = RiskManager(
position_limits=PositionLimits(max_position_size=Decimal("1000"))
)
portfolio = PortfolioState()
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("1500"))
result = manager.validate_order(order, portfolio)
assert result.passed is False
assert any(v.rule_name == "max_position_size" for v in result.violations)
def test_position_size_with_existing_position(self):
"""Test position size check with existing position."""
manager = RiskManager(
position_limits=PositionLimits(max_position_size=Decimal("1000"))
)
portfolio = PortfolioState(
positions={
"AAPL": Position(
symbol="AAPL",
quantity=Decimal("600"),
side=PositionSide.LONG,
avg_entry_price=Decimal("100"),
current_price=Decimal("100"),
market_value=Decimal("60000"),
cost_basis=Decimal("60000"),
unrealized_pnl=Decimal("0"),
unrealized_pnl_percent=Decimal("0"),
)
}
)
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("500"))
result = manager.validate_order(order, portfolio)
assert result.passed is False # 600 + 500 = 1100 > 1000
def test_position_value_limit(self):
"""Test position value limit."""
manager = RiskManager(
position_limits=PositionLimits(max_position_value=Decimal("50000"))
)
portfolio = PortfolioState()
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("1000"))
# At $100, value is $100,000 which exceeds $50,000 limit
result = manager.validate_order(order, portfolio, estimated_fill_price=Decimal("100"))
assert result.passed is False
assert any(v.rule_name == "max_position_value" for v in result.violations)
def test_concentration_limit(self):
"""Test concentration limit."""
manager = RiskManager(
position_limits=PositionLimits(max_concentration_percent=Decimal("20"))
)
portfolio = PortfolioState(equity=Decimal("100000"))
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("300"))
# At $100, value is $30,000 = 30% of $100,000 equity
result = manager.validate_order(order, portfolio, estimated_fill_price=Decimal("100"))
assert result.passed is False
assert any(v.rule_name == "max_concentration" for v in result.violations)
def test_total_positions_limit(self):
"""Test total positions limit."""
manager = RiskManager(
position_limits=PositionLimits(max_total_positions=2)
)
portfolio = PortfolioState(
positions={
"AAPL": Position(
symbol="AAPL",
quantity=Decimal("100"),
side=PositionSide.LONG,
avg_entry_price=Decimal("100"),
current_price=Decimal("100"),
market_value=Decimal("10000"),
cost_basis=Decimal("10000"),
unrealized_pnl=Decimal("0"),
unrealized_pnl_percent=Decimal("0"),
),
"MSFT": Position(
symbol="MSFT",
quantity=Decimal("100"),
side=PositionSide.LONG,
avg_entry_price=Decimal("100"),
current_price=Decimal("100"),
market_value=Decimal("10000"),
cost_basis=Decimal("10000"),
unrealized_pnl=Decimal("0"),
unrealized_pnl_percent=Decimal("0"),
),
}
)
order = OrderRequest.market("GOOGL", OrderSide.BUY, Decimal("100"))
result = manager.validate_order(order, portfolio)
assert result.passed is False
assert any(v.rule_name == "max_total_positions" for v in result.violations)
class TestRiskManagerLossLimits:
"""Test RiskManager loss limit checks."""
def test_daily_loss_within_limit(self):
"""Test daily loss within limit passes."""
manager = RiskManager(
loss_limits=LossLimits(max_daily_loss=Decimal("1000"))
)
portfolio = PortfolioState(daily_pnl=Decimal("-500"))
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
result = manager.validate_order(order, portfolio)
assert result.passed is True
def test_daily_loss_exceeds_limit(self):
"""Test daily loss exceeding limit fails."""
manager = RiskManager(
loss_limits=LossLimits(max_daily_loss=Decimal("1000"))
)
portfolio = PortfolioState(daily_pnl=Decimal("-1500"))
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
result = manager.validate_order(order, portfolio)
assert result.passed is False
assert any(v.rule_name == "max_daily_loss" for v in result.violations)
def test_daily_loss_percent_limit(self):
"""Test daily loss percentage limit."""
manager = RiskManager(
loss_limits=LossLimits(max_daily_loss_percent=Decimal("5"))
)
portfolio = PortfolioState(
equity=Decimal("100000"),
daily_pnl=Decimal("-6000"), # 6% loss
)
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
result = manager.validate_order(order, portfolio)
assert result.passed is False
assert any(v.rule_name == "max_daily_loss_percent" for v in result.violations)
def test_drawdown_limit(self):
"""Test drawdown limit."""
manager = RiskManager(
loss_limits=LossLimits(max_drawdown=Decimal("15000"))
)
portfolio = PortfolioState(
equity=Decimal("85000"),
peak_equity=Decimal("100000"), # 15k drawdown
)
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
# Exactly at limit should pass
result = manager.validate_order(order, portfolio)
assert result.passed is True
# Beyond limit should fail
portfolio.equity = Decimal("80000") # 20k drawdown
result = manager.validate_order(order, portfolio)
assert result.passed is False
assert any(v.rule_name == "max_drawdown" for v in result.violations)
def test_drawdown_percent_limit(self):
"""Test drawdown percentage limit."""
manager = RiskManager(
loss_limits=LossLimits(max_drawdown_percent=Decimal("20"))
)
portfolio = PortfolioState(
equity=Decimal("75000"),
peak_equity=Decimal("100000"), # 25% drawdown
)
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
result = manager.validate_order(order, portfolio)
assert result.passed is False
assert any(v.rule_name == "max_drawdown_percent" for v in result.violations)
def test_consecutive_losses_limit(self):
"""Test consecutive losses limit."""
manager = RiskManager(
loss_limits=LossLimits(max_consecutive_losses=5)
)
portfolio = PortfolioState(consecutive_losses=5)
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
result = manager.validate_order(order, portfolio)
assert result.passed is False
assert any(v.rule_name == "max_consecutive_losses" for v in result.violations)
class TestRiskManagerCoolingOff:
"""Test RiskManager cooling off period."""
def test_cooling_off_triggered(self):
"""Test cooling off is triggered on loss limit."""
manager = RiskManager(
loss_limits=LossLimits(
max_daily_loss=Decimal("1000"),
cooling_off_period_minutes=30,
)
)
portfolio = PortfolioState(daily_pnl=Decimal("-1500"))
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
result = manager.validate_order(order, portfolio)
assert result.passed is False
assert manager._in_cooling_off is True
def test_order_blocked_during_cooling_off(self):
"""Test orders blocked during cooling off."""
manager = RiskManager(
loss_limits=LossLimits(
max_daily_loss=Decimal("1000"),
cooling_off_period_minutes=30,
)
)
# Trigger cooling off
portfolio_loss = PortfolioState(daily_pnl=Decimal("-1500"))
manager.validate_order(
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")),
portfolio_loss,
)
# Try another order
portfolio_ok = PortfolioState(daily_pnl=Decimal("0"))
result = manager.validate_order(
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")),
portfolio_ok,
)
assert result.passed is False
assert any(v.rule_name == "cooling_off_period" for v in result.violations)
def test_cooling_off_reset(self):
"""Test cooling off can be reset."""
manager = RiskManager(
loss_limits=LossLimits(
max_daily_loss=Decimal("1000"),
cooling_off_period_minutes=30,
)
)
manager._in_cooling_off = True
manager._cooling_off_until = datetime.now(timezone.utc)
manager.reset_daily_limits()
assert manager._in_cooling_off is False
assert manager._cooling_off_until is None
class TestRiskManagerDisabled:
"""Test RiskManager when disabled."""
def test_disabled_passes_all(self):
"""Test disabled manager passes all orders."""
manager = RiskManager(
position_limits=PositionLimits(max_position_size=Decimal("100")),
enabled=False,
)
portfolio = PortfolioState()
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10000"))
result = manager.validate_order(order, portfolio)
assert result.passed is True
def test_enable_disable(self):
"""Test enabling and disabling."""
manager = RiskManager()
assert manager.enabled is True
manager.enabled = False
assert manager.enabled is False
manager.enabled = True
assert manager.enabled is True
class TestRiskManagerCustomRules:
"""Test RiskManager custom rules."""
def test_custom_rule_passing(self):
"""Test custom rule that passes."""
def custom_rule(order, portfolio):
return None # Pass
manager = RiskManager(custom_rules=[custom_rule])
portfolio = PortfolioState()
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
result = manager.validate_order(order, portfolio)
assert result.passed is True
assert "custom_rule_0" in result.checked_rules
def test_custom_rule_failing(self):
"""Test custom rule that fails."""
def custom_rule(order, portfolio):
return RiskViolation(
rule_type=RiskRuleType.CUSTOM,
rule_name="custom_test",
message="Custom rule failed",
current_value=Decimal("0"),
limit_value=Decimal("0"),
)
manager = RiskManager(custom_rules=[custom_rule])
portfolio = PortfolioState()
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
result = manager.validate_order(order, portfolio)
assert result.passed is False
assert any(v.rule_name == "custom_test" for v in result.violations)
def test_custom_rule_error_handled(self):
"""Test custom rule error doesn't break validation."""
def bad_rule(order, portfolio):
raise Exception("Rule error")
manager = RiskManager(custom_rules=[bad_rule])
portfolio = PortfolioState()
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
# Should not raise
result = manager.validate_order(order, portfolio)
assert result.passed is True
def test_add_custom_rule(self):
"""Test adding custom rule after init."""
manager = RiskManager()
def custom_rule(order, portfolio):
return None
manager.add_custom_rule(custom_rule)
assert len(manager._custom_rules) == 1
class TestRiskManagerTracking:
"""Test RiskManager tracking methods."""
def test_update_daily_pnl(self):
"""Test updating daily P&L."""
manager = RiskManager()
today = date.today()
manager.update_daily_pnl(Decimal("100"), today)
assert manager.get_daily_pnl(today) == Decimal("100")
manager.update_daily_pnl(Decimal("50"), today)
assert manager.get_daily_pnl(today) == Decimal("150")
def test_update_peak_equity(self):
"""Test updating peak equity."""
manager = RiskManager()
manager.update_peak_equity(Decimal("100000"))
assert manager._peak_equity == Decimal("100000")
# Higher should update
manager.update_peak_equity(Decimal("110000"))
assert manager._peak_equity == Decimal("110000")
# Lower should not update
manager.update_peak_equity(Decimal("105000"))
assert manager._peak_equity == Decimal("110000")
def test_reset_all(self):
"""Test resetting all state."""
manager = RiskManager()
manager.update_daily_pnl(Decimal("100"), date.today())
manager.update_peak_equity(Decimal("100000"))
manager._in_cooling_off = True
manager.reset_all()
assert manager.get_daily_pnl(date.today()) == Decimal("0")
assert manager._peak_equity is None
assert manager._in_cooling_off is False
class TestRiskManagerRuleChecking:
"""Test that all rules are checked."""
def test_all_rules_checked(self):
"""Test all rules appear in checked list."""
manager = RiskManager(
position_limits=PositionLimits(
max_position_size=Decimal("1000"),
max_position_value=Decimal("50000"),
max_concentration_percent=Decimal("20"),
max_total_positions=10,
),
loss_limits=LossLimits(
max_daily_loss=Decimal("1000"),
max_drawdown=Decimal("10000"),
max_single_trade_loss=Decimal("500"),
max_consecutive_losses=5,
),
)
portfolio = PortfolioState(equity=Decimal("100000"))
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
result = manager.validate_order(order, portfolio, estimated_fill_price=Decimal("100"))
expected_rules = [
"position_size",
"position_value",
"concentration",
"total_positions",
"daily_loss",
"drawdown",
"single_trade_loss",
"consecutive_losses",
]
for rule in expected_rules:
assert rule in result.checked_rules
class TestRiskRuleTypeEnum:
"""Test RiskRuleType enum."""
def test_all_types_defined(self):
"""Test all expected types are defined."""
expected = [
"POSITION_SIZE",
"POSITION_VALUE",
"CONCENTRATION",
"DAILY_LOSS",
"DRAWDOWN",
"CUSTOM",
]
for type_name in expected:
assert hasattr(RiskRuleType, type_name)
class TestRiskCheckResultEnum:
"""Test RiskCheckResult enum."""
def test_all_results_defined(self):
"""Test all expected results are defined."""
expected = ["PASSED", "FAILED", "WARNING", "SKIPPED"]
for result_name in expected:
assert hasattr(RiskCheckResult, result_name)

View File

@ -124,6 +124,17 @@ from .order_manager import (
OPEN_STATES,
)
from .risk_controls import (
RiskManager,
RiskCheckResult,
RiskRuleType,
RiskViolation,
RiskCheckResponse,
PositionLimits,
LossLimits,
PortfolioState,
)
__all__ = [
# Enums
"AssetClass",
@ -176,4 +187,13 @@ __all__ = [
"VALID_TRANSITIONS",
"TERMINAL_STATES",
"OPEN_STATES",
# Risk Controls
"RiskManager",
"RiskCheckResult",
"RiskRuleType",
"RiskViolation",
"RiskCheckResponse",
"PositionLimits",
"LossLimits",
"PortfolioState",
]

View File

@ -0,0 +1,735 @@
"""Risk Controls for order execution.
Issue #28: [EXEC-27] Risk controls - position limits, loss limits
This module provides pre-trade risk validation including:
- Position size limits (max shares, max notional value)
- Concentration limits (max % of portfolio in single position)
- Daily loss limits
- Drawdown limits
- Pre-trade validation framework
Example:
>>> from tradingagents.execution import RiskManager, PositionLimits, LossLimits
>>>
>>> limits = PositionLimits(
... max_position_size=Decimal("10000"),
... max_position_value=Decimal("50000"),
... max_concentration_percent=Decimal("20"),
... )
>>> risk_manager = RiskManager(position_limits=limits)
>>> result = risk_manager.check_order(order_request, portfolio)
"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone, date
from decimal import Decimal
from enum import Enum
from typing import Any, Dict, List, Optional, Callable
from .broker_base import (
OrderRequest,
OrderSide,
Position,
)
class RiskCheckResult(Enum):
"""Result of a risk check."""
PASSED = "passed"
FAILED = "failed"
WARNING = "warning"
SKIPPED = "skipped"
class RiskRuleType(Enum):
"""Type of risk rule."""
POSITION_SIZE = "position_size"
POSITION_VALUE = "position_value"
CONCENTRATION = "concentration"
DAILY_LOSS = "daily_loss"
DRAWDOWN = "drawdown"
CUSTOM = "custom"
@dataclass
class RiskViolation:
"""Details of a risk limit violation.
Attributes:
rule_type: Type of rule violated
rule_name: Name of the specific rule
message: Human-readable violation message
current_value: Current value that violated the limit
limit_value: The limit that was exceeded
severity: Violation severity (error, warning)
metadata: Additional context
"""
rule_type: RiskRuleType
rule_name: str
message: str
current_value: Decimal
limit_value: Decimal
severity: str = "error"
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class RiskCheckResponse:
"""Response from risk validation.
Attributes:
passed: Whether all risk checks passed
violations: List of rule violations
warnings: List of warnings (non-blocking)
checked_rules: List of rules that were checked
timestamp: When the check was performed
"""
passed: bool = True
violations: List[RiskViolation] = field(default_factory=list)
warnings: List[RiskViolation] = field(default_factory=list)
checked_rules: List[str] = field(default_factory=list)
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def add_violation(self, violation: RiskViolation) -> None:
"""Add a violation to the response."""
if violation.severity == "warning":
self.warnings.append(violation)
else:
self.violations.append(violation)
self.passed = False
@property
def rejection_message(self) -> Optional[str]:
"""Get formatted rejection message if failed."""
if self.passed:
return None
messages = [v.message for v in self.violations]
return "; ".join(messages)
@dataclass
class PositionLimits:
"""Position-related risk limits.
Attributes:
max_position_size: Maximum shares/units in single position
max_position_value: Maximum notional value of single position
max_concentration_percent: Maximum % of portfolio in single position
max_total_positions: Maximum number of open positions
max_sector_concentration: Maximum % in single sector
per_symbol_limits: Custom limits per symbol
"""
max_position_size: Optional[Decimal] = None
max_position_value: Optional[Decimal] = None
max_concentration_percent: Optional[Decimal] = None
max_total_positions: Optional[int] = None
max_sector_concentration: Optional[Decimal] = None
per_symbol_limits: Dict[str, Dict[str, Decimal]] = field(default_factory=dict)
def get_limit_for_symbol(
self, symbol: str, limit_type: str
) -> Optional[Decimal]:
"""Get specific limit for a symbol, falling back to default."""
if symbol in self.per_symbol_limits:
if limit_type in self.per_symbol_limits[symbol]:
return self.per_symbol_limits[symbol][limit_type]
return getattr(self, limit_type, None)
@dataclass
class LossLimits:
"""Loss-related risk limits.
Attributes:
max_daily_loss: Maximum loss allowed per day
max_daily_loss_percent: Maximum loss as % of equity per day
max_drawdown: Maximum drawdown from peak
max_drawdown_percent: Maximum drawdown as % of peak
max_single_trade_loss: Maximum loss on single trade
max_consecutive_losses: Maximum consecutive losing trades
cooling_off_period_minutes: Minutes to wait after hitting limit
"""
max_daily_loss: Optional[Decimal] = None
max_daily_loss_percent: Optional[Decimal] = None
max_drawdown: Optional[Decimal] = None
max_drawdown_percent: Optional[Decimal] = None
max_single_trade_loss: Optional[Decimal] = None
max_consecutive_losses: Optional[int] = None
cooling_off_period_minutes: int = 0
@dataclass
class PortfolioState:
"""Current portfolio state for risk calculations.
Attributes:
positions: Current positions keyed by symbol
cash: Available cash
equity: Total portfolio equity
buying_power: Available buying power
daily_pnl: Profit/loss for current day
peak_equity: Peak equity for drawdown calculations
consecutive_losses: Current consecutive losing trades
last_loss_time: When last loss occurred
"""
positions: Dict[str, Position] = field(default_factory=dict)
cash: Decimal = Decimal("0")
equity: Decimal = Decimal("0")
buying_power: Decimal = Decimal("0")
daily_pnl: Decimal = Decimal("0")
peak_equity: Optional[Decimal] = None
consecutive_losses: int = 0
last_loss_time: Optional[datetime] = None
@property
def current_drawdown(self) -> Decimal:
"""Calculate current drawdown from peak."""
if self.peak_equity is None or self.peak_equity <= 0:
return Decimal("0")
return self.peak_equity - self.equity
@property
def current_drawdown_percent(self) -> Decimal:
"""Calculate current drawdown as percentage."""
if self.peak_equity is None or self.peak_equity <= 0:
return Decimal("0")
return (self.current_drawdown / self.peak_equity) * Decimal("100")
class RiskManager:
"""Manages pre-trade risk validation.
The RiskManager validates orders against configured risk limits
before allowing them to be submitted to brokers.
Example:
>>> risk_manager = RiskManager(
... position_limits=PositionLimits(max_position_size=1000),
... loss_limits=LossLimits(max_daily_loss=Decimal("500")),
... )
>>> result = risk_manager.validate_order(order_request, portfolio_state)
>>> if not result.passed:
... print(f"Order rejected: {result.rejection_message}")
"""
def __init__(
self,
position_limits: Optional[PositionLimits] = None,
loss_limits: Optional[LossLimits] = None,
custom_rules: Optional[List[Callable]] = None,
enabled: bool = True,
) -> None:
"""Initialize risk manager.
Args:
position_limits: Position-related limits
loss_limits: Loss-related limits
custom_rules: Custom validation functions
enabled: Whether risk checks are enabled
"""
self._position_limits = position_limits or PositionLimits()
self._loss_limits = loss_limits or LossLimits()
self._custom_rules = custom_rules or []
self._enabled = enabled
self._daily_pnl_by_date: Dict[date, Decimal] = {}
self._peak_equity: Optional[Decimal] = None
self._in_cooling_off = False
self._cooling_off_until: Optional[datetime] = None
@property
def enabled(self) -> bool:
"""Check if risk manager is enabled."""
return self._enabled
@enabled.setter
def enabled(self, value: bool) -> None:
"""Enable or disable risk manager."""
self._enabled = value
@property
def position_limits(self) -> PositionLimits:
"""Get position limits."""
return self._position_limits
@position_limits.setter
def position_limits(self, limits: PositionLimits) -> None:
"""Set position limits."""
self._position_limits = limits
@property
def loss_limits(self) -> LossLimits:
"""Get loss limits."""
return self._loss_limits
@loss_limits.setter
def loss_limits(self, limits: LossLimits) -> None:
"""Set loss limits."""
self._loss_limits = limits
def add_custom_rule(
self,
rule: Callable[[OrderRequest, PortfolioState], Optional[RiskViolation]],
) -> None:
"""Add a custom validation rule.
Args:
rule: Function that takes order and portfolio, returns violation or None
"""
self._custom_rules.append(rule)
def validate_order(
self,
order: OrderRequest,
portfolio: PortfolioState,
estimated_fill_price: Optional[Decimal] = None,
) -> RiskCheckResponse:
"""Validate an order against all risk limits.
Args:
order: Order request to validate
portfolio: Current portfolio state
estimated_fill_price: Expected fill price for value calculations
Returns:
RiskCheckResponse with validation results
"""
response = RiskCheckResponse()
if not self._enabled:
response.checked_rules.append("(risk checks disabled)")
return response
# Check cooling off period
if self._in_cooling_off and self._cooling_off_until:
if datetime.now(timezone.utc) < self._cooling_off_until:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.DAILY_LOSS,
rule_name="cooling_off_period",
message=f"In cooling off period until {self._cooling_off_until}",
current_value=Decimal("0"),
limit_value=Decimal("0"),
)
)
return response
else:
self._in_cooling_off = False
self._cooling_off_until = None
# Run all checks
self._check_position_size(order, portfolio, response)
self._check_position_value(order, portfolio, response, estimated_fill_price)
self._check_concentration(order, portfolio, response, estimated_fill_price)
self._check_total_positions(order, portfolio, response)
self._check_daily_loss(portfolio, response)
self._check_drawdown(portfolio, response)
self._check_single_trade_loss(order, portfolio, response, estimated_fill_price)
self._check_consecutive_losses(portfolio, response)
self._run_custom_rules(order, portfolio, response)
return response
def _check_position_size(
self,
order: OrderRequest,
portfolio: PortfolioState,
response: RiskCheckResponse,
) -> None:
"""Check position size limits."""
response.checked_rules.append("position_size")
limit = self._position_limits.get_limit_for_symbol(
order.symbol, "max_position_size"
)
if limit is None:
return
current_position = portfolio.positions.get(order.symbol)
current_qty = current_position.quantity if current_position else Decimal("0")
if order.side == OrderSide.BUY:
new_qty = current_qty + order.quantity
else:
new_qty = current_qty - order.quantity
if abs(new_qty) > limit:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.POSITION_SIZE,
rule_name="max_position_size",
message=(
f"Position size {abs(new_qty)} exceeds limit {limit} "
f"for {order.symbol}"
),
current_value=abs(new_qty),
limit_value=limit,
)
)
def _check_position_value(
self,
order: OrderRequest,
portfolio: PortfolioState,
response: RiskCheckResponse,
estimated_price: Optional[Decimal],
) -> None:
"""Check position value limits."""
response.checked_rules.append("position_value")
limit = self._position_limits.get_limit_for_symbol(
order.symbol, "max_position_value"
)
if limit is None:
return
if estimated_price is None:
# Can't check without price
return
current_position = portfolio.positions.get(order.symbol)
current_qty = current_position.quantity if current_position else Decimal("0")
if order.side == OrderSide.BUY:
new_qty = current_qty + order.quantity
else:
new_qty = current_qty - order.quantity
new_value = abs(new_qty * estimated_price)
if new_value > limit:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.POSITION_VALUE,
rule_name="max_position_value",
message=(
f"Position value ${new_value} exceeds limit ${limit} "
f"for {order.symbol}"
),
current_value=new_value,
limit_value=limit,
)
)
def _check_concentration(
self,
order: OrderRequest,
portfolio: PortfolioState,
response: RiskCheckResponse,
estimated_price: Optional[Decimal],
) -> None:
"""Check concentration limits."""
response.checked_rules.append("concentration")
limit = self._position_limits.max_concentration_percent
if limit is None:
return
if estimated_price is None or portfolio.equity <= 0:
return
current_position = portfolio.positions.get(order.symbol)
current_qty = current_position.quantity if current_position else Decimal("0")
if order.side == OrderSide.BUY:
new_qty = current_qty + order.quantity
else:
new_qty = current_qty - order.quantity
new_value = abs(new_qty * estimated_price)
concentration = (new_value / portfolio.equity) * Decimal("100")
if concentration > limit:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.CONCENTRATION,
rule_name="max_concentration",
message=(
f"Concentration {concentration:.1f}% exceeds limit {limit}% "
f"for {order.symbol}"
),
current_value=concentration,
limit_value=limit,
)
)
def _check_total_positions(
self,
order: OrderRequest,
portfolio: PortfolioState,
response: RiskCheckResponse,
) -> None:
"""Check total positions limit."""
response.checked_rules.append("total_positions")
limit = self._position_limits.max_total_positions
if limit is None:
return
# Only check for new positions (buys in symbols we don't have)
if order.side != OrderSide.BUY:
return
if order.symbol in portfolio.positions:
return # Adding to existing position
if len(portfolio.positions) >= limit:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.POSITION_SIZE,
rule_name="max_total_positions",
message=(
f"Total positions {len(portfolio.positions)} at limit {limit}"
),
current_value=Decimal(str(len(portfolio.positions))),
limit_value=Decimal(str(limit)),
)
)
def _check_daily_loss(
self,
portfolio: PortfolioState,
response: RiskCheckResponse,
) -> None:
"""Check daily loss limits."""
response.checked_rules.append("daily_loss")
# Check absolute daily loss
if self._loss_limits.max_daily_loss is not None:
if portfolio.daily_pnl < -self._loss_limits.max_daily_loss:
self._trigger_cooling_off()
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.DAILY_LOSS,
rule_name="max_daily_loss",
message=(
f"Daily loss ${abs(portfolio.daily_pnl)} exceeds limit "
f"${self._loss_limits.max_daily_loss}"
),
current_value=abs(portfolio.daily_pnl),
limit_value=self._loss_limits.max_daily_loss,
)
)
# Check percentage daily loss
if self._loss_limits.max_daily_loss_percent is not None:
if portfolio.equity > 0:
daily_loss_pct = (
abs(portfolio.daily_pnl) / portfolio.equity
) * Decimal("100")
if (
portfolio.daily_pnl < 0
and daily_loss_pct > self._loss_limits.max_daily_loss_percent
):
self._trigger_cooling_off()
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.DAILY_LOSS,
rule_name="max_daily_loss_percent",
message=(
f"Daily loss {daily_loss_pct:.1f}% exceeds limit "
f"{self._loss_limits.max_daily_loss_percent}%"
),
current_value=daily_loss_pct,
limit_value=self._loss_limits.max_daily_loss_percent,
)
)
def _check_drawdown(
self,
portfolio: PortfolioState,
response: RiskCheckResponse,
) -> None:
"""Check drawdown limits."""
response.checked_rules.append("drawdown")
# Check absolute drawdown
if self._loss_limits.max_drawdown is not None:
if portfolio.current_drawdown > self._loss_limits.max_drawdown:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.DRAWDOWN,
rule_name="max_drawdown",
message=(
f"Drawdown ${portfolio.current_drawdown} exceeds limit "
f"${self._loss_limits.max_drawdown}"
),
current_value=portfolio.current_drawdown,
limit_value=self._loss_limits.max_drawdown,
)
)
# Check percentage drawdown
if self._loss_limits.max_drawdown_percent is not None:
if portfolio.current_drawdown_percent > self._loss_limits.max_drawdown_percent:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.DRAWDOWN,
rule_name="max_drawdown_percent",
message=(
f"Drawdown {portfolio.current_drawdown_percent:.1f}% "
f"exceeds limit {self._loss_limits.max_drawdown_percent}%"
),
current_value=portfolio.current_drawdown_percent,
limit_value=self._loss_limits.max_drawdown_percent,
)
)
def _check_single_trade_loss(
self,
order: OrderRequest,
portfolio: PortfolioState,
response: RiskCheckResponse,
estimated_price: Optional[Decimal],
) -> None:
"""Check single trade loss limit."""
response.checked_rules.append("single_trade_loss")
limit = self._loss_limits.max_single_trade_loss
if limit is None:
return
if estimated_price is None:
return
# Calculate max potential loss for the order
order_value = order.quantity * estimated_price
# For sells, potential loss is limited
# For buys, assume worst case is total loss of order value
if order.side == OrderSide.BUY:
potential_loss = order_value
if potential_loss > limit:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.DAILY_LOSS,
rule_name="max_single_trade_loss",
message=(
f"Potential loss ${potential_loss} exceeds limit "
f"${limit}"
),
current_value=potential_loss,
limit_value=limit,
severity="warning", # Warning, not blocking
)
)
def _check_consecutive_losses(
self,
portfolio: PortfolioState,
response: RiskCheckResponse,
) -> None:
"""Check consecutive losses limit."""
response.checked_rules.append("consecutive_losses")
limit = self._loss_limits.max_consecutive_losses
if limit is None:
return
if portfolio.consecutive_losses >= limit:
self._trigger_cooling_off()
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.DAILY_LOSS,
rule_name="max_consecutive_losses",
message=(
f"Consecutive losses {portfolio.consecutive_losses} "
f"reached limit {limit}"
),
current_value=Decimal(str(portfolio.consecutive_losses)),
limit_value=Decimal(str(limit)),
)
)
def _run_custom_rules(
self,
order: OrderRequest,
portfolio: PortfolioState,
response: RiskCheckResponse,
) -> None:
"""Run custom validation rules."""
for i, rule in enumerate(self._custom_rules):
response.checked_rules.append(f"custom_rule_{i}")
try:
violation = rule(order, portfolio)
if violation:
response.add_violation(violation)
except Exception:
# Don't let custom rule errors break validation
pass
def _trigger_cooling_off(self) -> None:
"""Trigger cooling off period."""
if self._loss_limits.cooling_off_period_minutes > 0:
self._in_cooling_off = True
from datetime import timedelta
self._cooling_off_until = datetime.now(timezone.utc) + timedelta(
minutes=self._loss_limits.cooling_off_period_minutes
)
def update_daily_pnl(self, pnl: Decimal, trade_date: date) -> None:
"""Update daily P&L tracking.
Args:
pnl: P&L for the trade
trade_date: Date of the trade
"""
if trade_date not in self._daily_pnl_by_date:
self._daily_pnl_by_date[trade_date] = Decimal("0")
self._daily_pnl_by_date[trade_date] += pnl
def get_daily_pnl(self, trade_date: date) -> Decimal:
"""Get daily P&L for a date.
Args:
trade_date: Date to get P&L for
Returns:
P&L for the date
"""
return self._daily_pnl_by_date.get(trade_date, Decimal("0"))
def update_peak_equity(self, equity: Decimal) -> None:
"""Update peak equity tracking.
Args:
equity: Current equity
"""
if self._peak_equity is None or equity > self._peak_equity:
self._peak_equity = equity
def reset_daily_limits(self) -> None:
"""Reset daily tracking (call at start of each trading day)."""
self._in_cooling_off = False
self._cooling_off_until = None
def reset_all(self) -> None:
"""Reset all tracking state."""
self._daily_pnl_by_date.clear()
self._peak_equity = None
self._in_cooling_off = False
self._cooling_off_until = None
# Export
__all__ = [
"RiskManager",
"RiskCheckResult",
"RiskRuleType",
"RiskViolation",
"RiskCheckResponse",
"PositionLimits",
"LossLimits",
"PortfolioState",
]