TradingAgents/tradingagents/execution/risk_controls.py

736 lines
25 KiB
Python

"""Risk Controls for order execution.
Issue #28: [EXEC-27] Risk controls - position limits, loss limits
This module provides pre-trade risk validation including:
- Position size limits (max shares, max notional value)
- Concentration limits (max % of portfolio in single position)
- Daily loss limits
- Drawdown limits
- Pre-trade validation framework
Example:
>>> from tradingagents.execution import RiskManager, PositionLimits, LossLimits
>>>
>>> limits = PositionLimits(
... max_position_size=Decimal("10000"),
... max_position_value=Decimal("50000"),
... max_concentration_percent=Decimal("20"),
... )
>>> risk_manager = RiskManager(position_limits=limits)
>>> result = risk_manager.check_order(order_request, portfolio)
"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone, date
from decimal import Decimal
from enum import Enum
from typing import Any, Dict, List, Optional, Callable
from .broker_base import (
OrderRequest,
OrderSide,
Position,
)
class RiskCheckResult(Enum):
"""Result of a risk check."""
PASSED = "passed"
FAILED = "failed"
WARNING = "warning"
SKIPPED = "skipped"
class RiskRuleType(Enum):
"""Type of risk rule."""
POSITION_SIZE = "position_size"
POSITION_VALUE = "position_value"
CONCENTRATION = "concentration"
DAILY_LOSS = "daily_loss"
DRAWDOWN = "drawdown"
CUSTOM = "custom"
@dataclass
class RiskViolation:
"""Details of a risk limit violation.
Attributes:
rule_type: Type of rule violated
rule_name: Name of the specific rule
message: Human-readable violation message
current_value: Current value that violated the limit
limit_value: The limit that was exceeded
severity: Violation severity (error, warning)
metadata: Additional context
"""
rule_type: RiskRuleType
rule_name: str
message: str
current_value: Decimal
limit_value: Decimal
severity: str = "error"
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class RiskCheckResponse:
"""Response from risk validation.
Attributes:
passed: Whether all risk checks passed
violations: List of rule violations
warnings: List of warnings (non-blocking)
checked_rules: List of rules that were checked
timestamp: When the check was performed
"""
passed: bool = True
violations: List[RiskViolation] = field(default_factory=list)
warnings: List[RiskViolation] = field(default_factory=list)
checked_rules: List[str] = field(default_factory=list)
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def add_violation(self, violation: RiskViolation) -> None:
"""Add a violation to the response."""
if violation.severity == "warning":
self.warnings.append(violation)
else:
self.violations.append(violation)
self.passed = False
@property
def rejection_message(self) -> Optional[str]:
"""Get formatted rejection message if failed."""
if self.passed:
return None
messages = [v.message for v in self.violations]
return "; ".join(messages)
@dataclass
class PositionLimits:
"""Position-related risk limits.
Attributes:
max_position_size: Maximum shares/units in single position
max_position_value: Maximum notional value of single position
max_concentration_percent: Maximum % of portfolio in single position
max_total_positions: Maximum number of open positions
max_sector_concentration: Maximum % in single sector
per_symbol_limits: Custom limits per symbol
"""
max_position_size: Optional[Decimal] = None
max_position_value: Optional[Decimal] = None
max_concentration_percent: Optional[Decimal] = None
max_total_positions: Optional[int] = None
max_sector_concentration: Optional[Decimal] = None
per_symbol_limits: Dict[str, Dict[str, Decimal]] = field(default_factory=dict)
def get_limit_for_symbol(
self, symbol: str, limit_type: str
) -> Optional[Decimal]:
"""Get specific limit for a symbol, falling back to default."""
if symbol in self.per_symbol_limits:
if limit_type in self.per_symbol_limits[symbol]:
return self.per_symbol_limits[symbol][limit_type]
return getattr(self, limit_type, None)
@dataclass
class LossLimits:
"""Loss-related risk limits.
Attributes:
max_daily_loss: Maximum loss allowed per day
max_daily_loss_percent: Maximum loss as % of equity per day
max_drawdown: Maximum drawdown from peak
max_drawdown_percent: Maximum drawdown as % of peak
max_single_trade_loss: Maximum loss on single trade
max_consecutive_losses: Maximum consecutive losing trades
cooling_off_period_minutes: Minutes to wait after hitting limit
"""
max_daily_loss: Optional[Decimal] = None
max_daily_loss_percent: Optional[Decimal] = None
max_drawdown: Optional[Decimal] = None
max_drawdown_percent: Optional[Decimal] = None
max_single_trade_loss: Optional[Decimal] = None
max_consecutive_losses: Optional[int] = None
cooling_off_period_minutes: int = 0
@dataclass
class PortfolioState:
"""Current portfolio state for risk calculations.
Attributes:
positions: Current positions keyed by symbol
cash: Available cash
equity: Total portfolio equity
buying_power: Available buying power
daily_pnl: Profit/loss for current day
peak_equity: Peak equity for drawdown calculations
consecutive_losses: Current consecutive losing trades
last_loss_time: When last loss occurred
"""
positions: Dict[str, Position] = field(default_factory=dict)
cash: Decimal = Decimal("0")
equity: Decimal = Decimal("0")
buying_power: Decimal = Decimal("0")
daily_pnl: Decimal = Decimal("0")
peak_equity: Optional[Decimal] = None
consecutive_losses: int = 0
last_loss_time: Optional[datetime] = None
@property
def current_drawdown(self) -> Decimal:
"""Calculate current drawdown from peak."""
if self.peak_equity is None or self.peak_equity <= 0:
return Decimal("0")
return self.peak_equity - self.equity
@property
def current_drawdown_percent(self) -> Decimal:
"""Calculate current drawdown as percentage."""
if self.peak_equity is None or self.peak_equity <= 0:
return Decimal("0")
return (self.current_drawdown / self.peak_equity) * Decimal("100")
class RiskManager:
"""Manages pre-trade risk validation.
The RiskManager validates orders against configured risk limits
before allowing them to be submitted to brokers.
Example:
>>> risk_manager = RiskManager(
... position_limits=PositionLimits(max_position_size=1000),
... loss_limits=LossLimits(max_daily_loss=Decimal("500")),
... )
>>> result = risk_manager.validate_order(order_request, portfolio_state)
>>> if not result.passed:
... print(f"Order rejected: {result.rejection_message}")
"""
def __init__(
self,
position_limits: Optional[PositionLimits] = None,
loss_limits: Optional[LossLimits] = None,
custom_rules: Optional[List[Callable]] = None,
enabled: bool = True,
) -> None:
"""Initialize risk manager.
Args:
position_limits: Position-related limits
loss_limits: Loss-related limits
custom_rules: Custom validation functions
enabled: Whether risk checks are enabled
"""
self._position_limits = position_limits or PositionLimits()
self._loss_limits = loss_limits or LossLimits()
self._custom_rules = custom_rules or []
self._enabled = enabled
self._daily_pnl_by_date: Dict[date, Decimal] = {}
self._peak_equity: Optional[Decimal] = None
self._in_cooling_off = False
self._cooling_off_until: Optional[datetime] = None
@property
def enabled(self) -> bool:
"""Check if risk manager is enabled."""
return self._enabled
@enabled.setter
def enabled(self, value: bool) -> None:
"""Enable or disable risk manager."""
self._enabled = value
@property
def position_limits(self) -> PositionLimits:
"""Get position limits."""
return self._position_limits
@position_limits.setter
def position_limits(self, limits: PositionLimits) -> None:
"""Set position limits."""
self._position_limits = limits
@property
def loss_limits(self) -> LossLimits:
"""Get loss limits."""
return self._loss_limits
@loss_limits.setter
def loss_limits(self, limits: LossLimits) -> None:
"""Set loss limits."""
self._loss_limits = limits
def add_custom_rule(
self,
rule: Callable[[OrderRequest, PortfolioState], Optional[RiskViolation]],
) -> None:
"""Add a custom validation rule.
Args:
rule: Function that takes order and portfolio, returns violation or None
"""
self._custom_rules.append(rule)
def validate_order(
self,
order: OrderRequest,
portfolio: PortfolioState,
estimated_fill_price: Optional[Decimal] = None,
) -> RiskCheckResponse:
"""Validate an order against all risk limits.
Args:
order: Order request to validate
portfolio: Current portfolio state
estimated_fill_price: Expected fill price for value calculations
Returns:
RiskCheckResponse with validation results
"""
response = RiskCheckResponse()
if not self._enabled:
response.checked_rules.append("(risk checks disabled)")
return response
# Check cooling off period
if self._in_cooling_off and self._cooling_off_until:
if datetime.now(timezone.utc) < self._cooling_off_until:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.DAILY_LOSS,
rule_name="cooling_off_period",
message=f"In cooling off period until {self._cooling_off_until}",
current_value=Decimal("0"),
limit_value=Decimal("0"),
)
)
return response
else:
self._in_cooling_off = False
self._cooling_off_until = None
# Run all checks
self._check_position_size(order, portfolio, response)
self._check_position_value(order, portfolio, response, estimated_fill_price)
self._check_concentration(order, portfolio, response, estimated_fill_price)
self._check_total_positions(order, portfolio, response)
self._check_daily_loss(portfolio, response)
self._check_drawdown(portfolio, response)
self._check_single_trade_loss(order, portfolio, response, estimated_fill_price)
self._check_consecutive_losses(portfolio, response)
self._run_custom_rules(order, portfolio, response)
return response
def _check_position_size(
self,
order: OrderRequest,
portfolio: PortfolioState,
response: RiskCheckResponse,
) -> None:
"""Check position size limits."""
response.checked_rules.append("position_size")
limit = self._position_limits.get_limit_for_symbol(
order.symbol, "max_position_size"
)
if limit is None:
return
current_position = portfolio.positions.get(order.symbol)
current_qty = current_position.quantity if current_position else Decimal("0")
if order.side == OrderSide.BUY:
new_qty = current_qty + order.quantity
else:
new_qty = current_qty - order.quantity
if abs(new_qty) > limit:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.POSITION_SIZE,
rule_name="max_position_size",
message=(
f"Position size {abs(new_qty)} exceeds limit {limit} "
f"for {order.symbol}"
),
current_value=abs(new_qty),
limit_value=limit,
)
)
def _check_position_value(
self,
order: OrderRequest,
portfolio: PortfolioState,
response: RiskCheckResponse,
estimated_price: Optional[Decimal],
) -> None:
"""Check position value limits."""
response.checked_rules.append("position_value")
limit = self._position_limits.get_limit_for_symbol(
order.symbol, "max_position_value"
)
if limit is None:
return
if estimated_price is None:
# Can't check without price
return
current_position = portfolio.positions.get(order.symbol)
current_qty = current_position.quantity if current_position else Decimal("0")
if order.side == OrderSide.BUY:
new_qty = current_qty + order.quantity
else:
new_qty = current_qty - order.quantity
new_value = abs(new_qty * estimated_price)
if new_value > limit:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.POSITION_VALUE,
rule_name="max_position_value",
message=(
f"Position value ${new_value} exceeds limit ${limit} "
f"for {order.symbol}"
),
current_value=new_value,
limit_value=limit,
)
)
def _check_concentration(
self,
order: OrderRequest,
portfolio: PortfolioState,
response: RiskCheckResponse,
estimated_price: Optional[Decimal],
) -> None:
"""Check concentration limits."""
response.checked_rules.append("concentration")
limit = self._position_limits.max_concentration_percent
if limit is None:
return
if estimated_price is None or portfolio.equity <= 0:
return
current_position = portfolio.positions.get(order.symbol)
current_qty = current_position.quantity if current_position else Decimal("0")
if order.side == OrderSide.BUY:
new_qty = current_qty + order.quantity
else:
new_qty = current_qty - order.quantity
new_value = abs(new_qty * estimated_price)
concentration = (new_value / portfolio.equity) * Decimal("100")
if concentration > limit:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.CONCENTRATION,
rule_name="max_concentration",
message=(
f"Concentration {concentration:.1f}% exceeds limit {limit}% "
f"for {order.symbol}"
),
current_value=concentration,
limit_value=limit,
)
)
def _check_total_positions(
self,
order: OrderRequest,
portfolio: PortfolioState,
response: RiskCheckResponse,
) -> None:
"""Check total positions limit."""
response.checked_rules.append("total_positions")
limit = self._position_limits.max_total_positions
if limit is None:
return
# Only check for new positions (buys in symbols we don't have)
if order.side != OrderSide.BUY:
return
if order.symbol in portfolio.positions:
return # Adding to existing position
if len(portfolio.positions) >= limit:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.POSITION_SIZE,
rule_name="max_total_positions",
message=(
f"Total positions {len(portfolio.positions)} at limit {limit}"
),
current_value=Decimal(str(len(portfolio.positions))),
limit_value=Decimal(str(limit)),
)
)
def _check_daily_loss(
self,
portfolio: PortfolioState,
response: RiskCheckResponse,
) -> None:
"""Check daily loss limits."""
response.checked_rules.append("daily_loss")
# Check absolute daily loss
if self._loss_limits.max_daily_loss is not None:
if portfolio.daily_pnl < -self._loss_limits.max_daily_loss:
self._trigger_cooling_off()
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.DAILY_LOSS,
rule_name="max_daily_loss",
message=(
f"Daily loss ${abs(portfolio.daily_pnl)} exceeds limit "
f"${self._loss_limits.max_daily_loss}"
),
current_value=abs(portfolio.daily_pnl),
limit_value=self._loss_limits.max_daily_loss,
)
)
# Check percentage daily loss
if self._loss_limits.max_daily_loss_percent is not None:
if portfolio.equity > 0:
daily_loss_pct = (
abs(portfolio.daily_pnl) / portfolio.equity
) * Decimal("100")
if (
portfolio.daily_pnl < 0
and daily_loss_pct > self._loss_limits.max_daily_loss_percent
):
self._trigger_cooling_off()
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.DAILY_LOSS,
rule_name="max_daily_loss_percent",
message=(
f"Daily loss {daily_loss_pct:.1f}% exceeds limit "
f"{self._loss_limits.max_daily_loss_percent}%"
),
current_value=daily_loss_pct,
limit_value=self._loss_limits.max_daily_loss_percent,
)
)
def _check_drawdown(
self,
portfolio: PortfolioState,
response: RiskCheckResponse,
) -> None:
"""Check drawdown limits."""
response.checked_rules.append("drawdown")
# Check absolute drawdown
if self._loss_limits.max_drawdown is not None:
if portfolio.current_drawdown > self._loss_limits.max_drawdown:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.DRAWDOWN,
rule_name="max_drawdown",
message=(
f"Drawdown ${portfolio.current_drawdown} exceeds limit "
f"${self._loss_limits.max_drawdown}"
),
current_value=portfolio.current_drawdown,
limit_value=self._loss_limits.max_drawdown,
)
)
# Check percentage drawdown
if self._loss_limits.max_drawdown_percent is not None:
if portfolio.current_drawdown_percent > self._loss_limits.max_drawdown_percent:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.DRAWDOWN,
rule_name="max_drawdown_percent",
message=(
f"Drawdown {portfolio.current_drawdown_percent:.1f}% "
f"exceeds limit {self._loss_limits.max_drawdown_percent}%"
),
current_value=portfolio.current_drawdown_percent,
limit_value=self._loss_limits.max_drawdown_percent,
)
)
def _check_single_trade_loss(
self,
order: OrderRequest,
portfolio: PortfolioState,
response: RiskCheckResponse,
estimated_price: Optional[Decimal],
) -> None:
"""Check single trade loss limit."""
response.checked_rules.append("single_trade_loss")
limit = self._loss_limits.max_single_trade_loss
if limit is None:
return
if estimated_price is None:
return
# Calculate max potential loss for the order
order_value = order.quantity * estimated_price
# For sells, potential loss is limited
# For buys, assume worst case is total loss of order value
if order.side == OrderSide.BUY:
potential_loss = order_value
if potential_loss > limit:
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.DAILY_LOSS,
rule_name="max_single_trade_loss",
message=(
f"Potential loss ${potential_loss} exceeds limit "
f"${limit}"
),
current_value=potential_loss,
limit_value=limit,
severity="warning", # Warning, not blocking
)
)
def _check_consecutive_losses(
self,
portfolio: PortfolioState,
response: RiskCheckResponse,
) -> None:
"""Check consecutive losses limit."""
response.checked_rules.append("consecutive_losses")
limit = self._loss_limits.max_consecutive_losses
if limit is None:
return
if portfolio.consecutive_losses >= limit:
self._trigger_cooling_off()
response.add_violation(
RiskViolation(
rule_type=RiskRuleType.DAILY_LOSS,
rule_name="max_consecutive_losses",
message=(
f"Consecutive losses {portfolio.consecutive_losses} "
f"reached limit {limit}"
),
current_value=Decimal(str(portfolio.consecutive_losses)),
limit_value=Decimal(str(limit)),
)
)
def _run_custom_rules(
self,
order: OrderRequest,
portfolio: PortfolioState,
response: RiskCheckResponse,
) -> None:
"""Run custom validation rules."""
for i, rule in enumerate(self._custom_rules):
response.checked_rules.append(f"custom_rule_{i}")
try:
violation = rule(order, portfolio)
if violation:
response.add_violation(violation)
except Exception:
# Don't let custom rule errors break validation
pass
def _trigger_cooling_off(self) -> None:
"""Trigger cooling off period."""
if self._loss_limits.cooling_off_period_minutes > 0:
self._in_cooling_off = True
from datetime import timedelta
self._cooling_off_until = datetime.now(timezone.utc) + timedelta(
minutes=self._loss_limits.cooling_off_period_minutes
)
def update_daily_pnl(self, pnl: Decimal, trade_date: date) -> None:
"""Update daily P&L tracking.
Args:
pnl: P&L for the trade
trade_date: Date of the trade
"""
if trade_date not in self._daily_pnl_by_date:
self._daily_pnl_by_date[trade_date] = Decimal("0")
self._daily_pnl_by_date[trade_date] += pnl
def get_daily_pnl(self, trade_date: date) -> Decimal:
"""Get daily P&L for a date.
Args:
trade_date: Date to get P&L for
Returns:
P&L for the date
"""
return self._daily_pnl_by_date.get(trade_date, Decimal("0"))
def update_peak_equity(self, equity: Decimal) -> None:
"""Update peak equity tracking.
Args:
equity: Current equity
"""
if self._peak_equity is None or equity > self._peak_equity:
self._peak_equity = equity
def reset_daily_limits(self) -> None:
"""Reset daily tracking (call at start of each trading day)."""
self._in_cooling_off = False
self._cooling_off_until = None
def reset_all(self) -> None:
"""Reset all tracking state."""
self._daily_pnl_by_date.clear()
self._peak_equity = None
self._in_cooling_off = False
self._cooling_off_until = None
# Export
__all__ = [
"RiskManager",
"RiskCheckResult",
"RiskRuleType",
"RiskViolation",
"RiskCheckResponse",
"PositionLimits",
"LossLimits",
"PortfolioState",
]