TradingAgents/.claude/lib/context_skill_injector.py

319 lines
9.5 KiB
Python

#!/usr/bin/env python3
"""
Context-Triggered Skill Injection - Issue #154
Auto-injects relevant skills based on conversation context patterns,
not just agent frontmatter declarations.
Key Features:
- Pattern-based detection (fast regex, not LLM)
- Max 5 skills per context to prevent bloat
- <100ms latency requirement
- Graceful degradation (missing skills don't block)
- Reuses existing skill_loader.py infrastructure
Pattern Categories:
- security: auth, token, password, JWT, encryption
- api: REST, endpoint, API, HTTP
- database: SQL, migration, schema, ORM
- git: commit, push, branch, merge, PR
- testing: test, unittest, pytest, TDD
- python: Python, type hints, docstring, PEP
Usage:
from context_skill_injector import get_context_skill_injection
# Get formatted skill content for a prompt
skill_content = get_context_skill_injection("implement secure API endpoint")
# Or use individual functions
patterns = detect_context_patterns("implement JWT auth")
skills = select_skills_for_context("implement JWT auth")
"""
import re
import sys
from typing import Dict, List, Optional, Set
# ============================================================================
# Configuration
# ============================================================================
# Maximum skills to inject per context (prevents context bloat)
MAX_CONTEXT_SKILLS = 5
# Pattern definitions - each maps to a set of relevant skills
# Uses word boundaries (\b) to prevent partial matches
CONTEXT_PATTERNS = {
"security": [
r"\b(auth|authenticat\w*|authoriz\w*)\b", # Matches authenticate, authentication, authorize, etc.
r"\b(token|jwt|oauth|api.?key)\b",
r"\b(password|secret|credential|encrypt)\b",
r"\b(secure|security|vulnerability|exploit)\b",
r"\b(login|logout|session|cookie)\b",
],
"api": [
r"\b(api|rest|endpoint|http)\b",
r"\b(request|response|route|handler)\b",
r"\b(get|post|put|delete|patch)\s+(request|endpoint|method)\b",
r"\b(webhook|callback|async.?api)\b",
],
"database": [
r"\b(database|db|sql|query)\b",
r"\b(migration|schema|table|column)\b",
r"\b(orm|sqlalchemy|django.?orm|prisma)\b",
r"\b(insert|update|delete|select)\s+(into|from|where)?\b",
r"\b(postgresql|mysql|sqlite|mongodb)\b",
],
"git": [
r"\b(git|commit|push|pull|merge)\b",
r"\b(branch|checkout|rebase|cherry.?pick)\b",
r"\b(pull.?request|pr|merge.?request)\b",
r"\b(stash|reset|revert|diff)\b",
],
"testing": [
r"\b(test|tests|testing|unittest)\b",
r"\b(pytest|jest|mocha|jasmine)\b",
r"\b(tdd|test.?driven|coverage)\b",
r"\b(mock|stub|fixture|assert)\b",
r"\b(integration.?test|unit.?test|e2e)\b",
],
"python": [
r"\b(python|py|python3)\b",
r"\b(type.?hint|typing|annotations)\b",
r"\b(docstring|pep|pep8|pep.?484)\b",
r"\b(class|def|async.?def|decorator)\b",
r"\b(import|from\s+\w+\s+import)\b",
],
}
# Maps pattern categories to skill names
# Skills must exist in plugins/autonomous-dev/skills/{skill-name}/SKILL.md
PATTERN_SKILL_MAP: Dict[str, List[str]] = {
"security": ["security-patterns"],
"api": ["api-design", "api-integration-patterns"],
"database": ["database-design"],
"git": ["git-workflow"],
"testing": ["testing-guide"],
"python": ["python-standards"],
}
# Priority order for skill selection when limit exceeded
PATTERN_PRIORITY = [
"security", # Security always first
"testing", # Tests are fundamental
"api", # API patterns common
"database", # Data layer
"python", # Language specifics
"git", # Operations
]
# ============================================================================
# Pattern Detection
# ============================================================================
def detect_context_patterns(user_prompt: Optional[str]) -> Set[str]:
"""
Detect context patterns in user prompt.
Scans the prompt for predefined regex patterns to identify
relevant skill categories (security, api, database, etc.).
Args:
user_prompt: User's prompt text
Returns:
Set of pattern category names detected (e.g., {"security", "api"})
Example:
>>> detect_context_patterns("implement JWT authentication")
{'security'}
>>> detect_context_patterns("create REST API endpoint")
{'api'}
"""
if not user_prompt:
return set()
text = user_prompt.lower()
detected = set()
for category, patterns in CONTEXT_PATTERNS.items():
for pattern in patterns:
if re.search(pattern, text, re.IGNORECASE):
detected.add(category)
break # One match per category is enough
return detected
# ============================================================================
# Skill Selection
# ============================================================================
def select_skills_for_context(
user_prompt: Optional[str],
max_skills: int = MAX_CONTEXT_SKILLS,
) -> List[str]:
"""
Select relevant skills based on context patterns in prompt.
Detects patterns in the prompt, maps them to skills, and returns
a prioritized list limited to max_skills to prevent context bloat.
Args:
user_prompt: User's prompt text
max_skills: Maximum number of skills to return (default: 5)
Returns:
List of skill names to inject, ordered by priority
Example:
>>> select_skills_for_context("implement secure API")
['security-patterns', 'api-design', 'api-integration-patterns']
"""
if not user_prompt:
return []
# Detect patterns
patterns = detect_context_patterns(user_prompt)
if not patterns:
return []
# Collect skills from detected patterns, respecting priority
skills = []
seen = set()
for category in PATTERN_PRIORITY:
if category in patterns:
category_skills = PATTERN_SKILL_MAP.get(category, [])
for skill in category_skills:
if skill not in seen:
skills.append(skill)
seen.add(skill)
if len(skills) >= max_skills:
return skills
return skills
# ============================================================================
# Integration with skill_loader
# ============================================================================
def get_context_skill_injection(
user_prompt: Optional[str],
max_skills: int = MAX_CONTEXT_SKILLS,
) -> str:
"""
Get formatted skill content for context-based injection.
Main entry point that combines pattern detection, skill selection,
and skill loading into a single call. Returns formatted skill
content ready to inject into context.
Args:
user_prompt: User's prompt text
max_skills: Maximum number of skills to inject
Returns:
Formatted skill content string (XML-tagged) or empty string
Example:
>>> content = get_context_skill_injection("implement JWT auth")
>>> print(content[:50])
<skills>
<skill name="security-patterns">...
"""
if not user_prompt:
return ""
# Select skills based on context
skills = select_skills_for_context(user_prompt, max_skills)
if not skills:
return ""
# Try to load and format skills using skill_loader
try:
from skill_loader import load_skill_content, format_skills_for_prompt
except ImportError:
# skill_loader not available - graceful degradation
return ""
# Load skill content
skill_contents = {}
for skill_name in skills:
content = load_skill_content(skill_name)
if content:
skill_contents[skill_name] = content
if not skill_contents:
return ""
# Format for prompt injection
return format_skills_for_prompt(skill_contents)
# ============================================================================
# Utility Functions
# ============================================================================
def get_pattern_categories() -> List[str]:
"""
Get list of available pattern categories.
Returns:
List of pattern category names
"""
return list(CONTEXT_PATTERNS.keys())
def get_skills_for_pattern(pattern: str) -> List[str]:
"""
Get skills mapped to a specific pattern category.
Args:
pattern: Pattern category name (e.g., "security")
Returns:
List of skill names for that pattern
"""
return PATTERN_SKILL_MAP.get(pattern, [])
# ============================================================================
# CLI Entry Point
# ============================================================================
def main():
"""CLI entry point for testing."""
if len(sys.argv) < 2:
print("Usage: python context_skill_injector.py <prompt>")
print("Example: python context_skill_injector.py 'implement JWT auth'")
sys.exit(1)
prompt = " ".join(sys.argv[1:])
print(f"Prompt: {prompt}")
print()
patterns = detect_context_patterns(prompt)
print(f"Detected patterns: {patterns}")
skills = select_skills_for_context(prompt)
print(f"Selected skills: {skills}")
content = get_context_skill_injection(prompt)
if content:
print(f"\nSkill content length: {len(content)} chars")
print(f"First 200 chars:\n{content[:200]}...")
else:
print("\nNo skill content loaded")
if __name__ == "__main__":
main()