refactor(factory): add pattern caching and type safety to validation

Improvements:
- Add ProviderMismatch TypedDict for type-safe return values
- Cache compiled regex patterns for better performance
- Update documentation to reflect optimizations

Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
陈少杰 2026-04-16 20:28:14 +08:00
parent 78312851f9
commit e581adbeca
2 changed files with 21 additions and 12 deletions

View File

@ -8,7 +8,8 @@ Scope: LLMRunner configuration validation and error classification
**2026-04-16**: Refactored provider validation to centralize patterns in `factory.py`
- Moved `_PROVIDER_BASE_URL_PATTERNS` from `llm_runner.py` to `ProviderSpec.base_url_patterns` in `factory.py`
- Added `validate_provider_base_url()` function in factory for reusable validation
- Added `validate_provider_base_url()` function with pattern caching for performance
- Added `ProviderMismatch` TypedDict for type-safe validation results
- Split ollama and openrouter into separate `ProviderSpec` entries (previously shared openai's spec)
- Reduced `llm_runner.py` from 45 lines to 13 lines for validation logic
- All 21 tests pass, including 6 provider mismatch tests

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Callable, Optional
from typing import Callable, Optional, TypedDict
import re
from .base_client import BaseLLMClient
@ -13,6 +13,16 @@ _OPENAI_COMPATIBLE = (
"openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter",
)
# Compiled pattern cache for validation performance
_COMPILED_PATTERNS: dict[str, list[re.Pattern]] = {}
class ProviderMismatch(TypedDict):
"""Provider validation mismatch details."""
provider: str
backend_url: str
expected_patterns: tuple[str, ...]
@dataclass(frozen=True)
class ProviderSpec:
@ -130,7 +140,7 @@ def create_llm_client(
return provider_spec.builder(model, base_url, **kwargs)
def validate_provider_base_url(provider: str, base_url: str) -> Optional[dict]:
def validate_provider_base_url(provider: str, base_url: str) -> Optional[ProviderMismatch]:
"""Validate provider × base_url compatibility.
Args:
@ -138,12 +148,7 @@ def validate_provider_base_url(provider: str, base_url: str) -> Optional[dict]:
base_url: API endpoint URL
Returns:
None if valid, or dict with mismatch details if invalid:
{
"provider": str,
"backend_url": str,
"expected_patterns": tuple[str, ...]
}
None if valid, or ProviderMismatch dict if invalid
"""
if not provider or not base_url:
return None
@ -161,9 +166,12 @@ def validate_provider_base_url(provider: str, base_url: str) -> Optional[dict]:
# No validation rules defined for this provider
return None
# Compile and test patterns
for pattern_str in spec.base_url_patterns:
pattern = re.compile(pattern_str)
# Use cached compiled patterns for performance
cache_key = spec.canonical_name
if cache_key not in _COMPILED_PATTERNS:
_COMPILED_PATTERNS[cache_key] = [re.compile(p) for p in spec.base_url_patterns]
for pattern in _COMPILED_PATTERNS[cache_key]:
if pattern.search(base_url_lower):
return None # Match found