184 lines
5.4 KiB
Python
184 lines
5.4 KiB
Python
from dataclasses import dataclass
|
||
from typing import Callable, Optional, TypedDict
|
||
import re
|
||
|
||
from .base_client import BaseLLMClient
|
||
from .openai_client import OpenAIClient
|
||
from .anthropic_client import AnthropicClient
|
||
from .google_client import GoogleClient
|
||
from .azure_client import AzureOpenAIClient
|
||
|
||
# Providers that use the OpenAI-compatible chat completions API
|
||
_OPENAI_COMPATIBLE = (
|
||
"openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter",
|
||
)
|
||
|
||
# Compiled pattern cache for validation performance
|
||
_COMPILED_PATTERNS: dict[str, list[re.Pattern]] = {}
|
||
|
||
|
||
class ProviderMismatch(TypedDict):
|
||
"""Provider validation mismatch details."""
|
||
provider: str
|
||
backend_url: str
|
||
expected_patterns: tuple[str, ...]
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class ProviderSpec:
|
||
"""Provider registry entry for LLM client creation.
|
||
|
||
Attributes:
|
||
canonical_name: Primary provider identifier
|
||
aliases: Alternative names that resolve to this provider
|
||
builder: Factory function to create the client instance
|
||
base_url_patterns: Regex patterns for valid base URLs (None = no validation)
|
||
"""
|
||
|
||
canonical_name: str
|
||
aliases: tuple[str, ...]
|
||
builder: Callable[..., BaseLLMClient]
|
||
base_url_patterns: Optional[tuple[str, ...]] = None
|
||
|
||
|
||
_PROVIDER_SPECS: tuple[ProviderSpec, ...] = (
|
||
ProviderSpec(
|
||
canonical_name="openai",
|
||
aliases=("openai",),
|
||
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
|
||
model,
|
||
base_url,
|
||
provider="openai",
|
||
**kwargs,
|
||
),
|
||
base_url_patterns=(r"api\.openai\.com",),
|
||
),
|
||
ProviderSpec(
|
||
canonical_name="ollama",
|
||
aliases=("ollama",),
|
||
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
|
||
model,
|
||
base_url,
|
||
provider="ollama",
|
||
**kwargs,
|
||
),
|
||
base_url_patterns=(r"localhost:\d+", r"127\.0\.0\.1:\d+", r"ollama"),
|
||
),
|
||
ProviderSpec(
|
||
canonical_name="openrouter",
|
||
aliases=("openrouter",),
|
||
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
|
||
model,
|
||
base_url,
|
||
provider="openrouter",
|
||
**kwargs,
|
||
),
|
||
base_url_patterns=(r"openrouter\.ai",),
|
||
),
|
||
ProviderSpec(
|
||
canonical_name="xai",
|
||
aliases=("xai",),
|
||
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
|
||
model,
|
||
base_url,
|
||
provider="xai",
|
||
**kwargs,
|
||
),
|
||
base_url_patterns=(r"api\.x\.ai",),
|
||
),
|
||
ProviderSpec(
|
||
canonical_name="anthropic",
|
||
aliases=("anthropic",),
|
||
builder=lambda model, base_url=None, **kwargs: AnthropicClient(model, base_url, **kwargs),
|
||
base_url_patterns=(r"api\.anthropic\.com", r"api\.minimaxi\.com/anthropic"),
|
||
),
|
||
ProviderSpec(
|
||
canonical_name="google",
|
||
aliases=("google",),
|
||
builder=lambda model, base_url=None, **kwargs: GoogleClient(model, base_url, **kwargs),
|
||
base_url_patterns=(r"generativelanguage\.googleapis\.com",),
|
||
),
|
||
)
|
||
|
||
|
||
def get_provider_spec(provider: str) -> ProviderSpec:
|
||
"""Resolve a provider or alias to its canonical registry entry."""
|
||
provider_lower = provider.lower()
|
||
for spec in _PROVIDER_SPECS:
|
||
if provider_lower in spec.aliases:
|
||
return spec
|
||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||
|
||
|
||
def get_supported_providers() -> tuple[str, ...]:
|
||
"""Return canonical provider names exposed by the registry."""
|
||
return tuple(spec.canonical_name for spec in _PROVIDER_SPECS)
|
||
|
||
|
||
def create_llm_client(
|
||
provider: str,
|
||
model: str,
|
||
base_url: Optional[str] = None,
|
||
**kwargs,
|
||
) -> BaseLLMClient:
|
||
"""Create an LLM client for the specified provider.
|
||
|
||
Args:
|
||
provider: LLM provider name
|
||
model: Model name/identifier
|
||
base_url: Optional base URL for API endpoint
|
||
**kwargs: Additional provider-specific arguments
|
||
|
||
Returns:
|
||
Configured BaseLLMClient instance
|
||
|
||
Raises:
|
||
ValueError: If provider is not supported
|
||
"""
|
||
provider_lower = provider.lower()
|
||
provider_spec = get_provider_spec(provider_lower)
|
||
return provider_spec.builder(model, base_url, **kwargs)
|
||
|
||
|
||
def validate_provider_base_url(provider: str, base_url: str) -> Optional[ProviderMismatch]:
|
||
"""Validate provider × base_url compatibility.
|
||
|
||
Args:
|
||
provider: LLM provider name (original, not canonical)
|
||
base_url: API endpoint URL
|
||
|
||
Returns:
|
||
None if valid, or ProviderMismatch dict if invalid
|
||
"""
|
||
if not provider or not base_url:
|
||
return None
|
||
|
||
provider_lower = provider.lower()
|
||
base_url_lower = base_url.lower()
|
||
|
||
try:
|
||
spec = get_provider_spec(provider_lower)
|
||
except ValueError:
|
||
# Unknown provider - no validation rules
|
||
return None
|
||
|
||
if spec.base_url_patterns is None:
|
||
# No validation rules defined for this provider
|
||
return None
|
||
|
||
# Use cached compiled patterns for performance
|
||
cache_key = spec.canonical_name
|
||
if cache_key not in _COMPILED_PATTERNS:
|
||
_COMPILED_PATTERNS[cache_key] = [re.compile(p) for p in spec.base_url_patterns]
|
||
|
||
for pattern in _COMPILED_PATTERNS[cache_key]:
|
||
if pattern.search(base_url_lower):
|
||
return None # Match found
|
||
|
||
# No pattern matched - return mismatch details
|
||
return {
|
||
"provider": provider_lower,
|
||
"backend_url": base_url,
|
||
"expected_patterns": spec.base_url_patterns,
|
||
}
|