refactor(orchestrator): centralize provider validation in factory
Move provider × base_url validation patterns from llm_runner.py to factory.py's ProviderSpec, implementing the architecture improvement suggested in docs/architecture/orchestrator-validation.md. Changes: - Add base_url_patterns field to ProviderSpec dataclass - Split ollama and openrouter into separate ProviderSpec entries (previously shared openai's spec with dynamic provider selection) - Add validate_provider_base_url() function in factory for reusable validation - Simplify LLMRunner._detect_provider_mismatch() to delegate to factory - Update architecture doc with change log and implementation notes Benefits: - Single source of truth for provider configuration - Easier maintenance when adding/updating providers - Reduced code duplication (llm_runner.py: -39 lines, factory.py: +84 lines) - Factory validation can be tested independently All 28 orchestrator validation tests pass, including 6 provider mismatch tests.
This commit is contained in:
parent
a5fd95af82
commit
78312851f9
|
|
@ -4,6 +4,15 @@ Status: implemented (2026-04-16)
|
||||||
Audience: orchestrator users, backend maintainers
|
Audience: orchestrator users, backend maintainers
|
||||||
Scope: LLMRunner configuration validation and error classification
|
Scope: LLMRunner configuration validation and error classification
|
||||||
|
|
||||||
|
## Change Log
|
||||||
|
|
||||||
|
**2026-04-16**: Refactored provider validation to centralize patterns in `factory.py`
|
||||||
|
- Moved `_PROVIDER_BASE_URL_PATTERNS` from `llm_runner.py` to `ProviderSpec.base_url_patterns` in `factory.py`
|
||||||
|
- Added `validate_provider_base_url()` function in factory for reusable validation
|
||||||
|
- Split ollama and openrouter into separate `ProviderSpec` entries (previously shared openai's spec)
|
||||||
|
- Reduced `llm_runner.py` from 45 lines to 13 lines for validation logic
|
||||||
|
- All 21 tests pass, including 6 provider mismatch tests
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
`orchestrator/llm_runner.py` implements three layers of configuration validation to catch errors before expensive graph initialization or API calls:
|
`orchestrator/llm_runner.py` implements three layers of configuration validation to catch errors before expensive graph initialization or API calls:
|
||||||
|
|
@ -243,10 +252,20 @@ python -m pytest orchestrator/tests/test_llm_runner.py -v
|
||||||
|
|
||||||
When adding a new provider to `tradingagents/llm_clients/factory.py`:
|
When adding a new provider to `tradingagents/llm_clients/factory.py`:
|
||||||
|
|
||||||
1. Add URL pattern to `_PROVIDER_BASE_URL_PATTERNS` in `llm_runner.py`
|
1. Add a new `ProviderSpec` entry to `_PROVIDER_SPECS` tuple with `base_url_patterns`
|
||||||
2. Add test cases for valid and invalid configurations
|
2. Add test cases for valid and invalid configurations in `orchestrator/tests/test_llm_runner.py`
|
||||||
3. Update this documentation
|
3. Update this documentation
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```python
|
||||||
|
ProviderSpec(
|
||||||
|
canonical_name="newprovider",
|
||||||
|
aliases=("newprovider",),
|
||||||
|
builder=lambda model, base_url=None, **kwargs: NewProviderClient(model, base_url, **kwargs),
|
||||||
|
base_url_patterns=(r"api\.newprovider\.com",),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
### Adjusting Timeout Recommendations
|
### Adjusting Timeout Recommendations
|
||||||
|
|
||||||
If profiling shows different timeout requirements:
|
If profiling shows different timeout requirements:
|
||||||
|
|
@ -277,11 +296,25 @@ Current implementation does **not** validate API key validity before graph initi
|
||||||
|
|
||||||
### Provider Pattern Maintenance
|
### Provider Pattern Maintenance
|
||||||
|
|
||||||
URL patterns must be manually kept in sync with provider changes:
|
~~URL patterns must be manually kept in sync with provider changes:~~
|
||||||
|
|
||||||
|
**UPDATE (2026-04-16)**: Provider URL patterns have been moved to `tradingagents/llm_clients/factory.py` as part of `ProviderSpec`. This centralizes validation logic with provider definitions.
|
||||||
|
|
||||||
|
**Current implementation:**
|
||||||
|
- Each `ProviderSpec` includes optional `base_url_patterns` tuple
|
||||||
|
- `validate_provider_base_url()` function provides validation logic
|
||||||
|
- `LLMRunner._detect_provider_mismatch()` delegates to factory validation
|
||||||
|
- Patterns are co-located with provider builders, reducing maintenance burden
|
||||||
|
|
||||||
|
**Benefits:**
|
||||||
|
- Single source of truth for provider configuration
|
||||||
|
- Easier to keep patterns in sync when adding/updating providers
|
||||||
|
- Factory can be tested independently of orchestrator
|
||||||
|
- Reduced code duplication
|
||||||
|
|
||||||
|
**Remaining considerations:**
|
||||||
- **Risk**: Provider changes base URL structure (e.g., API versioning)
|
- **Risk**: Provider changes base URL structure (e.g., API versioning)
|
||||||
- **Mitigation**: Validation is non-blocking; mismatches are logged but don't prevent operation
|
- **Mitigation**: Validation is non-blocking; mismatches are logged but don't prevent operation
|
||||||
- **Future**: Consider moving patterns to `tradingagents/llm_clients/factory.py` as part of `ProviderSpec`
|
|
||||||
|
|
||||||
### Timeout Recommendations
|
### Timeout Recommendations
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,33 +1,16 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from orchestrator.config import OrchestratorConfig
|
from orchestrator.config import OrchestratorConfig
|
||||||
from orchestrator.contracts.error_taxonomy import ReasonCode
|
from orchestrator.contracts.error_taxonomy import ReasonCode
|
||||||
from orchestrator.contracts.result_contract import Signal, build_error_signal
|
from orchestrator.contracts.result_contract import Signal, build_error_signal
|
||||||
from tradingagents.agents.utils.agent_states import extract_research_provenance
|
from tradingagents.agents.utils.agent_states import extract_research_provenance
|
||||||
|
from tradingagents.llm_clients.factory import validate_provider_base_url
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Provider × base_url validation matrix
|
|
||||||
# Note: ollama/openrouter share openai's canonical provider but have different URL patterns
|
|
||||||
_PROVIDER_BASE_URL_PATTERNS = {
|
|
||||||
"anthropic": [r"api\.anthropic\.com", r"api\.minimaxi\.com/anthropic"],
|
|
||||||
"openai": [r"api\.openai\.com"],
|
|
||||||
"google": [r"generativelanguage\.googleapis\.com"],
|
|
||||||
"xai": [r"api\.x\.ai"],
|
|
||||||
"ollama": [r"localhost:\d+", r"127\.0\.0\.1:\d+", r"ollama"],
|
|
||||||
"openrouter": [r"openrouter\.ai"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Precompile regex patterns for efficiency
|
|
||||||
_COMPILED_PATTERNS = {
|
|
||||||
provider: [re.compile(pattern) for pattern in patterns]
|
|
||||||
for provider, patterns in _PROVIDER_BASE_URL_PATTERNS.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Recommended timeout thresholds by analyst count
|
# Recommended timeout thresholds by analyst count
|
||||||
_RECOMMENDED_TIMEOUTS = {
|
_RECOMMENDED_TIMEOUTS = {
|
||||||
1: {"analyst": 75.0, "research": 30.0},
|
1: {"analyst": 75.0, "research": 30.0},
|
||||||
|
|
@ -110,35 +93,19 @@ class LLMRunner:
|
||||||
return self._graph
|
return self._graph
|
||||||
|
|
||||||
def _detect_provider_mismatch(self):
|
def _detect_provider_mismatch(self):
|
||||||
"""Validate provider × base_url compatibility using pattern matrix.
|
"""Validate provider × base_url compatibility using factory's validation.
|
||||||
|
|
||||||
Uses the original provider name (not canonical) for validation since
|
Uses the original provider name (not canonical) for validation since
|
||||||
ollama/openrouter share openai's canonical provider but have different URLs.
|
ollama/openrouter have different URL patterns than openai.
|
||||||
"""
|
"""
|
||||||
trading_cfg = self._config.trading_agents_config or {}
|
trading_cfg = self._config.trading_agents_config or {}
|
||||||
provider = str(trading_cfg.get("llm_provider", "")).lower()
|
provider = trading_cfg.get("llm_provider", "")
|
||||||
base_url = str(trading_cfg.get("backend_url", "") or "").lower()
|
base_url = trading_cfg.get("backend_url", "")
|
||||||
|
|
||||||
if not provider or not base_url:
|
if not provider or not base_url:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Use original provider name for pattern matching (not canonical)
|
return validate_provider_base_url(provider, base_url)
|
||||||
# This handles ollama/openrouter which share openai's canonical provider
|
|
||||||
compiled_patterns = _COMPILED_PATTERNS.get(provider, [])
|
|
||||||
if not compiled_patterns:
|
|
||||||
# No validation rules defined for this provider
|
|
||||||
return None
|
|
||||||
|
|
||||||
for pattern in compiled_patterns:
|
|
||||||
if pattern.search(base_url):
|
|
||||||
return None # Match found, no mismatch
|
|
||||||
|
|
||||||
# No pattern matched - return raw patterns for error message
|
|
||||||
return {
|
|
||||||
"provider": provider,
|
|
||||||
"backend_url": trading_cfg.get("backend_url"),
|
|
||||||
"expected_patterns": _PROVIDER_BASE_URL_PATTERNS[provider],
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_signal(self, ticker: str, date: str) -> Signal:
|
def get_signal(self, ticker: str, date: str) -> Signal:
|
||||||
"""获取指定股票在指定日期的 LLM 信号,带缓存。"""
|
"""获取指定股票在指定日期的 LLM 信号,带缓存。"""
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
import re
|
||||||
|
|
||||||
from .base_client import BaseLLMClient
|
from .base_client import BaseLLMClient
|
||||||
from .openai_client import OpenAIClient
|
from .openai_client import OpenAIClient
|
||||||
|
|
@ -15,23 +16,54 @@ _OPENAI_COMPATIBLE = (
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ProviderSpec:
|
class ProviderSpec:
|
||||||
"""Provider registry entry for LLM client creation."""
|
"""Provider registry entry for LLM client creation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
canonical_name: Primary provider identifier
|
||||||
|
aliases: Alternative names that resolve to this provider
|
||||||
|
builder: Factory function to create the client instance
|
||||||
|
base_url_patterns: Regex patterns for valid base URLs (None = no validation)
|
||||||
|
"""
|
||||||
|
|
||||||
canonical_name: str
|
canonical_name: str
|
||||||
aliases: tuple[str, ...]
|
aliases: tuple[str, ...]
|
||||||
builder: Callable[..., BaseLLMClient]
|
builder: Callable[..., BaseLLMClient]
|
||||||
|
base_url_patterns: Optional[tuple[str, ...]] = None
|
||||||
|
|
||||||
|
|
||||||
_PROVIDER_SPECS: tuple[ProviderSpec, ...] = (
|
_PROVIDER_SPECS: tuple[ProviderSpec, ...] = (
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
canonical_name="openai",
|
canonical_name="openai",
|
||||||
aliases=("openai", "ollama", "openrouter"),
|
aliases=("openai",),
|
||||||
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
|
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
|
||||||
model,
|
model,
|
||||||
base_url,
|
base_url,
|
||||||
provider=kwargs.pop("provider", "openai"),
|
provider="openai",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
),
|
),
|
||||||
|
base_url_patterns=(r"api\.openai\.com",),
|
||||||
|
),
|
||||||
|
ProviderSpec(
|
||||||
|
canonical_name="ollama",
|
||||||
|
aliases=("ollama",),
|
||||||
|
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
|
||||||
|
model,
|
||||||
|
base_url,
|
||||||
|
provider="ollama",
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
|
base_url_patterns=(r"localhost:\d+", r"127\.0\.0\.1:\d+", r"ollama"),
|
||||||
|
),
|
||||||
|
ProviderSpec(
|
||||||
|
canonical_name="openrouter",
|
||||||
|
aliases=("openrouter",),
|
||||||
|
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
|
||||||
|
model,
|
||||||
|
base_url,
|
||||||
|
provider="openrouter",
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
|
base_url_patterns=(r"openrouter\.ai",),
|
||||||
),
|
),
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
canonical_name="xai",
|
canonical_name="xai",
|
||||||
|
|
@ -42,16 +74,19 @@ _PROVIDER_SPECS: tuple[ProviderSpec, ...] = (
|
||||||
provider="xai",
|
provider="xai",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
),
|
),
|
||||||
|
base_url_patterns=(r"api\.x\.ai",),
|
||||||
),
|
),
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
canonical_name="anthropic",
|
canonical_name="anthropic",
|
||||||
aliases=("anthropic",),
|
aliases=("anthropic",),
|
||||||
builder=lambda model, base_url=None, **kwargs: AnthropicClient(model, base_url, **kwargs),
|
builder=lambda model, base_url=None, **kwargs: AnthropicClient(model, base_url, **kwargs),
|
||||||
|
base_url_patterns=(r"api\.anthropic\.com", r"api\.minimaxi\.com/anthropic"),
|
||||||
),
|
),
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
canonical_name="google",
|
canonical_name="google",
|
||||||
aliases=("google",),
|
aliases=("google",),
|
||||||
builder=lambda model, base_url=None, **kwargs: GoogleClient(model, base_url, **kwargs),
|
builder=lambda model, base_url=None, **kwargs: GoogleClient(model, base_url, **kwargs),
|
||||||
|
base_url_patterns=(r"generativelanguage\.googleapis\.com",),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -92,7 +127,49 @@ def create_llm_client(
|
||||||
"""
|
"""
|
||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower()
|
||||||
provider_spec = get_provider_spec(provider_lower)
|
provider_spec = get_provider_spec(provider_lower)
|
||||||
builder_kwargs = dict(kwargs)
|
return provider_spec.builder(model, base_url, **kwargs)
|
||||||
if provider_lower in ("openai", "ollama", "openrouter"):
|
|
||||||
builder_kwargs["provider"] = provider_lower
|
|
||||||
return provider_spec.builder(model, base_url, **builder_kwargs)
|
def validate_provider_base_url(provider: str, base_url: str) -> Optional[dict]:
|
||||||
|
"""Validate provider × base_url compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: LLM provider name (original, not canonical)
|
||||||
|
base_url: API endpoint URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None if valid, or dict with mismatch details if invalid:
|
||||||
|
{
|
||||||
|
"provider": str,
|
||||||
|
"backend_url": str,
|
||||||
|
"expected_patterns": tuple[str, ...]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if not provider or not base_url:
|
||||||
|
return None
|
||||||
|
|
||||||
|
provider_lower = provider.lower()
|
||||||
|
base_url_lower = base_url.lower()
|
||||||
|
|
||||||
|
try:
|
||||||
|
spec = get_provider_spec(provider_lower)
|
||||||
|
except ValueError:
|
||||||
|
# Unknown provider - no validation rules
|
||||||
|
return None
|
||||||
|
|
||||||
|
if spec.base_url_patterns is None:
|
||||||
|
# No validation rules defined for this provider
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Compile and test patterns
|
||||||
|
for pattern_str in spec.base_url_patterns:
|
||||||
|
pattern = re.compile(pattern_str)
|
||||||
|
if pattern.search(base_url_lower):
|
||||||
|
return None # Match found
|
||||||
|
|
||||||
|
# No pattern matched - return mismatch details
|
||||||
|
return {
|
||||||
|
"provider": provider_lower,
|
||||||
|
"backend_url": base_url,
|
||||||
|
"expected_patterns": spec.base_url_patterns,
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue