TradingAgents/.claude/lib/auto_approval_engine.py

500 lines
15 KiB
Python

#!/usr/bin/env python3
"""
Auto-Approve Tool Hook - PreToolUse Hook for MCP Auto-Approval
This module implements the PreToolUse lifecycle hook that auto-approves
MCP tool calls from trusted subagents. It provides:
1. Subagent context detection (CLAUDE_AGENT_NAME env var)
2. Agent whitelist checking (trusted vs restricted agents)
3. User consent verification (opt-in design)
4. Tool call validation (whitelist/blacklist)
5. Circuit breaker logic (auto-disable after 10 denials)
6. Comprehensive audit logging (every approval/denial)
7. Graceful degradation (errors default to manual approval)
Security Architecture:
- Defense-in-depth: 6 layers of validation
1. Subagent context check (only auto-approve in subagent)
2. User consent check (must opt-in)
3. Agent whitelist check (only trusted agents)
4. Tool call validation (whitelist/blacklist)
5. Circuit breaker (auto-disable after repeated denials)
6. Audit logging (full trail of decisions)
- Conservative defaults: Deny unknown commands/paths
- Graceful degradation: Errors result in manual approval (safe failure)
- Zero trust: Every tool call is validated independently
Usage (Claude Code 2.0+ lifecycle hook):
# In plugin manifest (pyproject.toml or plugins.json):
[hooks]
PreToolUse = "autonomous_dev.hooks.unified_pre_tool_use:on_pre_tool_use"
# Claude Code will call on_pre_tool_use() before each MCP tool execution
# Returns: {"approved": true/false, "reason": "explanation"}
Date: 2025-11-15
Issue: #73 (MCP Auto-Approval for Subagent Tool Calls)
Agent: implementer
Phase: TDD Green (making tests pass)
See error-handling-patterns skill for exception hierarchy and error handling best practices.
"""
import os
import sys
import threading
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Any, Optional
# Add lib directory to path for imports
lib_dir = Path(__file__).parent.parent / "lib"
sys.path.insert(0, str(lib_dir))
# Import dependencies
from tool_validator import ToolValidator, load_policy
from tool_approval_audit import ToolApprovalAuditor
from auto_approval_consent import check_user_consent, get_auto_approval_mode
from user_state_manager import DEFAULT_STATE_FILE
# Import path_utils for policy file resolution
try:
from path_utils import get_policy_file
except ImportError:
# Fallback if path_utils not available
def get_policy_file(use_cache: bool = True):
"""Fallback policy file resolution."""
return Path(__file__).parent.parent / "config" / "auto_approve_policy.json"
# Default policy file path
DEFAULT_POLICY_FILE = get_policy_file()
# Circuit breaker threshold (10 denials → auto-disable)
CIRCUIT_BREAKER_THRESHOLD = 10
# Default audit log file
DEFAULT_AUDIT_LOG = Path(__file__).parent.parent.parent.parent / "logs" / "tool_auto_approve_audit.log"
@dataclass
class AutoApprovalState:
"""Thread-safe state for auto-approval logic.
Tracks:
- denial_count: Number of consecutive denials (for circuit breaker)
- circuit_breaker_tripped: Whether circuit breaker has tripped
Thread-safe: Uses threading.Lock for concurrent access.
"""
denial_count: int = 0
circuit_breaker_tripped: bool = False
_lock: threading.Lock = field(default_factory=threading.Lock)
def increment_denial_count(self) -> int:
"""Increment denial count (thread-safe).
Returns:
New denial count
"""
with self._lock:
self.denial_count += 1
return self.denial_count
def reset_denial_count(self) -> None:
"""Reset denial count to zero (thread-safe)."""
with self._lock:
self.denial_count = 0
def trip_circuit_breaker(self) -> None:
"""Trip circuit breaker (thread-safe)."""
with self._lock:
self.circuit_breaker_tripped = True
def reset_circuit_breaker(self) -> None:
"""Reset circuit breaker (thread-safe)."""
with self._lock:
self.circuit_breaker_tripped = False
self.denial_count = 0
def is_circuit_breaker_tripped(self) -> bool:
"""Check if circuit breaker is tripped (thread-safe).
Returns:
True if tripped, False otherwise
"""
with self._lock:
return self.circuit_breaker_tripped
def get_denial_count(self) -> int:
"""Get current denial count (thread-safe).
Returns:
Current denial count
"""
with self._lock:
return self.denial_count
def items(self):
"""Return state as items for dict-like interface.
Returns:
List of (key, value) tuples
"""
with self._lock:
return [
("denial_count", self.denial_count),
("circuit_breaker_tripped", self.circuit_breaker_tripped),
]
# Global state instance
_global_state: Optional[AutoApprovalState] = None
_global_state_lock = threading.Lock()
def _get_global_state() -> AutoApprovalState:
"""Get or create global state instance (thread-safe).
Returns:
Global AutoApprovalState instance
"""
global _global_state, _global_state_lock
with _global_state_lock:
if _global_state is None:
_global_state = AutoApprovalState()
return _global_state
# Cached policy and validator (loaded once for performance)
_cached_policy: Optional[Dict[str, Any]] = None
_cached_validator: Optional[ToolValidator] = None
_cache_lock = threading.Lock()
def load_and_cache_policy(policy_file: Optional[Path] = None) -> Dict[str, Any]:
"""Load and cache policy file (thread-safe).
Policy is loaded once and cached in memory for performance.
Args:
policy_file: Path to policy file (default: uses cascading lookup via get_policy_file)
Returns:
Policy dictionary
"""
global _cached_policy, _cache_lock
with _cache_lock:
if _cached_policy is None:
# Use cascading lookup if no explicit path provided
policy_file = policy_file or get_policy_file()
_cached_policy = load_policy(policy_file)
return _cached_policy
def _get_cached_validator() -> ToolValidator:
"""Get or create cached validator instance (thread-safe).
Returns:
Cached ToolValidator instance
"""
global _cached_validator, _cache_lock
with _cache_lock:
if _cached_validator is None:
# Use cascading lookup for policy file
_cached_validator = ToolValidator(policy_file=get_policy_file())
return _cached_validator
# Subagent context detection
def is_subagent_context() -> bool:
"""Check if running in subagent context.
Subagent context is detected via CLAUDE_AGENT_NAME environment variable,
which Claude Code sets when executing tasks via the Task tool.
Returns:
True if in subagent context, False otherwise
"""
agent_name = os.getenv("CLAUDE_AGENT_NAME", "").strip()
return bool(agent_name)
def get_agent_name() -> Optional[str]:
"""Get agent name from environment variable.
Sanitizes agent name to prevent injection attacks (removes newlines,
carriage returns, tabs, and other control characters).
Returns:
Sanitized agent name if set, None otherwise
"""
agent_name = os.getenv("CLAUDE_AGENT_NAME", "").strip()
if not agent_name:
return None
# Sanitize agent name - remove control characters (CWE-117 prevention)
# Remove all characters from \x00 to \x1f (control chars)
sanitized = ''.join(c for c in agent_name if ord(c) >= 0x20)
return sanitized if sanitized else None
# Agent whitelist checking
def is_trusted_agent(agent_name: Optional[str]) -> bool:
"""Check if agent is in trusted whitelist.
Args:
agent_name: Agent name to check
Returns:
True if trusted, False otherwise
"""
if not agent_name:
return False
# Load policy
policy = load_and_cache_policy()
# Get trusted agents list
trusted_agents = policy.get("agents", {}).get("trusted", [])
# Case-insensitive check
agent_name_lower = agent_name.lower()
trusted_agents_lower = [a.lower() for a in trusted_agents]
return agent_name_lower in trusted_agents_lower
# User consent checking
def check_user_consent_cached(state_file: Path = DEFAULT_STATE_FILE) -> bool:
"""Check user consent with caching.
This is a wrapper around auto_approval_consent.check_user_consent()
that's exposed for testing.
Args:
state_file: Path to user state file
Returns:
True if user consented, False otherwise
"""
return check_user_consent(state_file)
# Circuit breaker logic
def increment_denial_count(state: Optional[AutoApprovalState] = None) -> int:
"""Increment denial count (convenience function).
Args:
state: AutoApprovalState instance (default: global state)
Returns:
New denial count
"""
if state is None:
state = _get_global_state()
return state.increment_denial_count()
def should_trip_circuit_breaker(state: Optional[AutoApprovalState] = None) -> bool:
"""Check if circuit breaker should trip.
Circuit breaker trips after CIRCUIT_BREAKER_THRESHOLD denials.
Args:
state: AutoApprovalState instance (default: global state)
Returns:
True if should trip, False otherwise
"""
if state is None:
state = _get_global_state()
return state.get_denial_count() >= CIRCUIT_BREAKER_THRESHOLD
def reset_circuit_breaker(state: Optional[AutoApprovalState] = None) -> None:
"""Reset circuit breaker (convenience function).
Args:
state: AutoApprovalState instance (default: global state)
"""
if state is None:
state = _get_global_state()
state.reset_circuit_breaker()
# Main auto-approval logic
def should_auto_approve(
tool: str,
parameters: Dict[str, Any],
agent_name: Optional[str] = None,
) -> tuple[bool, str]:
"""Determine if tool call should be auto-approved.
Decision logic:
1. Check circuit breaker (deny if tripped)
2. Get auto-approval mode (everywhere/subagent_only/disabled)
3. Check context requirements based on mode
4. Validate tool call (use ToolValidator)
5. Update circuit breaker state based on result
Args:
tool: Tool name (Bash, Read, Write, etc.)
parameters: Tool parameters
agent_name: Agent name (from CLAUDE_AGENT_NAME env var)
Returns:
Tuple of (approved: bool, reason: str)
"""
state = _get_global_state()
# 1. Check circuit breaker
if state.is_circuit_breaker_tripped():
return False, "Circuit breaker tripped (too many denials)"
# 2. Get auto-approval mode
mode = get_auto_approval_mode()
# 3. Check if auto-approval is disabled
if mode == "disabled":
return False, "Auto-approval disabled (MCP_AUTO_APPROVE not enabled)"
# 4. Check context requirements based on mode
in_subagent = is_subagent_context()
if mode == "subagent_only" and not in_subagent:
return False, "Mode is 'subagent_only' but not in subagent context"
# 5. Agent whitelist check (only in subagent context, only in subagent_only mode)
# In "everywhere" mode, skip whitelist check (trust all agents)
if mode == "subagent_only" and in_subagent and not is_trusted_agent(agent_name):
return False, f"Agent '{agent_name}' is not in trusted whitelist (subagent_only mode)"
# 6. Validate tool call
validator = _get_cached_validator()
result = validator.validate_tool_call(tool, parameters, agent_name)
# 7. Update circuit breaker state
if not result.approved:
# Increment denial count
denial_count = increment_denial_count(state)
# Check if should trip
if should_trip_circuit_breaker(state):
state.trip_circuit_breaker()
# Log circuit breaker trip
auditor = ToolApprovalAuditor()
auditor.log_circuit_breaker_trip(
agent_name=agent_name or "unknown",
denial_count=denial_count,
reason=f"Circuit breaker tripped after {denial_count} denials"
)
return False, f"Circuit breaker tripped after {denial_count} denials"
else:
# Approval - reset denial count
state.reset_denial_count()
return result.approved, result.reason
# PreToolUse hook entry point
def on_pre_tool_use(tool: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
"""PreToolUse lifecycle hook for MCP auto-approval.
This hook is called by Claude Code before each MCP tool execution.
It decides whether to auto-approve the tool call or require manual approval.
Args:
tool: Tool name (Bash, Read, Write, Edit, Grep, etc.)
parameters: Tool parameters dictionary
Returns:
Dictionary with:
- approved: bool (True = auto-approve, False = manual approval)
- reason: str (human-readable explanation)
Error Handling:
- Graceful degradation: Any error results in manual approval
- Audit logging: All errors are logged for debugging
"""
try:
# Get agent name from environment
agent_name = get_agent_name()
# Determine if should auto-approve
approved, reason = should_auto_approve(tool, parameters, agent_name)
# Log decision
auditor = ToolApprovalAuditor()
if approved:
auditor.log_approval(
agent_name=agent_name or "unknown",
tool=tool,
parameters=parameters,
reason=reason
)
else:
auditor.log_denial(
agent_name=agent_name or "unknown",
tool=tool,
parameters=parameters,
reason=reason,
security_risk="blacklist" in reason.lower() or "injection" in reason.lower()
)
return {
"approved": approved,
"reason": reason
}
except Exception as e:
# Graceful degradation - deny on error
auditor = ToolApprovalAuditor()
agent_name = get_agent_name()
auditor.log_denial(
agent_name=agent_name or "unknown",
tool=tool,
parameters=parameters,
reason=f"Error in auto-approval logic: {e}",
security_risk=False
)
return {
"approved": False,
"reason": f"Auto-approval error (defaulting to manual): {e}"
}
# Exported convenience function for testing
def prompt_user_for_consent(state_file: Path = DEFAULT_STATE_FILE) -> bool:
"""Wrapper for auto_approval_consent.prompt_user_for_consent (for testing).
Args:
state_file: Path to user state file
Returns:
True if user consented, False otherwise
"""
from auto_approval_consent import prompt_user_for_consent as _prompt
return _prompt(state_file)