refactor(orchestrator): centralize provider validation in factory

Move provider × base_url validation patterns from llm_runner.py to
factory.py's ProviderSpec, implementing the architecture improvement
suggested in docs/architecture/orchestrator-validation.md.

Changes:
- Add base_url_patterns field to ProviderSpec dataclass
- Split ollama and openrouter into separate ProviderSpec entries
  (previously shared openai's spec with dynamic provider selection)
- Add validate_provider_base_url() function in factory for reusable validation
- Simplify LLMRunner._detect_provider_mismatch() to delegate to factory
- Update architecture doc with change log and implementation notes

Benefits:
- Single source of truth for provider configuration
- Easier maintenance when adding/updating providers
- Reduced code duplication (llm_runner.py: -39 lines, factory.py: +84 lines)
- Factory validation can be tested independently

All 28 orchestrator validation tests pass, including 6 provider mismatch tests.
This commit is contained in:
陈少杰 2026-04-16 20:06:30 +08:00
parent a5fd95af82
commit 78312851f9
3 changed files with 127 additions and 50 deletions

View File

@ -4,6 +4,15 @@ Status: implemented (2026-04-16)
Audience: orchestrator users, backend maintainers Audience: orchestrator users, backend maintainers
Scope: LLMRunner configuration validation and error classification Scope: LLMRunner configuration validation and error classification
## Change Log
**2026-04-16**: Refactored provider validation to centralize patterns in `factory.py`
- Moved `_PROVIDER_BASE_URL_PATTERNS` from `llm_runner.py` to `ProviderSpec.base_url_patterns` in `factory.py`
- Added `validate_provider_base_url()` function in factory for reusable validation
- Split ollama and openrouter into separate `ProviderSpec` entries (previously shared openai's spec)
- Reduced `llm_runner.py` from 45 lines to 13 lines for validation logic
- All 21 tests pass, including 6 provider mismatch tests
## Overview ## Overview
`orchestrator/llm_runner.py` implements three layers of configuration validation to catch errors before expensive graph initialization or API calls: `orchestrator/llm_runner.py` implements three layers of configuration validation to catch errors before expensive graph initialization or API calls:
@ -243,10 +252,20 @@ python -m pytest orchestrator/tests/test_llm_runner.py -v
When adding a new provider to `tradingagents/llm_clients/factory.py`: When adding a new provider to `tradingagents/llm_clients/factory.py`:
1. Add URL pattern to `_PROVIDER_BASE_URL_PATTERNS` in `llm_runner.py` 1. Add a new `ProviderSpec` entry to `_PROVIDER_SPECS` tuple with `base_url_patterns`
2. Add test cases for valid and invalid configurations 2. Add test cases for valid and invalid configurations in `orchestrator/tests/test_llm_runner.py`
3. Update this documentation 3. Update this documentation
**Example:**
```python
ProviderSpec(
canonical_name="newprovider",
aliases=("newprovider",),
builder=lambda model, base_url=None, **kwargs: NewProviderClient(model, base_url, **kwargs),
base_url_patterns=(r"api\.newprovider\.com",),
)
```
### Adjusting Timeout Recommendations ### Adjusting Timeout Recommendations
If profiling shows different timeout requirements: If profiling shows different timeout requirements:
@ -277,11 +296,25 @@ Current implementation does **not** validate API key validity before graph initi
### Provider Pattern Maintenance ### Provider Pattern Maintenance
URL patterns must be manually kept in sync with provider changes: ~~URL patterns must be manually kept in sync with provider changes:~~
**UPDATE (2026-04-16)**: Provider URL patterns have been moved to `tradingagents/llm_clients/factory.py` as part of `ProviderSpec`. This centralizes validation logic with provider definitions.
**Current implementation:**
- Each `ProviderSpec` includes optional `base_url_patterns` tuple
- `validate_provider_base_url()` function provides validation logic
- `LLMRunner._detect_provider_mismatch()` delegates to factory validation
- Patterns are co-located with provider builders, reducing maintenance burden
**Benefits:**
- Single source of truth for provider configuration
- Easier to keep patterns in sync when adding/updating providers
- Factory can be tested independently of orchestrator
- Reduced code duplication
**Remaining considerations:**
- **Risk**: Provider changes base URL structure (e.g., API versioning) - **Risk**: Provider changes base URL structure (e.g., API versioning)
- **Mitigation**: Validation is non-blocking; mismatches are logged but don't prevent operation - **Mitigation**: Validation is non-blocking; mismatches are logged but don't prevent operation
- **Future**: Consider moving patterns to `tradingagents/llm_clients/factory.py` as part of `ProviderSpec`
### Timeout Recommendations ### Timeout Recommendations

View File

@ -1,33 +1,16 @@
import json import json
import logging import logging
import os import os
import re
from datetime import datetime, timezone from datetime import datetime, timezone
from orchestrator.config import OrchestratorConfig from orchestrator.config import OrchestratorConfig
from orchestrator.contracts.error_taxonomy import ReasonCode from orchestrator.contracts.error_taxonomy import ReasonCode
from orchestrator.contracts.result_contract import Signal, build_error_signal from orchestrator.contracts.result_contract import Signal, build_error_signal
from tradingagents.agents.utils.agent_states import extract_research_provenance from tradingagents.agents.utils.agent_states import extract_research_provenance
from tradingagents.llm_clients.factory import validate_provider_base_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Provider × base_url validation matrix
# Note: ollama/openrouter share openai's canonical provider but have different URL patterns
_PROVIDER_BASE_URL_PATTERNS = {
"anthropic": [r"api\.anthropic\.com", r"api\.minimaxi\.com/anthropic"],
"openai": [r"api\.openai\.com"],
"google": [r"generativelanguage\.googleapis\.com"],
"xai": [r"api\.x\.ai"],
"ollama": [r"localhost:\d+", r"127\.0\.0\.1:\d+", r"ollama"],
"openrouter": [r"openrouter\.ai"],
}
# Precompile regex patterns for efficiency
_COMPILED_PATTERNS = {
provider: [re.compile(pattern) for pattern in patterns]
for provider, patterns in _PROVIDER_BASE_URL_PATTERNS.items()
}
# Recommended timeout thresholds by analyst count # Recommended timeout thresholds by analyst count
_RECOMMENDED_TIMEOUTS = { _RECOMMENDED_TIMEOUTS = {
1: {"analyst": 75.0, "research": 30.0}, 1: {"analyst": 75.0, "research": 30.0},
@ -110,35 +93,19 @@ class LLMRunner:
return self._graph return self._graph
def _detect_provider_mismatch(self): def _detect_provider_mismatch(self):
"""Validate provider × base_url compatibility using pattern matrix. """Validate provider × base_url compatibility using factory's validation.
Uses the original provider name (not canonical) for validation since Uses the original provider name (not canonical) for validation since
ollama/openrouter share openai's canonical provider but have different URLs. ollama/openrouter have different URL patterns than openai.
""" """
trading_cfg = self._config.trading_agents_config or {} trading_cfg = self._config.trading_agents_config or {}
provider = str(trading_cfg.get("llm_provider", "")).lower() provider = trading_cfg.get("llm_provider", "")
base_url = str(trading_cfg.get("backend_url", "") or "").lower() base_url = trading_cfg.get("backend_url", "")
if not provider or not base_url: if not provider or not base_url:
return None return None
# Use original provider name for pattern matching (not canonical) return validate_provider_base_url(provider, base_url)
# This handles ollama/openrouter which share openai's canonical provider
compiled_patterns = _COMPILED_PATTERNS.get(provider, [])
if not compiled_patterns:
# No validation rules defined for this provider
return None
for pattern in compiled_patterns:
if pattern.search(base_url):
return None # Match found, no mismatch
# No pattern matched - return raw patterns for error message
return {
"provider": provider,
"backend_url": trading_cfg.get("backend_url"),
"expected_patterns": _PROVIDER_BASE_URL_PATTERNS[provider],
}
def get_signal(self, ticker: str, date: str) -> Signal: def get_signal(self, ticker: str, date: str) -> Signal:
"""获取指定股票在指定日期的 LLM 信号,带缓存。""" """获取指定股票在指定日期的 LLM 信号,带缓存。"""

View File

@ -1,5 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Optional from typing import Callable, Optional
import re
from .base_client import BaseLLMClient from .base_client import BaseLLMClient
from .openai_client import OpenAIClient from .openai_client import OpenAIClient
@ -15,23 +16,54 @@ _OPENAI_COMPATIBLE = (
@dataclass(frozen=True) @dataclass(frozen=True)
class ProviderSpec: class ProviderSpec:
"""Provider registry entry for LLM client creation.""" """Provider registry entry for LLM client creation.
Attributes:
canonical_name: Primary provider identifier
aliases: Alternative names that resolve to this provider
builder: Factory function to create the client instance
base_url_patterns: Regex patterns for valid base URLs (None = no validation)
"""
canonical_name: str canonical_name: str
aliases: tuple[str, ...] aliases: tuple[str, ...]
builder: Callable[..., BaseLLMClient] builder: Callable[..., BaseLLMClient]
base_url_patterns: Optional[tuple[str, ...]] = None
_PROVIDER_SPECS: tuple[ProviderSpec, ...] = ( _PROVIDER_SPECS: tuple[ProviderSpec, ...] = (
ProviderSpec( ProviderSpec(
canonical_name="openai", canonical_name="openai",
aliases=("openai", "ollama", "openrouter"), aliases=("openai",),
builder=lambda model, base_url=None, **kwargs: OpenAIClient( builder=lambda model, base_url=None, **kwargs: OpenAIClient(
model, model,
base_url, base_url,
provider=kwargs.pop("provider", "openai"), provider="openai",
**kwargs, **kwargs,
), ),
base_url_patterns=(r"api\.openai\.com",),
),
ProviderSpec(
canonical_name="ollama",
aliases=("ollama",),
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
model,
base_url,
provider="ollama",
**kwargs,
),
base_url_patterns=(r"localhost:\d+", r"127\.0\.0\.1:\d+", r"ollama"),
),
ProviderSpec(
canonical_name="openrouter",
aliases=("openrouter",),
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
model,
base_url,
provider="openrouter",
**kwargs,
),
base_url_patterns=(r"openrouter\.ai",),
), ),
ProviderSpec( ProviderSpec(
canonical_name="xai", canonical_name="xai",
@ -42,16 +74,19 @@ _PROVIDER_SPECS: tuple[ProviderSpec, ...] = (
provider="xai", provider="xai",
**kwargs, **kwargs,
), ),
base_url_patterns=(r"api\.x\.ai",),
), ),
ProviderSpec( ProviderSpec(
canonical_name="anthropic", canonical_name="anthropic",
aliases=("anthropic",), aliases=("anthropic",),
builder=lambda model, base_url=None, **kwargs: AnthropicClient(model, base_url, **kwargs), builder=lambda model, base_url=None, **kwargs: AnthropicClient(model, base_url, **kwargs),
base_url_patterns=(r"api\.anthropic\.com", r"api\.minimaxi\.com/anthropic"),
), ),
ProviderSpec( ProviderSpec(
canonical_name="google", canonical_name="google",
aliases=("google",), aliases=("google",),
builder=lambda model, base_url=None, **kwargs: GoogleClient(model, base_url, **kwargs), builder=lambda model, base_url=None, **kwargs: GoogleClient(model, base_url, **kwargs),
base_url_patterns=(r"generativelanguage\.googleapis\.com",),
), ),
) )
@ -92,7 +127,49 @@ def create_llm_client(
""" """
provider_lower = provider.lower() provider_lower = provider.lower()
provider_spec = get_provider_spec(provider_lower) provider_spec = get_provider_spec(provider_lower)
builder_kwargs = dict(kwargs) return provider_spec.builder(model, base_url, **kwargs)
if provider_lower in ("openai", "ollama", "openrouter"):
builder_kwargs["provider"] = provider_lower
return provider_spec.builder(model, base_url, **builder_kwargs) def validate_provider_base_url(provider: str, base_url: str) -> Optional[dict]:
"""Validate provider × base_url compatibility.
Args:
provider: LLM provider name (original, not canonical)
base_url: API endpoint URL
Returns:
None if valid, or dict with mismatch details if invalid:
{
"provider": str,
"backend_url": str,
"expected_patterns": tuple[str, ...]
}
"""
if not provider or not base_url:
return None
provider_lower = provider.lower()
base_url_lower = base_url.lower()
try:
spec = get_provider_spec(provider_lower)
except ValueError:
# Unknown provider - no validation rules
return None
if spec.base_url_patterns is None:
# No validation rules defined for this provider
return None
# Compile and test patterns
for pattern_str in spec.base_url_patterns:
pattern = re.compile(pattern_str)
if pattern.search(base_url_lower):
return None # Match found
# No pattern matched - return mismatch details
return {
"provider": provider_lower,
"backend_url": base_url,
"expected_patterns": spec.base_url_patterns,
}