From e581adbeca668152933aa30ba110330527c1b870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=B0=91=E6=9D=B0?= Date: Thu, 16 Apr 2026 20:28:14 +0800 Subject: [PATCH] refactor(factory): add pattern caching and type safety to validation Improvements: - Add ProviderMismatch TypedDict for type-safe return values - Cache compiled regex patterns for better performance - Update documentation to reflect optimizations Co-Authored-By: Claude Sonnet 4.6 (1M context) --- docs/architecture/orchestrator-validation.md | 3 +- tradingagents/llm_clients/factory.py | 30 +++++++++++++------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/docs/architecture/orchestrator-validation.md b/docs/architecture/orchestrator-validation.md index 8544c5b7..b446f541 100644 --- a/docs/architecture/orchestrator-validation.md +++ b/docs/architecture/orchestrator-validation.md @@ -8,7 +8,8 @@ Scope: LLMRunner configuration validation and error classification **2026-04-16**: Refactored provider validation to centralize patterns in `factory.py` - Moved `_PROVIDER_BASE_URL_PATTERNS` from `llm_runner.py` to `ProviderSpec.base_url_patterns` in `factory.py` -- Added `validate_provider_base_url()` function in factory for reusable validation +- Added `validate_provider_base_url()` function with pattern caching for performance +- Added `ProviderMismatch` TypedDict for type-safe validation results - Split ollama and openrouter into separate `ProviderSpec` entries (previously shared openai's spec) - Reduced `llm_runner.py` from 45 lines to 13 lines for validation logic - All 21 tests pass, including 6 provider mismatch tests diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index db477584..a168649f 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, Optional +from typing import Callable, Optional, TypedDict import re from .base_client import BaseLLMClient @@ -13,6 +13,16 @@ _OPENAI_COMPATIBLE = ( "openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter", ) +# Compiled pattern cache for validation performance +_COMPILED_PATTERNS: dict[str, list[re.Pattern]] = {} + + +class ProviderMismatch(TypedDict): + """Provider validation mismatch details.""" + provider: str + backend_url: str + expected_patterns: tuple[str, ...] + @dataclass(frozen=True) class ProviderSpec: @@ -130,7 +140,7 @@ def create_llm_client( return provider_spec.builder(model, base_url, **kwargs) -def validate_provider_base_url(provider: str, base_url: str) -> Optional[dict]: +def validate_provider_base_url(provider: str, base_url: str) -> Optional[ProviderMismatch]: """Validate provider × base_url compatibility. Args: @@ -138,12 +148,7 @@ def validate_provider_base_url(provider: str, base_url: str) -> Optional[dict]: base_url: API endpoint URL Returns: - None if valid, or dict with mismatch details if invalid: - { - "provider": str, - "backend_url": str, - "expected_patterns": tuple[str, ...] - } + None if valid, or ProviderMismatch dict if invalid """ if not provider or not base_url: return None @@ -161,9 +166,12 @@ def validate_provider_base_url(provider: str, base_url: str) -> Optional[dict]: # No validation rules defined for this provider return None - # Compile and test patterns - for pattern_str in spec.base_url_patterns: - pattern = re.compile(pattern_str) + # Use cached compiled patterns for performance + cache_key = spec.canonical_name + if cache_key not in _COMPILED_PATTERNS: + _COMPILED_PATTERNS[cache_key] = [re.compile(p) for p in spec.base_url_patterns] + + for pattern in _COMPILED_PATTERNS[cache_key]: if pattern.search(base_url_lower): return None # Match found