feat(execution): add Risk Controls for pre-trade validation - Issue #28 (45 tests)
This commit is contained in:
parent
6863e3ed87
commit
9aee43312d
|
|
@ -0,0 +1,678 @@
|
|||
"""Tests for Risk Controls implementation.
|
||||
|
||||
Issue #28: [EXEC-27] Risk controls - position limits, loss limits
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timezone, timedelta, date
|
||||
import pytest
|
||||
|
||||
from tradingagents.execution import (
|
||||
RiskManager,
|
||||
RiskCheckResult,
|
||||
RiskRuleType,
|
||||
RiskViolation,
|
||||
RiskCheckResponse,
|
||||
PositionLimits,
|
||||
LossLimits,
|
||||
PortfolioState,
|
||||
OrderRequest,
|
||||
OrderSide,
|
||||
Position,
|
||||
PositionSide,
|
||||
)
|
||||
|
||||
|
||||
class TestRiskViolation:
|
||||
"""Test RiskViolation dataclass."""
|
||||
|
||||
def test_violation_creation(self):
|
||||
"""Test creating a violation."""
|
||||
violation = RiskViolation(
|
||||
rule_type=RiskRuleType.POSITION_SIZE,
|
||||
rule_name="max_position_size",
|
||||
message="Position too large",
|
||||
current_value=Decimal("1500"),
|
||||
limit_value=Decimal("1000"),
|
||||
)
|
||||
assert violation.rule_type == RiskRuleType.POSITION_SIZE
|
||||
assert violation.current_value == Decimal("1500")
|
||||
assert violation.severity == "error"
|
||||
|
||||
def test_violation_with_warning_severity(self):
|
||||
"""Test violation with warning severity."""
|
||||
violation = RiskViolation(
|
||||
rule_type=RiskRuleType.DAILY_LOSS,
|
||||
rule_name="potential_loss",
|
||||
message="Potential loss warning",
|
||||
current_value=Decimal("100"),
|
||||
limit_value=Decimal("200"),
|
||||
severity="warning",
|
||||
)
|
||||
assert violation.severity == "warning"
|
||||
|
||||
|
||||
class TestRiskCheckResponse:
|
||||
"""Test RiskCheckResponse dataclass."""
|
||||
|
||||
def test_default_passed(self):
|
||||
"""Test default response is passed."""
|
||||
response = RiskCheckResponse()
|
||||
assert response.passed is True
|
||||
assert response.violations == []
|
||||
assert response.warnings == []
|
||||
|
||||
def test_add_violation(self):
|
||||
"""Test adding violation fails response."""
|
||||
response = RiskCheckResponse()
|
||||
violation = RiskViolation(
|
||||
rule_type=RiskRuleType.POSITION_SIZE,
|
||||
rule_name="test",
|
||||
message="Test violation",
|
||||
current_value=Decimal("100"),
|
||||
limit_value=Decimal("50"),
|
||||
)
|
||||
response.add_violation(violation)
|
||||
assert response.passed is False
|
||||
assert len(response.violations) == 1
|
||||
|
||||
def test_add_warning(self):
|
||||
"""Test adding warning keeps passed."""
|
||||
response = RiskCheckResponse()
|
||||
violation = RiskViolation(
|
||||
rule_type=RiskRuleType.DAILY_LOSS,
|
||||
rule_name="test",
|
||||
message="Test warning",
|
||||
current_value=Decimal("100"),
|
||||
limit_value=Decimal("200"),
|
||||
severity="warning",
|
||||
)
|
||||
response.add_violation(violation)
|
||||
assert response.passed is True
|
||||
assert len(response.warnings) == 1
|
||||
|
||||
def test_rejection_message(self):
|
||||
"""Test rejection message formatting."""
|
||||
response = RiskCheckResponse()
|
||||
violation = RiskViolation(
|
||||
rule_type=RiskRuleType.POSITION_SIZE,
|
||||
rule_name="test",
|
||||
message="Position too large",
|
||||
current_value=Decimal("100"),
|
||||
limit_value=Decimal("50"),
|
||||
)
|
||||
response.add_violation(violation)
|
||||
assert "Position too large" in response.rejection_message
|
||||
|
||||
def test_rejection_message_none_when_passed(self):
|
||||
"""Test rejection message is None when passed."""
|
||||
response = RiskCheckResponse()
|
||||
assert response.rejection_message is None
|
||||
|
||||
|
||||
class TestPositionLimits:
|
||||
"""Test PositionLimits dataclass."""
|
||||
|
||||
def test_default_limits(self):
|
||||
"""Test default limits are None."""
|
||||
limits = PositionLimits()
|
||||
assert limits.max_position_size is None
|
||||
assert limits.max_position_value is None
|
||||
assert limits.max_concentration_percent is None
|
||||
|
||||
def test_custom_limits(self):
|
||||
"""Test setting custom limits."""
|
||||
limits = PositionLimits(
|
||||
max_position_size=Decimal("1000"),
|
||||
max_position_value=Decimal("50000"),
|
||||
max_concentration_percent=Decimal("20"),
|
||||
)
|
||||
assert limits.max_position_size == Decimal("1000")
|
||||
assert limits.max_position_value == Decimal("50000")
|
||||
assert limits.max_concentration_percent == Decimal("20")
|
||||
|
||||
def test_per_symbol_limits(self):
|
||||
"""Test per-symbol limits."""
|
||||
limits = PositionLimits(
|
||||
max_position_size=Decimal("1000"),
|
||||
per_symbol_limits={
|
||||
"AAPL": {"max_position_size": Decimal("500")},
|
||||
},
|
||||
)
|
||||
assert limits.get_limit_for_symbol("AAPL", "max_position_size") == Decimal("500")
|
||||
assert limits.get_limit_for_symbol("MSFT", "max_position_size") == Decimal("1000")
|
||||
|
||||
|
||||
class TestLossLimits:
|
||||
"""Test LossLimits dataclass."""
|
||||
|
||||
def test_default_limits(self):
|
||||
"""Test default limits are None."""
|
||||
limits = LossLimits()
|
||||
assert limits.max_daily_loss is None
|
||||
assert limits.max_drawdown is None
|
||||
assert limits.cooling_off_period_minutes == 0
|
||||
|
||||
def test_custom_limits(self):
|
||||
"""Test setting custom limits."""
|
||||
limits = LossLimits(
|
||||
max_daily_loss=Decimal("500"),
|
||||
max_daily_loss_percent=Decimal("5"),
|
||||
max_drawdown_percent=Decimal("20"),
|
||||
cooling_off_period_minutes=30,
|
||||
)
|
||||
assert limits.max_daily_loss == Decimal("500")
|
||||
assert limits.max_daily_loss_percent == Decimal("5")
|
||||
assert limits.cooling_off_period_minutes == 30
|
||||
|
||||
|
||||
class TestPortfolioState:
|
||||
"""Test PortfolioState dataclass."""
|
||||
|
||||
def test_default_state(self):
|
||||
"""Test default state."""
|
||||
state = PortfolioState()
|
||||
assert state.cash == Decimal("0")
|
||||
assert state.equity == Decimal("0")
|
||||
assert state.daily_pnl == Decimal("0")
|
||||
|
||||
def test_drawdown_calculation(self):
|
||||
"""Test drawdown calculation."""
|
||||
state = PortfolioState(
|
||||
equity=Decimal("90000"),
|
||||
peak_equity=Decimal("100000"),
|
||||
)
|
||||
assert state.current_drawdown == Decimal("10000")
|
||||
assert state.current_drawdown_percent == Decimal("10")
|
||||
|
||||
def test_no_drawdown_without_peak(self):
|
||||
"""Test no drawdown without peak."""
|
||||
state = PortfolioState(equity=Decimal("100000"))
|
||||
assert state.current_drawdown == Decimal("0")
|
||||
assert state.current_drawdown_percent == Decimal("0")
|
||||
|
||||
|
||||
class TestRiskManagerInit:
|
||||
"""Test RiskManager initialization."""
|
||||
|
||||
def test_default_initialization(self):
|
||||
"""Test default initialization."""
|
||||
manager = RiskManager()
|
||||
assert manager.enabled is True
|
||||
assert manager.position_limits is not None
|
||||
assert manager.loss_limits is not None
|
||||
|
||||
def test_disabled_initialization(self):
|
||||
"""Test disabled initialization."""
|
||||
manager = RiskManager(enabled=False)
|
||||
assert manager.enabled is False
|
||||
|
||||
def test_with_limits(self):
|
||||
"""Test initialization with limits."""
|
||||
pos_limits = PositionLimits(max_position_size=Decimal("1000"))
|
||||
loss_limits = LossLimits(max_daily_loss=Decimal("500"))
|
||||
manager = RiskManager(
|
||||
position_limits=pos_limits,
|
||||
loss_limits=loss_limits,
|
||||
)
|
||||
assert manager.position_limits.max_position_size == Decimal("1000")
|
||||
assert manager.loss_limits.max_daily_loss == Decimal("500")
|
||||
|
||||
|
||||
class TestRiskManagerPositionLimits:
|
||||
"""Test RiskManager position limit checks."""
|
||||
|
||||
def test_position_size_within_limit(self):
|
||||
"""Test position size within limit passes."""
|
||||
manager = RiskManager(
|
||||
position_limits=PositionLimits(max_position_size=Decimal("1000"))
|
||||
)
|
||||
portfolio = PortfolioState()
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("500"))
|
||||
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is True
|
||||
|
||||
def test_position_size_exceeds_limit(self):
|
||||
"""Test position size exceeding limit fails."""
|
||||
manager = RiskManager(
|
||||
position_limits=PositionLimits(max_position_size=Decimal("1000"))
|
||||
)
|
||||
portfolio = PortfolioState()
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("1500"))
|
||||
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is False
|
||||
assert any(v.rule_name == "max_position_size" for v in result.violations)
|
||||
|
||||
def test_position_size_with_existing_position(self):
|
||||
"""Test position size check with existing position."""
|
||||
manager = RiskManager(
|
||||
position_limits=PositionLimits(max_position_size=Decimal("1000"))
|
||||
)
|
||||
portfolio = PortfolioState(
|
||||
positions={
|
||||
"AAPL": Position(
|
||||
symbol="AAPL",
|
||||
quantity=Decimal("600"),
|
||||
side=PositionSide.LONG,
|
||||
avg_entry_price=Decimal("100"),
|
||||
current_price=Decimal("100"),
|
||||
market_value=Decimal("60000"),
|
||||
cost_basis=Decimal("60000"),
|
||||
unrealized_pnl=Decimal("0"),
|
||||
unrealized_pnl_percent=Decimal("0"),
|
||||
)
|
||||
}
|
||||
)
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("500"))
|
||||
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is False # 600 + 500 = 1100 > 1000
|
||||
|
||||
def test_position_value_limit(self):
|
||||
"""Test position value limit."""
|
||||
manager = RiskManager(
|
||||
position_limits=PositionLimits(max_position_value=Decimal("50000"))
|
||||
)
|
||||
portfolio = PortfolioState()
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("1000"))
|
||||
|
||||
# At $100, value is $100,000 which exceeds $50,000 limit
|
||||
result = manager.validate_order(order, portfolio, estimated_fill_price=Decimal("100"))
|
||||
assert result.passed is False
|
||||
assert any(v.rule_name == "max_position_value" for v in result.violations)
|
||||
|
||||
def test_concentration_limit(self):
|
||||
"""Test concentration limit."""
|
||||
manager = RiskManager(
|
||||
position_limits=PositionLimits(max_concentration_percent=Decimal("20"))
|
||||
)
|
||||
portfolio = PortfolioState(equity=Decimal("100000"))
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("300"))
|
||||
|
||||
# At $100, value is $30,000 = 30% of $100,000 equity
|
||||
result = manager.validate_order(order, portfolio, estimated_fill_price=Decimal("100"))
|
||||
assert result.passed is False
|
||||
assert any(v.rule_name == "max_concentration" for v in result.violations)
|
||||
|
||||
def test_total_positions_limit(self):
|
||||
"""Test total positions limit."""
|
||||
manager = RiskManager(
|
||||
position_limits=PositionLimits(max_total_positions=2)
|
||||
)
|
||||
portfolio = PortfolioState(
|
||||
positions={
|
||||
"AAPL": Position(
|
||||
symbol="AAPL",
|
||||
quantity=Decimal("100"),
|
||||
side=PositionSide.LONG,
|
||||
avg_entry_price=Decimal("100"),
|
||||
current_price=Decimal("100"),
|
||||
market_value=Decimal("10000"),
|
||||
cost_basis=Decimal("10000"),
|
||||
unrealized_pnl=Decimal("0"),
|
||||
unrealized_pnl_percent=Decimal("0"),
|
||||
),
|
||||
"MSFT": Position(
|
||||
symbol="MSFT",
|
||||
quantity=Decimal("100"),
|
||||
side=PositionSide.LONG,
|
||||
avg_entry_price=Decimal("100"),
|
||||
current_price=Decimal("100"),
|
||||
market_value=Decimal("10000"),
|
||||
cost_basis=Decimal("10000"),
|
||||
unrealized_pnl=Decimal("0"),
|
||||
unrealized_pnl_percent=Decimal("0"),
|
||||
),
|
||||
}
|
||||
)
|
||||
order = OrderRequest.market("GOOGL", OrderSide.BUY, Decimal("100"))
|
||||
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is False
|
||||
assert any(v.rule_name == "max_total_positions" for v in result.violations)
|
||||
|
||||
|
||||
class TestRiskManagerLossLimits:
|
||||
"""Test RiskManager loss limit checks."""
|
||||
|
||||
def test_daily_loss_within_limit(self):
|
||||
"""Test daily loss within limit passes."""
|
||||
manager = RiskManager(
|
||||
loss_limits=LossLimits(max_daily_loss=Decimal("1000"))
|
||||
)
|
||||
portfolio = PortfolioState(daily_pnl=Decimal("-500"))
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
|
||||
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is True
|
||||
|
||||
def test_daily_loss_exceeds_limit(self):
|
||||
"""Test daily loss exceeding limit fails."""
|
||||
manager = RiskManager(
|
||||
loss_limits=LossLimits(max_daily_loss=Decimal("1000"))
|
||||
)
|
||||
portfolio = PortfolioState(daily_pnl=Decimal("-1500"))
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
|
||||
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is False
|
||||
assert any(v.rule_name == "max_daily_loss" for v in result.violations)
|
||||
|
||||
def test_daily_loss_percent_limit(self):
|
||||
"""Test daily loss percentage limit."""
|
||||
manager = RiskManager(
|
||||
loss_limits=LossLimits(max_daily_loss_percent=Decimal("5"))
|
||||
)
|
||||
portfolio = PortfolioState(
|
||||
equity=Decimal("100000"),
|
||||
daily_pnl=Decimal("-6000"), # 6% loss
|
||||
)
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
|
||||
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is False
|
||||
assert any(v.rule_name == "max_daily_loss_percent" for v in result.violations)
|
||||
|
||||
def test_drawdown_limit(self):
|
||||
"""Test drawdown limit."""
|
||||
manager = RiskManager(
|
||||
loss_limits=LossLimits(max_drawdown=Decimal("15000"))
|
||||
)
|
||||
portfolio = PortfolioState(
|
||||
equity=Decimal("85000"),
|
||||
peak_equity=Decimal("100000"), # 15k drawdown
|
||||
)
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
|
||||
|
||||
# Exactly at limit should pass
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is True
|
||||
|
||||
# Beyond limit should fail
|
||||
portfolio.equity = Decimal("80000") # 20k drawdown
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is False
|
||||
assert any(v.rule_name == "max_drawdown" for v in result.violations)
|
||||
|
||||
def test_drawdown_percent_limit(self):
|
||||
"""Test drawdown percentage limit."""
|
||||
manager = RiskManager(
|
||||
loss_limits=LossLimits(max_drawdown_percent=Decimal("20"))
|
||||
)
|
||||
portfolio = PortfolioState(
|
||||
equity=Decimal("75000"),
|
||||
peak_equity=Decimal("100000"), # 25% drawdown
|
||||
)
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
|
||||
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is False
|
||||
assert any(v.rule_name == "max_drawdown_percent" for v in result.violations)
|
||||
|
||||
def test_consecutive_losses_limit(self):
|
||||
"""Test consecutive losses limit."""
|
||||
manager = RiskManager(
|
||||
loss_limits=LossLimits(max_consecutive_losses=5)
|
||||
)
|
||||
portfolio = PortfolioState(consecutive_losses=5)
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
|
||||
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is False
|
||||
assert any(v.rule_name == "max_consecutive_losses" for v in result.violations)
|
||||
|
||||
|
||||
class TestRiskManagerCoolingOff:
|
||||
"""Test RiskManager cooling off period."""
|
||||
|
||||
def test_cooling_off_triggered(self):
|
||||
"""Test cooling off is triggered on loss limit."""
|
||||
manager = RiskManager(
|
||||
loss_limits=LossLimits(
|
||||
max_daily_loss=Decimal("1000"),
|
||||
cooling_off_period_minutes=30,
|
||||
)
|
||||
)
|
||||
portfolio = PortfolioState(daily_pnl=Decimal("-1500"))
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
|
||||
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is False
|
||||
assert manager._in_cooling_off is True
|
||||
|
||||
def test_order_blocked_during_cooling_off(self):
|
||||
"""Test orders blocked during cooling off."""
|
||||
manager = RiskManager(
|
||||
loss_limits=LossLimits(
|
||||
max_daily_loss=Decimal("1000"),
|
||||
cooling_off_period_minutes=30,
|
||||
)
|
||||
)
|
||||
# Trigger cooling off
|
||||
portfolio_loss = PortfolioState(daily_pnl=Decimal("-1500"))
|
||||
manager.validate_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")),
|
||||
portfolio_loss,
|
||||
)
|
||||
|
||||
# Try another order
|
||||
portfolio_ok = PortfolioState(daily_pnl=Decimal("0"))
|
||||
result = manager.validate_order(
|
||||
OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100")),
|
||||
portfolio_ok,
|
||||
)
|
||||
|
||||
assert result.passed is False
|
||||
assert any(v.rule_name == "cooling_off_period" for v in result.violations)
|
||||
|
||||
def test_cooling_off_reset(self):
|
||||
"""Test cooling off can be reset."""
|
||||
manager = RiskManager(
|
||||
loss_limits=LossLimits(
|
||||
max_daily_loss=Decimal("1000"),
|
||||
cooling_off_period_minutes=30,
|
||||
)
|
||||
)
|
||||
manager._in_cooling_off = True
|
||||
manager._cooling_off_until = datetime.now(timezone.utc)
|
||||
|
||||
manager.reset_daily_limits()
|
||||
|
||||
assert manager._in_cooling_off is False
|
||||
assert manager._cooling_off_until is None
|
||||
|
||||
|
||||
class TestRiskManagerDisabled:
|
||||
"""Test RiskManager when disabled."""
|
||||
|
||||
def test_disabled_passes_all(self):
|
||||
"""Test disabled manager passes all orders."""
|
||||
manager = RiskManager(
|
||||
position_limits=PositionLimits(max_position_size=Decimal("100")),
|
||||
enabled=False,
|
||||
)
|
||||
portfolio = PortfolioState()
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10000"))
|
||||
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is True
|
||||
|
||||
def test_enable_disable(self):
|
||||
"""Test enabling and disabling."""
|
||||
manager = RiskManager()
|
||||
assert manager.enabled is True
|
||||
|
||||
manager.enabled = False
|
||||
assert manager.enabled is False
|
||||
|
||||
manager.enabled = True
|
||||
assert manager.enabled is True
|
||||
|
||||
|
||||
class TestRiskManagerCustomRules:
|
||||
"""Test RiskManager custom rules."""
|
||||
|
||||
def test_custom_rule_passing(self):
|
||||
"""Test custom rule that passes."""
|
||||
def custom_rule(order, portfolio):
|
||||
return None # Pass
|
||||
|
||||
manager = RiskManager(custom_rules=[custom_rule])
|
||||
portfolio = PortfolioState()
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
|
||||
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is True
|
||||
assert "custom_rule_0" in result.checked_rules
|
||||
|
||||
def test_custom_rule_failing(self):
|
||||
"""Test custom rule that fails."""
|
||||
def custom_rule(order, portfolio):
|
||||
return RiskViolation(
|
||||
rule_type=RiskRuleType.CUSTOM,
|
||||
rule_name="custom_test",
|
||||
message="Custom rule failed",
|
||||
current_value=Decimal("0"),
|
||||
limit_value=Decimal("0"),
|
||||
)
|
||||
|
||||
manager = RiskManager(custom_rules=[custom_rule])
|
||||
portfolio = PortfolioState()
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
|
||||
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is False
|
||||
assert any(v.rule_name == "custom_test" for v in result.violations)
|
||||
|
||||
def test_custom_rule_error_handled(self):
|
||||
"""Test custom rule error doesn't break validation."""
|
||||
def bad_rule(order, portfolio):
|
||||
raise Exception("Rule error")
|
||||
|
||||
manager = RiskManager(custom_rules=[bad_rule])
|
||||
portfolio = PortfolioState()
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("100"))
|
||||
|
||||
# Should not raise
|
||||
result = manager.validate_order(order, portfolio)
|
||||
assert result.passed is True
|
||||
|
||||
def test_add_custom_rule(self):
|
||||
"""Test adding custom rule after init."""
|
||||
manager = RiskManager()
|
||||
|
||||
def custom_rule(order, portfolio):
|
||||
return None
|
||||
|
||||
manager.add_custom_rule(custom_rule)
|
||||
assert len(manager._custom_rules) == 1
|
||||
|
||||
|
||||
class TestRiskManagerTracking:
|
||||
"""Test RiskManager tracking methods."""
|
||||
|
||||
def test_update_daily_pnl(self):
|
||||
"""Test updating daily P&L."""
|
||||
manager = RiskManager()
|
||||
today = date.today()
|
||||
|
||||
manager.update_daily_pnl(Decimal("100"), today)
|
||||
assert manager.get_daily_pnl(today) == Decimal("100")
|
||||
|
||||
manager.update_daily_pnl(Decimal("50"), today)
|
||||
assert manager.get_daily_pnl(today) == Decimal("150")
|
||||
|
||||
def test_update_peak_equity(self):
|
||||
"""Test updating peak equity."""
|
||||
manager = RiskManager()
|
||||
|
||||
manager.update_peak_equity(Decimal("100000"))
|
||||
assert manager._peak_equity == Decimal("100000")
|
||||
|
||||
# Higher should update
|
||||
manager.update_peak_equity(Decimal("110000"))
|
||||
assert manager._peak_equity == Decimal("110000")
|
||||
|
||||
# Lower should not update
|
||||
manager.update_peak_equity(Decimal("105000"))
|
||||
assert manager._peak_equity == Decimal("110000")
|
||||
|
||||
def test_reset_all(self):
|
||||
"""Test resetting all state."""
|
||||
manager = RiskManager()
|
||||
manager.update_daily_pnl(Decimal("100"), date.today())
|
||||
manager.update_peak_equity(Decimal("100000"))
|
||||
manager._in_cooling_off = True
|
||||
|
||||
manager.reset_all()
|
||||
|
||||
assert manager.get_daily_pnl(date.today()) == Decimal("0")
|
||||
assert manager._peak_equity is None
|
||||
assert manager._in_cooling_off is False
|
||||
|
||||
|
||||
class TestRiskManagerRuleChecking:
|
||||
"""Test that all rules are checked."""
|
||||
|
||||
def test_all_rules_checked(self):
|
||||
"""Test all rules appear in checked list."""
|
||||
manager = RiskManager(
|
||||
position_limits=PositionLimits(
|
||||
max_position_size=Decimal("1000"),
|
||||
max_position_value=Decimal("50000"),
|
||||
max_concentration_percent=Decimal("20"),
|
||||
max_total_positions=10,
|
||||
),
|
||||
loss_limits=LossLimits(
|
||||
max_daily_loss=Decimal("1000"),
|
||||
max_drawdown=Decimal("10000"),
|
||||
max_single_trade_loss=Decimal("500"),
|
||||
max_consecutive_losses=5,
|
||||
),
|
||||
)
|
||||
portfolio = PortfolioState(equity=Decimal("100000"))
|
||||
order = OrderRequest.market("AAPL", OrderSide.BUY, Decimal("10"))
|
||||
|
||||
result = manager.validate_order(order, portfolio, estimated_fill_price=Decimal("100"))
|
||||
|
||||
expected_rules = [
|
||||
"position_size",
|
||||
"position_value",
|
||||
"concentration",
|
||||
"total_positions",
|
||||
"daily_loss",
|
||||
"drawdown",
|
||||
"single_trade_loss",
|
||||
"consecutive_losses",
|
||||
]
|
||||
for rule in expected_rules:
|
||||
assert rule in result.checked_rules
|
||||
|
||||
|
||||
class TestRiskRuleTypeEnum:
|
||||
"""Test RiskRuleType enum."""
|
||||
|
||||
def test_all_types_defined(self):
|
||||
"""Test all expected types are defined."""
|
||||
expected = [
|
||||
"POSITION_SIZE",
|
||||
"POSITION_VALUE",
|
||||
"CONCENTRATION",
|
||||
"DAILY_LOSS",
|
||||
"DRAWDOWN",
|
||||
"CUSTOM",
|
||||
]
|
||||
for type_name in expected:
|
||||
assert hasattr(RiskRuleType, type_name)
|
||||
|
||||
|
||||
class TestRiskCheckResultEnum:
|
||||
"""Test RiskCheckResult enum."""
|
||||
|
||||
def test_all_results_defined(self):
|
||||
"""Test all expected results are defined."""
|
||||
expected = ["PASSED", "FAILED", "WARNING", "SKIPPED"]
|
||||
for result_name in expected:
|
||||
assert hasattr(RiskCheckResult, result_name)
|
||||
|
|
@ -124,6 +124,17 @@ from .order_manager import (
|
|||
OPEN_STATES,
|
||||
)
|
||||
|
||||
from .risk_controls import (
|
||||
RiskManager,
|
||||
RiskCheckResult,
|
||||
RiskRuleType,
|
||||
RiskViolation,
|
||||
RiskCheckResponse,
|
||||
PositionLimits,
|
||||
LossLimits,
|
||||
PortfolioState,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Enums
|
||||
"AssetClass",
|
||||
|
|
@ -176,4 +187,13 @@ __all__ = [
|
|||
"VALID_TRANSITIONS",
|
||||
"TERMINAL_STATES",
|
||||
"OPEN_STATES",
|
||||
# Risk Controls
|
||||
"RiskManager",
|
||||
"RiskCheckResult",
|
||||
"RiskRuleType",
|
||||
"RiskViolation",
|
||||
"RiskCheckResponse",
|
||||
"PositionLimits",
|
||||
"LossLimits",
|
||||
"PortfolioState",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,735 @@
|
|||
"""Risk Controls for order execution.
|
||||
|
||||
Issue #28: [EXEC-27] Risk controls - position limits, loss limits
|
||||
|
||||
This module provides pre-trade risk validation including:
|
||||
- Position size limits (max shares, max notional value)
|
||||
- Concentration limits (max % of portfolio in single position)
|
||||
- Daily loss limits
|
||||
- Drawdown limits
|
||||
- Pre-trade validation framework
|
||||
|
||||
Example:
|
||||
>>> from tradingagents.execution import RiskManager, PositionLimits, LossLimits
|
||||
>>>
|
||||
>>> limits = PositionLimits(
|
||||
... max_position_size=Decimal("10000"),
|
||||
... max_position_value=Decimal("50000"),
|
||||
... max_concentration_percent=Decimal("20"),
|
||||
... )
|
||||
>>> risk_manager = RiskManager(position_limits=limits)
|
||||
>>> result = risk_manager.check_order(order_request, portfolio)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone, date
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
|
||||
from .broker_base import (
|
||||
OrderRequest,
|
||||
OrderSide,
|
||||
Position,
|
||||
)
|
||||
|
||||
|
||||
class RiskCheckResult(Enum):
|
||||
"""Result of a risk check."""
|
||||
|
||||
PASSED = "passed"
|
||||
FAILED = "failed"
|
||||
WARNING = "warning"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class RiskRuleType(Enum):
|
||||
"""Type of risk rule."""
|
||||
|
||||
POSITION_SIZE = "position_size"
|
||||
POSITION_VALUE = "position_value"
|
||||
CONCENTRATION = "concentration"
|
||||
DAILY_LOSS = "daily_loss"
|
||||
DRAWDOWN = "drawdown"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RiskViolation:
|
||||
"""Details of a risk limit violation.
|
||||
|
||||
Attributes:
|
||||
rule_type: Type of rule violated
|
||||
rule_name: Name of the specific rule
|
||||
message: Human-readable violation message
|
||||
current_value: Current value that violated the limit
|
||||
limit_value: The limit that was exceeded
|
||||
severity: Violation severity (error, warning)
|
||||
metadata: Additional context
|
||||
"""
|
||||
|
||||
rule_type: RiskRuleType
|
||||
rule_name: str
|
||||
message: str
|
||||
current_value: Decimal
|
||||
limit_value: Decimal
|
||||
severity: str = "error"
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RiskCheckResponse:
|
||||
"""Response from risk validation.
|
||||
|
||||
Attributes:
|
||||
passed: Whether all risk checks passed
|
||||
violations: List of rule violations
|
||||
warnings: List of warnings (non-blocking)
|
||||
checked_rules: List of rules that were checked
|
||||
timestamp: When the check was performed
|
||||
"""
|
||||
|
||||
passed: bool = True
|
||||
violations: List[RiskViolation] = field(default_factory=list)
|
||||
warnings: List[RiskViolation] = field(default_factory=list)
|
||||
checked_rules: List[str] = field(default_factory=list)
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def add_violation(self, violation: RiskViolation) -> None:
|
||||
"""Add a violation to the response."""
|
||||
if violation.severity == "warning":
|
||||
self.warnings.append(violation)
|
||||
else:
|
||||
self.violations.append(violation)
|
||||
self.passed = False
|
||||
|
||||
@property
|
||||
def rejection_message(self) -> Optional[str]:
|
||||
"""Get formatted rejection message if failed."""
|
||||
if self.passed:
|
||||
return None
|
||||
messages = [v.message for v in self.violations]
|
||||
return "; ".join(messages)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PositionLimits:
|
||||
"""Position-related risk limits.
|
||||
|
||||
Attributes:
|
||||
max_position_size: Maximum shares/units in single position
|
||||
max_position_value: Maximum notional value of single position
|
||||
max_concentration_percent: Maximum % of portfolio in single position
|
||||
max_total_positions: Maximum number of open positions
|
||||
max_sector_concentration: Maximum % in single sector
|
||||
per_symbol_limits: Custom limits per symbol
|
||||
"""
|
||||
|
||||
max_position_size: Optional[Decimal] = None
|
||||
max_position_value: Optional[Decimal] = None
|
||||
max_concentration_percent: Optional[Decimal] = None
|
||||
max_total_positions: Optional[int] = None
|
||||
max_sector_concentration: Optional[Decimal] = None
|
||||
per_symbol_limits: Dict[str, Dict[str, Decimal]] = field(default_factory=dict)
|
||||
|
||||
def get_limit_for_symbol(
|
||||
self, symbol: str, limit_type: str
|
||||
) -> Optional[Decimal]:
|
||||
"""Get specific limit for a symbol, falling back to default."""
|
||||
if symbol in self.per_symbol_limits:
|
||||
if limit_type in self.per_symbol_limits[symbol]:
|
||||
return self.per_symbol_limits[symbol][limit_type]
|
||||
|
||||
return getattr(self, limit_type, None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LossLimits:
|
||||
"""Loss-related risk limits.
|
||||
|
||||
Attributes:
|
||||
max_daily_loss: Maximum loss allowed per day
|
||||
max_daily_loss_percent: Maximum loss as % of equity per day
|
||||
max_drawdown: Maximum drawdown from peak
|
||||
max_drawdown_percent: Maximum drawdown as % of peak
|
||||
max_single_trade_loss: Maximum loss on single trade
|
||||
max_consecutive_losses: Maximum consecutive losing trades
|
||||
cooling_off_period_minutes: Minutes to wait after hitting limit
|
||||
"""
|
||||
|
||||
max_daily_loss: Optional[Decimal] = None
|
||||
max_daily_loss_percent: Optional[Decimal] = None
|
||||
max_drawdown: Optional[Decimal] = None
|
||||
max_drawdown_percent: Optional[Decimal] = None
|
||||
max_single_trade_loss: Optional[Decimal] = None
|
||||
max_consecutive_losses: Optional[int] = None
|
||||
cooling_off_period_minutes: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PortfolioState:
|
||||
"""Current portfolio state for risk calculations.
|
||||
|
||||
Attributes:
|
||||
positions: Current positions keyed by symbol
|
||||
cash: Available cash
|
||||
equity: Total portfolio equity
|
||||
buying_power: Available buying power
|
||||
daily_pnl: Profit/loss for current day
|
||||
peak_equity: Peak equity for drawdown calculations
|
||||
consecutive_losses: Current consecutive losing trades
|
||||
last_loss_time: When last loss occurred
|
||||
"""
|
||||
|
||||
positions: Dict[str, Position] = field(default_factory=dict)
|
||||
cash: Decimal = Decimal("0")
|
||||
equity: Decimal = Decimal("0")
|
||||
buying_power: Decimal = Decimal("0")
|
||||
daily_pnl: Decimal = Decimal("0")
|
||||
peak_equity: Optional[Decimal] = None
|
||||
consecutive_losses: int = 0
|
||||
last_loss_time: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def current_drawdown(self) -> Decimal:
|
||||
"""Calculate current drawdown from peak."""
|
||||
if self.peak_equity is None or self.peak_equity <= 0:
|
||||
return Decimal("0")
|
||||
return self.peak_equity - self.equity
|
||||
|
||||
@property
|
||||
def current_drawdown_percent(self) -> Decimal:
|
||||
"""Calculate current drawdown as percentage."""
|
||||
if self.peak_equity is None or self.peak_equity <= 0:
|
||||
return Decimal("0")
|
||||
return (self.current_drawdown / self.peak_equity) * Decimal("100")
|
||||
|
||||
|
||||
class RiskManager:
|
||||
"""Manages pre-trade risk validation.
|
||||
|
||||
The RiskManager validates orders against configured risk limits
|
||||
before allowing them to be submitted to brokers.
|
||||
|
||||
Example:
|
||||
>>> risk_manager = RiskManager(
|
||||
... position_limits=PositionLimits(max_position_size=1000),
|
||||
... loss_limits=LossLimits(max_daily_loss=Decimal("500")),
|
||||
... )
|
||||
>>> result = risk_manager.validate_order(order_request, portfolio_state)
|
||||
>>> if not result.passed:
|
||||
... print(f"Order rejected: {result.rejection_message}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
position_limits: Optional[PositionLimits] = None,
|
||||
loss_limits: Optional[LossLimits] = None,
|
||||
custom_rules: Optional[List[Callable]] = None,
|
||||
enabled: bool = True,
|
||||
) -> None:
|
||||
"""Initialize risk manager.
|
||||
|
||||
Args:
|
||||
position_limits: Position-related limits
|
||||
loss_limits: Loss-related limits
|
||||
custom_rules: Custom validation functions
|
||||
enabled: Whether risk checks are enabled
|
||||
"""
|
||||
self._position_limits = position_limits or PositionLimits()
|
||||
self._loss_limits = loss_limits or LossLimits()
|
||||
self._custom_rules = custom_rules or []
|
||||
self._enabled = enabled
|
||||
self._daily_pnl_by_date: Dict[date, Decimal] = {}
|
||||
self._peak_equity: Optional[Decimal] = None
|
||||
self._in_cooling_off = False
|
||||
self._cooling_off_until: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Check if risk manager is enabled."""
|
||||
return self._enabled
|
||||
|
||||
@enabled.setter
|
||||
def enabled(self, value: bool) -> None:
|
||||
"""Enable or disable risk manager."""
|
||||
self._enabled = value
|
||||
|
||||
@property
|
||||
def position_limits(self) -> PositionLimits:
|
||||
"""Get position limits."""
|
||||
return self._position_limits
|
||||
|
||||
@position_limits.setter
|
||||
def position_limits(self, limits: PositionLimits) -> None:
|
||||
"""Set position limits."""
|
||||
self._position_limits = limits
|
||||
|
||||
@property
|
||||
def loss_limits(self) -> LossLimits:
|
||||
"""Get loss limits."""
|
||||
return self._loss_limits
|
||||
|
||||
@loss_limits.setter
|
||||
def loss_limits(self, limits: LossLimits) -> None:
|
||||
"""Set loss limits."""
|
||||
self._loss_limits = limits
|
||||
|
||||
def add_custom_rule(
|
||||
self,
|
||||
rule: Callable[[OrderRequest, PortfolioState], Optional[RiskViolation]],
|
||||
) -> None:
|
||||
"""Add a custom validation rule.
|
||||
|
||||
Args:
|
||||
rule: Function that takes order and portfolio, returns violation or None
|
||||
"""
|
||||
self._custom_rules.append(rule)
|
||||
|
||||
def validate_order(
|
||||
self,
|
||||
order: OrderRequest,
|
||||
portfolio: PortfolioState,
|
||||
estimated_fill_price: Optional[Decimal] = None,
|
||||
) -> RiskCheckResponse:
|
||||
"""Validate an order against all risk limits.
|
||||
|
||||
Args:
|
||||
order: Order request to validate
|
||||
portfolio: Current portfolio state
|
||||
estimated_fill_price: Expected fill price for value calculations
|
||||
|
||||
Returns:
|
||||
RiskCheckResponse with validation results
|
||||
"""
|
||||
response = RiskCheckResponse()
|
||||
|
||||
if not self._enabled:
|
||||
response.checked_rules.append("(risk checks disabled)")
|
||||
return response
|
||||
|
||||
# Check cooling off period
|
||||
if self._in_cooling_off and self._cooling_off_until:
|
||||
if datetime.now(timezone.utc) < self._cooling_off_until:
|
||||
response.add_violation(
|
||||
RiskViolation(
|
||||
rule_type=RiskRuleType.DAILY_LOSS,
|
||||
rule_name="cooling_off_period",
|
||||
message=f"In cooling off period until {self._cooling_off_until}",
|
||||
current_value=Decimal("0"),
|
||||
limit_value=Decimal("0"),
|
||||
)
|
||||
)
|
||||
return response
|
||||
else:
|
||||
self._in_cooling_off = False
|
||||
self._cooling_off_until = None
|
||||
|
||||
# Run all checks
|
||||
self._check_position_size(order, portfolio, response)
|
||||
self._check_position_value(order, portfolio, response, estimated_fill_price)
|
||||
self._check_concentration(order, portfolio, response, estimated_fill_price)
|
||||
self._check_total_positions(order, portfolio, response)
|
||||
self._check_daily_loss(portfolio, response)
|
||||
self._check_drawdown(portfolio, response)
|
||||
self._check_single_trade_loss(order, portfolio, response, estimated_fill_price)
|
||||
self._check_consecutive_losses(portfolio, response)
|
||||
self._run_custom_rules(order, portfolio, response)
|
||||
|
||||
return response
|
||||
|
||||
def _check_position_size(
|
||||
self,
|
||||
order: OrderRequest,
|
||||
portfolio: PortfolioState,
|
||||
response: RiskCheckResponse,
|
||||
) -> None:
|
||||
"""Check position size limits."""
|
||||
response.checked_rules.append("position_size")
|
||||
|
||||
limit = self._position_limits.get_limit_for_symbol(
|
||||
order.symbol, "max_position_size"
|
||||
)
|
||||
if limit is None:
|
||||
return
|
||||
|
||||
current_position = portfolio.positions.get(order.symbol)
|
||||
current_qty = current_position.quantity if current_position else Decimal("0")
|
||||
|
||||
if order.side == OrderSide.BUY:
|
||||
new_qty = current_qty + order.quantity
|
||||
else:
|
||||
new_qty = current_qty - order.quantity
|
||||
|
||||
if abs(new_qty) > limit:
|
||||
response.add_violation(
|
||||
RiskViolation(
|
||||
rule_type=RiskRuleType.POSITION_SIZE,
|
||||
rule_name="max_position_size",
|
||||
message=(
|
||||
f"Position size {abs(new_qty)} exceeds limit {limit} "
|
||||
f"for {order.symbol}"
|
||||
),
|
||||
current_value=abs(new_qty),
|
||||
limit_value=limit,
|
||||
)
|
||||
)
|
||||
|
||||
def _check_position_value(
|
||||
self,
|
||||
order: OrderRequest,
|
||||
portfolio: PortfolioState,
|
||||
response: RiskCheckResponse,
|
||||
estimated_price: Optional[Decimal],
|
||||
) -> None:
|
||||
"""Check position value limits."""
|
||||
response.checked_rules.append("position_value")
|
||||
|
||||
limit = self._position_limits.get_limit_for_symbol(
|
||||
order.symbol, "max_position_value"
|
||||
)
|
||||
if limit is None:
|
||||
return
|
||||
|
||||
if estimated_price is None:
|
||||
# Can't check without price
|
||||
return
|
||||
|
||||
current_position = portfolio.positions.get(order.symbol)
|
||||
current_qty = current_position.quantity if current_position else Decimal("0")
|
||||
|
||||
if order.side == OrderSide.BUY:
|
||||
new_qty = current_qty + order.quantity
|
||||
else:
|
||||
new_qty = current_qty - order.quantity
|
||||
|
||||
new_value = abs(new_qty * estimated_price)
|
||||
|
||||
if new_value > limit:
|
||||
response.add_violation(
|
||||
RiskViolation(
|
||||
rule_type=RiskRuleType.POSITION_VALUE,
|
||||
rule_name="max_position_value",
|
||||
message=(
|
||||
f"Position value ${new_value} exceeds limit ${limit} "
|
||||
f"for {order.symbol}"
|
||||
),
|
||||
current_value=new_value,
|
||||
limit_value=limit,
|
||||
)
|
||||
)
|
||||
|
||||
def _check_concentration(
|
||||
self,
|
||||
order: OrderRequest,
|
||||
portfolio: PortfolioState,
|
||||
response: RiskCheckResponse,
|
||||
estimated_price: Optional[Decimal],
|
||||
) -> None:
|
||||
"""Check concentration limits."""
|
||||
response.checked_rules.append("concentration")
|
||||
|
||||
limit = self._position_limits.max_concentration_percent
|
||||
if limit is None:
|
||||
return
|
||||
|
||||
if estimated_price is None or portfolio.equity <= 0:
|
||||
return
|
||||
|
||||
current_position = portfolio.positions.get(order.symbol)
|
||||
current_qty = current_position.quantity if current_position else Decimal("0")
|
||||
|
||||
if order.side == OrderSide.BUY:
|
||||
new_qty = current_qty + order.quantity
|
||||
else:
|
||||
new_qty = current_qty - order.quantity
|
||||
|
||||
new_value = abs(new_qty * estimated_price)
|
||||
concentration = (new_value / portfolio.equity) * Decimal("100")
|
||||
|
||||
if concentration > limit:
|
||||
response.add_violation(
|
||||
RiskViolation(
|
||||
rule_type=RiskRuleType.CONCENTRATION,
|
||||
rule_name="max_concentration",
|
||||
message=(
|
||||
f"Concentration {concentration:.1f}% exceeds limit {limit}% "
|
||||
f"for {order.symbol}"
|
||||
),
|
||||
current_value=concentration,
|
||||
limit_value=limit,
|
||||
)
|
||||
)
|
||||
|
||||
def _check_total_positions(
|
||||
self,
|
||||
order: OrderRequest,
|
||||
portfolio: PortfolioState,
|
||||
response: RiskCheckResponse,
|
||||
) -> None:
|
||||
"""Check total positions limit."""
|
||||
response.checked_rules.append("total_positions")
|
||||
|
||||
limit = self._position_limits.max_total_positions
|
||||
if limit is None:
|
||||
return
|
||||
|
||||
# Only check for new positions (buys in symbols we don't have)
|
||||
if order.side != OrderSide.BUY:
|
||||
return
|
||||
|
||||
if order.symbol in portfolio.positions:
|
||||
return # Adding to existing position
|
||||
|
||||
if len(portfolio.positions) >= limit:
|
||||
response.add_violation(
|
||||
RiskViolation(
|
||||
rule_type=RiskRuleType.POSITION_SIZE,
|
||||
rule_name="max_total_positions",
|
||||
message=(
|
||||
f"Total positions {len(portfolio.positions)} at limit {limit}"
|
||||
),
|
||||
current_value=Decimal(str(len(portfolio.positions))),
|
||||
limit_value=Decimal(str(limit)),
|
||||
)
|
||||
)
|
||||
|
||||
def _check_daily_loss(
|
||||
self,
|
||||
portfolio: PortfolioState,
|
||||
response: RiskCheckResponse,
|
||||
) -> None:
|
||||
"""Check daily loss limits."""
|
||||
response.checked_rules.append("daily_loss")
|
||||
|
||||
# Check absolute daily loss
|
||||
if self._loss_limits.max_daily_loss is not None:
|
||||
if portfolio.daily_pnl < -self._loss_limits.max_daily_loss:
|
||||
self._trigger_cooling_off()
|
||||
response.add_violation(
|
||||
RiskViolation(
|
||||
rule_type=RiskRuleType.DAILY_LOSS,
|
||||
rule_name="max_daily_loss",
|
||||
message=(
|
||||
f"Daily loss ${abs(portfolio.daily_pnl)} exceeds limit "
|
||||
f"${self._loss_limits.max_daily_loss}"
|
||||
),
|
||||
current_value=abs(portfolio.daily_pnl),
|
||||
limit_value=self._loss_limits.max_daily_loss,
|
||||
)
|
||||
)
|
||||
|
||||
# Check percentage daily loss
|
||||
if self._loss_limits.max_daily_loss_percent is not None:
|
||||
if portfolio.equity > 0:
|
||||
daily_loss_pct = (
|
||||
abs(portfolio.daily_pnl) / portfolio.equity
|
||||
) * Decimal("100")
|
||||
if (
|
||||
portfolio.daily_pnl < 0
|
||||
and daily_loss_pct > self._loss_limits.max_daily_loss_percent
|
||||
):
|
||||
self._trigger_cooling_off()
|
||||
response.add_violation(
|
||||
RiskViolation(
|
||||
rule_type=RiskRuleType.DAILY_LOSS,
|
||||
rule_name="max_daily_loss_percent",
|
||||
message=(
|
||||
f"Daily loss {daily_loss_pct:.1f}% exceeds limit "
|
||||
f"{self._loss_limits.max_daily_loss_percent}%"
|
||||
),
|
||||
current_value=daily_loss_pct,
|
||||
limit_value=self._loss_limits.max_daily_loss_percent,
|
||||
)
|
||||
)
|
||||
|
||||
def _check_drawdown(
|
||||
self,
|
||||
portfolio: PortfolioState,
|
||||
response: RiskCheckResponse,
|
||||
) -> None:
|
||||
"""Check drawdown limits."""
|
||||
response.checked_rules.append("drawdown")
|
||||
|
||||
# Check absolute drawdown
|
||||
if self._loss_limits.max_drawdown is not None:
|
||||
if portfolio.current_drawdown > self._loss_limits.max_drawdown:
|
||||
response.add_violation(
|
||||
RiskViolation(
|
||||
rule_type=RiskRuleType.DRAWDOWN,
|
||||
rule_name="max_drawdown",
|
||||
message=(
|
||||
f"Drawdown ${portfolio.current_drawdown} exceeds limit "
|
||||
f"${self._loss_limits.max_drawdown}"
|
||||
),
|
||||
current_value=portfolio.current_drawdown,
|
||||
limit_value=self._loss_limits.max_drawdown,
|
||||
)
|
||||
)
|
||||
|
||||
# Check percentage drawdown
|
||||
if self._loss_limits.max_drawdown_percent is not None:
|
||||
if portfolio.current_drawdown_percent > self._loss_limits.max_drawdown_percent:
|
||||
response.add_violation(
|
||||
RiskViolation(
|
||||
rule_type=RiskRuleType.DRAWDOWN,
|
||||
rule_name="max_drawdown_percent",
|
||||
message=(
|
||||
f"Drawdown {portfolio.current_drawdown_percent:.1f}% "
|
||||
f"exceeds limit {self._loss_limits.max_drawdown_percent}%"
|
||||
),
|
||||
current_value=portfolio.current_drawdown_percent,
|
||||
limit_value=self._loss_limits.max_drawdown_percent,
|
||||
)
|
||||
)
|
||||
|
||||
def _check_single_trade_loss(
|
||||
self,
|
||||
order: OrderRequest,
|
||||
portfolio: PortfolioState,
|
||||
response: RiskCheckResponse,
|
||||
estimated_price: Optional[Decimal],
|
||||
) -> None:
|
||||
"""Check single trade loss limit."""
|
||||
response.checked_rules.append("single_trade_loss")
|
||||
|
||||
limit = self._loss_limits.max_single_trade_loss
|
||||
if limit is None:
|
||||
return
|
||||
|
||||
if estimated_price is None:
|
||||
return
|
||||
|
||||
# Calculate max potential loss for the order
|
||||
order_value = order.quantity * estimated_price
|
||||
|
||||
# For sells, potential loss is limited
|
||||
# For buys, assume worst case is total loss of order value
|
||||
if order.side == OrderSide.BUY:
|
||||
potential_loss = order_value
|
||||
if potential_loss > limit:
|
||||
response.add_violation(
|
||||
RiskViolation(
|
||||
rule_type=RiskRuleType.DAILY_LOSS,
|
||||
rule_name="max_single_trade_loss",
|
||||
message=(
|
||||
f"Potential loss ${potential_loss} exceeds limit "
|
||||
f"${limit}"
|
||||
),
|
||||
current_value=potential_loss,
|
||||
limit_value=limit,
|
||||
severity="warning", # Warning, not blocking
|
||||
)
|
||||
)
|
||||
|
||||
def _check_consecutive_losses(
|
||||
self,
|
||||
portfolio: PortfolioState,
|
||||
response: RiskCheckResponse,
|
||||
) -> None:
|
||||
"""Check consecutive losses limit."""
|
||||
response.checked_rules.append("consecutive_losses")
|
||||
|
||||
limit = self._loss_limits.max_consecutive_losses
|
||||
if limit is None:
|
||||
return
|
||||
|
||||
if portfolio.consecutive_losses >= limit:
|
||||
self._trigger_cooling_off()
|
||||
response.add_violation(
|
||||
RiskViolation(
|
||||
rule_type=RiskRuleType.DAILY_LOSS,
|
||||
rule_name="max_consecutive_losses",
|
||||
message=(
|
||||
f"Consecutive losses {portfolio.consecutive_losses} "
|
||||
f"reached limit {limit}"
|
||||
),
|
||||
current_value=Decimal(str(portfolio.consecutive_losses)),
|
||||
limit_value=Decimal(str(limit)),
|
||||
)
|
||||
)
|
||||
|
||||
def _run_custom_rules(
|
||||
self,
|
||||
order: OrderRequest,
|
||||
portfolio: PortfolioState,
|
||||
response: RiskCheckResponse,
|
||||
) -> None:
|
||||
"""Run custom validation rules."""
|
||||
for i, rule in enumerate(self._custom_rules):
|
||||
response.checked_rules.append(f"custom_rule_{i}")
|
||||
try:
|
||||
violation = rule(order, portfolio)
|
||||
if violation:
|
||||
response.add_violation(violation)
|
||||
except Exception:
|
||||
# Don't let custom rule errors break validation
|
||||
pass
|
||||
|
||||
def _trigger_cooling_off(self) -> None:
|
||||
"""Trigger cooling off period."""
|
||||
if self._loss_limits.cooling_off_period_minutes > 0:
|
||||
self._in_cooling_off = True
|
||||
from datetime import timedelta
|
||||
|
||||
self._cooling_off_until = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=self._loss_limits.cooling_off_period_minutes
|
||||
)
|
||||
|
||||
def update_daily_pnl(self, pnl: Decimal, trade_date: date) -> None:
|
||||
"""Update daily P&L tracking.
|
||||
|
||||
Args:
|
||||
pnl: P&L for the trade
|
||||
trade_date: Date of the trade
|
||||
"""
|
||||
if trade_date not in self._daily_pnl_by_date:
|
||||
self._daily_pnl_by_date[trade_date] = Decimal("0")
|
||||
self._daily_pnl_by_date[trade_date] += pnl
|
||||
|
||||
def get_daily_pnl(self, trade_date: date) -> Decimal:
|
||||
"""Get daily P&L for a date.
|
||||
|
||||
Args:
|
||||
trade_date: Date to get P&L for
|
||||
|
||||
Returns:
|
||||
P&L for the date
|
||||
"""
|
||||
return self._daily_pnl_by_date.get(trade_date, Decimal("0"))
|
||||
|
||||
def update_peak_equity(self, equity: Decimal) -> None:
|
||||
"""Update peak equity tracking.
|
||||
|
||||
Args:
|
||||
equity: Current equity
|
||||
"""
|
||||
if self._peak_equity is None or equity > self._peak_equity:
|
||||
self._peak_equity = equity
|
||||
|
||||
def reset_daily_limits(self) -> None:
|
||||
"""Reset daily tracking (call at start of each trading day)."""
|
||||
self._in_cooling_off = False
|
||||
self._cooling_off_until = None
|
||||
|
||||
def reset_all(self) -> None:
|
||||
"""Reset all tracking state."""
|
||||
self._daily_pnl_by_date.clear()
|
||||
self._peak_equity = None
|
||||
self._in_cooling_off = False
|
||||
self._cooling_off_until = None
|
||||
|
||||
|
||||
# Export
|
||||
__all__ = [
|
||||
"RiskManager",
|
||||
"RiskCheckResult",
|
||||
"RiskRuleType",
|
||||
"RiskViolation",
|
||||
"RiskCheckResponse",
|
||||
"PositionLimits",
|
||||
"LossLimits",
|
||||
"PortfolioState",
|
||||
]
|
||||
Loading…
Reference in New Issue