99 lines
3.1 KiB
Python
99 lines
3.1 KiB
Python
from dataclasses import dataclass
|
|
from typing import Callable, Optional
|
|
|
|
from .base_client import BaseLLMClient
|
|
from .openai_client import OpenAIClient
|
|
from .anthropic_client import AnthropicClient
|
|
from .google_client import GoogleClient
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ProviderSpec:
|
|
"""Provider registry entry for LLM client creation."""
|
|
|
|
canonical_name: str
|
|
aliases: tuple[str, ...]
|
|
builder: Callable[..., BaseLLMClient]
|
|
|
|
|
|
_PROVIDER_SPECS: tuple[ProviderSpec, ...] = (
|
|
ProviderSpec(
|
|
canonical_name="openai",
|
|
aliases=("openai", "ollama", "openrouter"),
|
|
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
|
|
model,
|
|
base_url,
|
|
provider=kwargs.pop("provider", "openai"),
|
|
**kwargs,
|
|
),
|
|
),
|
|
ProviderSpec(
|
|
canonical_name="xai",
|
|
aliases=("xai",),
|
|
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
|
|
model,
|
|
base_url,
|
|
provider="xai",
|
|
**kwargs,
|
|
),
|
|
),
|
|
ProviderSpec(
|
|
canonical_name="anthropic",
|
|
aliases=("anthropic",),
|
|
builder=lambda model, base_url=None, **kwargs: AnthropicClient(model, base_url, **kwargs),
|
|
),
|
|
ProviderSpec(
|
|
canonical_name="google",
|
|
aliases=("google",),
|
|
builder=lambda model, base_url=None, **kwargs: GoogleClient(model, base_url, **kwargs),
|
|
),
|
|
)
|
|
|
|
|
|
def get_provider_spec(provider: str) -> ProviderSpec:
|
|
"""Resolve a provider or alias to its canonical registry entry."""
|
|
provider_lower = provider.lower()
|
|
for spec in _PROVIDER_SPECS:
|
|
if provider_lower in spec.aliases:
|
|
return spec
|
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
|
|
|
|
|
def get_supported_providers() -> tuple[str, ...]:
|
|
"""Return canonical provider names exposed by the registry."""
|
|
return tuple(spec.canonical_name for spec in _PROVIDER_SPECS)
|
|
|
|
|
|
def create_llm_client(
|
|
provider: str,
|
|
model: str,
|
|
base_url: Optional[str] = None,
|
|
**kwargs,
|
|
) -> BaseLLMClient:
|
|
"""Create an LLM client for the specified provider.
|
|
|
|
Args:
|
|
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
|
model: Model name/identifier
|
|
base_url: Optional base URL for API endpoint
|
|
**kwargs: Additional provider-specific arguments
|
|
- http_client: Custom httpx.Client for SSL proxy or certificate customization
|
|
- http_async_client: Custom httpx.AsyncClient for async operations
|
|
- timeout: Request timeout in seconds
|
|
- max_retries: Maximum retry attempts
|
|
- api_key: API key for the provider
|
|
- callbacks: LangChain callbacks
|
|
|
|
Returns:
|
|
Configured BaseLLMClient instance
|
|
|
|
Raises:
|
|
ValueError: If provider is not supported
|
|
"""
|
|
provider_lower = provider.lower()
|
|
provider_spec = get_provider_spec(provider_lower)
|
|
builder_kwargs = dict(kwargs)
|
|
if provider_lower in ("openai", "ollama", "openrouter"):
|
|
builder_kwargs["provider"] = provider_lower
|
|
return provider_spec.builder(model, base_url, **builder_kwargs)
|