TradingAgents/.claude/hooks/auto_tdd_enforcer.py

326 lines
10 KiB
Python
Executable File

#!/usr/bin/env python3
"""
TDD Enforcer - Ensures tests are written BEFORE implementation.
Blocks implementation if:
1. No test file exists for the feature
2. Test file exists but all tests passing (tests should fail first in TDD!)
Allows implementation if:
1. Tests exist and are failing (proper TDD workflow)
2. User explicitly requests to skip TDD
Auto-invokes tester subagent to write failing tests first.
Hook Integration:
- Event: PreToolUse (before Write/Edit on src/ files)
- Trigger: Writing to src/**/*.py
- Action: Check if tests exist and are failing
"""
import subprocess
import sys
from pathlib import Path
from typing import Optional, Tuple
# ============================================================================
# Configuration
# ============================================================================
PROJECT_ROOT = Path(__file__).parent.parent.parent
SRC_DIR = PROJECT_ROOT / "src" / "[project_name]"
TESTS_DIR = PROJECT_ROOT / "tests"
UNIT_TESTS_DIR = TESTS_DIR / "unit"
INTEGRATION_TESTS_DIR = TESTS_DIR / "integration"
# Patterns that indicate implementation (not just refactoring)
IMPLEMENTATION_KEYWORDS = [
"implement",
"add feature",
"create new",
"new function",
"new class",
"add method",
]
# Patterns that DON'T require TDD (refactoring, docs, etc.)
SKIP_TDD_KEYWORDS = [
"refactor",
"rename",
"format",
"typo",
"comment",
"docstring",
"fix bug", # Bug fixes can have tests after
"update docs",
]
# ============================================================================
# Helper Functions
# ============================================================================
def get_test_file_for_module(module_path: Path) -> Path:
"""Get corresponding test file for source module.
Example:
src/[project_name]/trainer.py → tests/unit/test_trainer.py
src/[project_name]/core/adapter.py → tests/unit/test_adapter.py
"""
# Get the module name (last part of path before .py)
module_name = module_path.stem
# Test file naming convention: test_{module_name}.py
test_name = f"test_{module_name}.py"
# Try unit tests first, then integration tests
unit_test_path = UNIT_TESTS_DIR / test_name
integration_test_path = INTEGRATION_TESTS_DIR / test_name
# Return unit test path (even if doesn't exist - it's the expected location)
return unit_test_path
def tests_exist(test_file: Path) -> bool:
"""Check if test file exists."""
return test_file.exists()
def run_tests(test_file: Path) -> Tuple[bool, str]:
"""Run tests and return (passing, output).
Returns:
(True, output) if tests pass
(False, output) if tests fail
"""
if not test_file.exists():
return (False, "Test file does not exist")
try:
result = subprocess.run(
["python", "-m", "pytest", str(test_file), "-v", "--tb=short"],
cwd=PROJECT_ROOT,
capture_output=True,
text=True,
timeout=30, # 30 second timeout
)
output = result.stdout + result.stderr
# Tests PASSING = returncode 0
# Tests FAILING = returncode != 0
passing = (result.returncode == 0)
return (passing, output)
except subprocess.TimeoutExpired:
return (False, "Tests timed out (>30 seconds)")
except Exception as e:
return (False, f"Error running tests: {e}")
def should_skip_tdd(user_prompt: str) -> bool:
"""Check if user request suggests we should skip TDD enforcement.
Skip TDD for:
- Refactoring
- Renaming
- Formatting
- Documentation
- Bug fixes (tests can come after for bugs)
"""
prompt_lower = user_prompt.lower()
for keyword in SKIP_TDD_KEYWORDS:
if keyword in prompt_lower:
return True
return False
def is_implementation(user_prompt: str) -> bool:
"""Check if user request is implementing new functionality.
Returns True for:
- "implement X"
- "add feature Y"
- "create new Z"
"""
prompt_lower = user_prompt.lower()
for keyword in IMPLEMENTATION_KEYWORDS:
if keyword in prompt_lower:
return True
return False
def detect_target_module(file_path: str) -> Optional[Path]:
"""Detect which module is being modified from file path.
Args:
file_path: Path to file being written (from $CLAUDE_FILE_PATHS)
Returns:
Path object if it's a source file, None otherwise
"""
path = Path(file_path)
# Only enforce TDD for source files in src/[project_name]/
if "src/[project_name]" not in str(path):
return None
# Ignore test files
if "test_" in path.name:
return None
# Ignore __init__.py (usually just imports)
if path.name == "__init__.py":
return None
return path
def suggest_tester_invocation(feature_request: str, target_module: Path) -> str:
"""Generate suggestion for invoking tester subagent.
Returns:
Formatted message suggesting how to invoke tester
"""
test_file = get_test_file_for_module(target_module)
return f"""
╭─────────────────────────────────────────────────────────╮
│ 🧪 TDD ENFORCEMENT: Tests Required Before Implementation │
╰─────────────────────────────────────────────────────────╯
❌ No tests found for: {target_module.name}
Expected test file: {test_file.relative_to(PROJECT_ROOT)}
┌─────────────────────────────────────────────────────────┐
│ 📋 TDD Workflow (Required): │
│ │
│ 1. Write FAILING tests first (tester subagent) │
│ 2. Run tests (should FAIL - not implemented yet) │
│ 3. Implement feature (make tests PASS) │
│ 4. Refactor if needed │
└─────────────────────────────────────────────────────────┘
🤖 AUTO-INVOKE TESTER SUBAGENT:
The tester subagent can automatically:
✓ Write failing tests for: {feature_request}
✓ Create test file: {test_file.name}
✓ Run tests (will fail - not implemented)
✓ Commit tests
✓ Allow implementation to proceed
To invoke tester subagent, tell Claude:
"Invoke tester subagent to write tests for {feature_request}"
Or manually create tests first:
→ Create {test_file.relative_to(PROJECT_ROOT)}
→ Write tests that will fail (feature not implemented)
→ Run: pytest {test_file.relative_to(PROJECT_ROOT)} -v
→ Verify tests FAIL
→ Then proceed with implementation
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
TDD = Test-Driven Development (Tests First, Then Code)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
"""
# ============================================================================
# Main TDD Enforcement Logic
# ============================================================================
def enforce_tdd(user_prompt: str, file_path: str) -> int:
"""Enforce TDD workflow.
Args:
user_prompt: User's request
file_path: File being written to
Returns:
0 = Allow implementation (tests exist and failing)
1 = Block implementation (no tests or tests passing)
2 = Suggest tester subagent (no tests, can auto-create)
"""
# Detect target module
target_module = detect_target_module(file_path)
if target_module is None:
# Not a source file, allow
return 0
# Check if we should skip TDD enforcement
if should_skip_tdd(user_prompt):
print(f"⏭️ Skipping TDD enforcement (refactoring/docs/bug fix)")
return 0
# Check if this is new implementation
if not is_implementation(user_prompt):
# Not implementing new features, allow
return 0
# Get corresponding test file
test_file = get_test_file_for_module(target_module)
# Check if tests exist
if not tests_exist(test_file):
# No tests - suggest tester subagent
print(suggest_tester_invocation(user_prompt, target_module))
return 2
# Tests exist - check if they're failing (proper TDD)
passing, output = run_tests(test_file)
if not passing:
# Tests failing = proper TDD workflow ✅
print(f"✅ TDD Compliant: Tests exist and failing")
print(f" Test file: {test_file.relative_to(PROJECT_ROOT)}")
print(f" → Proceed with implementation to make tests pass")
return 0
# Tests passing = NOT proper TDD ❌
print(f"⚠️ TDD Violation: Tests exist but all passing")
print(f" Test file: {test_file.relative_to(PROJECT_ROOT)}")
print()
print("In TDD, tests should FAIL before implementation:")
print("1. Write tests that will fail (feature not implemented)")
print("2. Run tests (verify they FAIL)")
print("3. Implement feature (make tests PASS)")
print()
print("Your tests are passing, which means either:")
print("a) Feature is already implemented (refactoring, not new feature)")
print("b) Tests are not comprehensive enough")
print()
print("If this is refactoring, ignore this warning.")
print("If this is NEW functionality, add FAILING tests first.")
return 1
def main():
"""Main entry point."""
# Parse arguments
if len(sys.argv) < 3:
# Not enough arguments - allow (might be manual invocation)
return 0
user_prompt = sys.argv[1]
file_path = sys.argv[2]
# Enforce TDD
exit_code = enforce_tdd(user_prompt, file_path)
return exit_code
if __name__ == "__main__":
sys.exit(main())