diff --git a/docs/architecture/orchestrator-validation.md b/docs/architecture/orchestrator-validation.md index 52b8f431..8544c5b7 100644 --- a/docs/architecture/orchestrator-validation.md +++ b/docs/architecture/orchestrator-validation.md @@ -4,6 +4,15 @@ Status: implemented (2026-04-16) Audience: orchestrator users, backend maintainers Scope: LLMRunner configuration validation and error classification +## Change Log + +**2026-04-16**: Refactored provider validation to centralize patterns in `factory.py` +- Moved `_PROVIDER_BASE_URL_PATTERNS` from `llm_runner.py` to `ProviderSpec.base_url_patterns` in `factory.py` +- Added `validate_provider_base_url()` function in factory for reusable validation +- Split ollama and openrouter into separate `ProviderSpec` entries (previously shared openai's spec) +- Reduced `llm_runner.py` from 45 lines to 13 lines for validation logic +- All 21 tests pass, including 6 provider mismatch tests + ## Overview `orchestrator/llm_runner.py` implements three layers of configuration validation to catch errors before expensive graph initialization or API calls: @@ -243,10 +252,20 @@ python -m pytest orchestrator/tests/test_llm_runner.py -v When adding a new provider to `tradingagents/llm_clients/factory.py`: -1. Add URL pattern to `_PROVIDER_BASE_URL_PATTERNS` in `llm_runner.py` -2. Add test cases for valid and invalid configurations +1. Add a new `ProviderSpec` entry to `_PROVIDER_SPECS` tuple with `base_url_patterns` +2. Add test cases for valid and invalid configurations in `orchestrator/tests/test_llm_runner.py` 3. Update this documentation +**Example:** +```python +ProviderSpec( + canonical_name="newprovider", + aliases=("newprovider",), + builder=lambda model, base_url=None, **kwargs: NewProviderClient(model, base_url, **kwargs), + base_url_patterns=(r"api\.newprovider\.com",), +) +``` + ### Adjusting Timeout Recommendations If profiling shows different timeout requirements: @@ -277,11 +296,25 @@ Current implementation does **not** validate API key validity before graph initi ### Provider Pattern Maintenance -URL patterns must be manually kept in sync with provider changes: +~~URL patterns must be manually kept in sync with provider changes:~~ +**UPDATE (2026-04-16)**: Provider URL patterns have been moved to `tradingagents/llm_clients/factory.py` as part of `ProviderSpec`. This centralizes validation logic with provider definitions. + +**Current implementation:** +- Each `ProviderSpec` includes optional `base_url_patterns` tuple +- `validate_provider_base_url()` function provides validation logic +- `LLMRunner._detect_provider_mismatch()` delegates to factory validation +- Patterns are co-located with provider builders, reducing maintenance burden + +**Benefits:** +- Single source of truth for provider configuration +- Easier to keep patterns in sync when adding/updating providers +- Factory can be tested independently of orchestrator +- Reduced code duplication + +**Remaining considerations:** - **Risk**: Provider changes base URL structure (e.g., API versioning) - **Mitigation**: Validation is non-blocking; mismatches are logged but don't prevent operation -- **Future**: Consider moving patterns to `tradingagents/llm_clients/factory.py` as part of `ProviderSpec` ### Timeout Recommendations diff --git a/orchestrator/llm_runner.py b/orchestrator/llm_runner.py index 53e165da..4d13c08b 100644 --- a/orchestrator/llm_runner.py +++ b/orchestrator/llm_runner.py @@ -1,33 +1,16 @@ import json import logging import os -import re from datetime import datetime, timezone from orchestrator.config import OrchestratorConfig from orchestrator.contracts.error_taxonomy import ReasonCode from orchestrator.contracts.result_contract import Signal, build_error_signal from tradingagents.agents.utils.agent_states import extract_research_provenance +from tradingagents.llm_clients.factory import validate_provider_base_url logger = logging.getLogger(__name__) -# Provider × base_url validation matrix -# Note: ollama/openrouter share openai's canonical provider but have different URL patterns -_PROVIDER_BASE_URL_PATTERNS = { - "anthropic": [r"api\.anthropic\.com", r"api\.minimaxi\.com/anthropic"], - "openai": [r"api\.openai\.com"], - "google": [r"generativelanguage\.googleapis\.com"], - "xai": [r"api\.x\.ai"], - "ollama": [r"localhost:\d+", r"127\.0\.0\.1:\d+", r"ollama"], - "openrouter": [r"openrouter\.ai"], -} - -# Precompile regex patterns for efficiency -_COMPILED_PATTERNS = { - provider: [re.compile(pattern) for pattern in patterns] - for provider, patterns in _PROVIDER_BASE_URL_PATTERNS.items() -} - # Recommended timeout thresholds by analyst count _RECOMMENDED_TIMEOUTS = { 1: {"analyst": 75.0, "research": 30.0}, @@ -110,35 +93,19 @@ class LLMRunner: return self._graph def _detect_provider_mismatch(self): - """Validate provider × base_url compatibility using pattern matrix. + """Validate provider × base_url compatibility using factory's validation. Uses the original provider name (not canonical) for validation since - ollama/openrouter share openai's canonical provider but have different URLs. + ollama/openrouter have different URL patterns than openai. """ trading_cfg = self._config.trading_agents_config or {} - provider = str(trading_cfg.get("llm_provider", "")).lower() - base_url = str(trading_cfg.get("backend_url", "") or "").lower() + provider = trading_cfg.get("llm_provider", "") + base_url = trading_cfg.get("backend_url", "") if not provider or not base_url: return None - # Use original provider name for pattern matching (not canonical) - # This handles ollama/openrouter which share openai's canonical provider - compiled_patterns = _COMPILED_PATTERNS.get(provider, []) - if not compiled_patterns: - # No validation rules defined for this provider - return None - - for pattern in compiled_patterns: - if pattern.search(base_url): - return None # Match found, no mismatch - - # No pattern matched - return raw patterns for error message - return { - "provider": provider, - "backend_url": trading_cfg.get("backend_url"), - "expected_patterns": _PROVIDER_BASE_URL_PATTERNS[provider], - } + return validate_provider_base_url(provider, base_url) def get_signal(self, ticker: str, date: str) -> Signal: """获取指定股票在指定日期的 LLM 信号,带缓存。""" diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index 9cab2c64..db477584 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from typing import Callable, Optional +import re from .base_client import BaseLLMClient from .openai_client import OpenAIClient @@ -15,23 +16,54 @@ _OPENAI_COMPATIBLE = ( @dataclass(frozen=True) class ProviderSpec: - """Provider registry entry for LLM client creation.""" + """Provider registry entry for LLM client creation. + + Attributes: + canonical_name: Primary provider identifier + aliases: Alternative names that resolve to this provider + builder: Factory function to create the client instance + base_url_patterns: Regex patterns for valid base URLs (None = no validation) + """ canonical_name: str aliases: tuple[str, ...] builder: Callable[..., BaseLLMClient] + base_url_patterns: Optional[tuple[str, ...]] = None _PROVIDER_SPECS: tuple[ProviderSpec, ...] = ( ProviderSpec( canonical_name="openai", - aliases=("openai", "ollama", "openrouter"), + aliases=("openai",), builder=lambda model, base_url=None, **kwargs: OpenAIClient( model, base_url, - provider=kwargs.pop("provider", "openai"), + provider="openai", **kwargs, ), + base_url_patterns=(r"api\.openai\.com",), + ), + ProviderSpec( + canonical_name="ollama", + aliases=("ollama",), + builder=lambda model, base_url=None, **kwargs: OpenAIClient( + model, + base_url, + provider="ollama", + **kwargs, + ), + base_url_patterns=(r"localhost:\d+", r"127\.0\.0\.1:\d+", r"ollama"), + ), + ProviderSpec( + canonical_name="openrouter", + aliases=("openrouter",), + builder=lambda model, base_url=None, **kwargs: OpenAIClient( + model, + base_url, + provider="openrouter", + **kwargs, + ), + base_url_patterns=(r"openrouter\.ai",), ), ProviderSpec( canonical_name="xai", @@ -42,16 +74,19 @@ _PROVIDER_SPECS: tuple[ProviderSpec, ...] = ( provider="xai", **kwargs, ), + base_url_patterns=(r"api\.x\.ai",), ), ProviderSpec( canonical_name="anthropic", aliases=("anthropic",), builder=lambda model, base_url=None, **kwargs: AnthropicClient(model, base_url, **kwargs), + base_url_patterns=(r"api\.anthropic\.com", r"api\.minimaxi\.com/anthropic"), ), ProviderSpec( canonical_name="google", aliases=("google",), builder=lambda model, base_url=None, **kwargs: GoogleClient(model, base_url, **kwargs), + base_url_patterns=(r"generativelanguage\.googleapis\.com",), ), ) @@ -92,7 +127,49 @@ def create_llm_client( """ provider_lower = provider.lower() provider_spec = get_provider_spec(provider_lower) - builder_kwargs = dict(kwargs) - if provider_lower in ("openai", "ollama", "openrouter"): - builder_kwargs["provider"] = provider_lower - return provider_spec.builder(model, base_url, **builder_kwargs) + return provider_spec.builder(model, base_url, **kwargs) + + +def validate_provider_base_url(provider: str, base_url: str) -> Optional[dict]: + """Validate provider × base_url compatibility. + + Args: + provider: LLM provider name (original, not canonical) + base_url: API endpoint URL + + Returns: + None if valid, or dict with mismatch details if invalid: + { + "provider": str, + "backend_url": str, + "expected_patterns": tuple[str, ...] + } + """ + if not provider or not base_url: + return None + + provider_lower = provider.lower() + base_url_lower = base_url.lower() + + try: + spec = get_provider_spec(provider_lower) + except ValueError: + # Unknown provider - no validation rules + return None + + if spec.base_url_patterns is None: + # No validation rules defined for this provider + return None + + # Compile and test patterns + for pattern_str in spec.base_url_patterns: + pattern = re.compile(pattern_str) + if pattern.search(base_url_lower): + return None # Match found + + # No pattern matched - return mismatch details + return { + "provider": provider_lower, + "backend_url": base_url, + "expected_patterns": spec.base_url_patterns, + }