TradingAgents/tradingagents/llm_clients/factory.py

184 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass
from typing import Callable, Optional, TypedDict
import re
from .base_client import BaseLLMClient
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
from .google_client import GoogleClient
from .azure_client import AzureOpenAIClient
# Providers that use the OpenAI-compatible chat completions API
_OPENAI_COMPATIBLE = (
"openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter",
)
# Compiled pattern cache for validation performance
_COMPILED_PATTERNS: dict[str, list[re.Pattern]] = {}
class ProviderMismatch(TypedDict):
"""Provider validation mismatch details."""
provider: str
backend_url: str
expected_patterns: tuple[str, ...]
@dataclass(frozen=True)
class ProviderSpec:
"""Provider registry entry for LLM client creation.
Attributes:
canonical_name: Primary provider identifier
aliases: Alternative names that resolve to this provider
builder: Factory function to create the client instance
base_url_patterns: Regex patterns for valid base URLs (None = no validation)
"""
canonical_name: str
aliases: tuple[str, ...]
builder: Callable[..., BaseLLMClient]
base_url_patterns: Optional[tuple[str, ...]] = None
_PROVIDER_SPECS: tuple[ProviderSpec, ...] = (
ProviderSpec(
canonical_name="openai",
aliases=("openai",),
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
model,
base_url,
provider="openai",
**kwargs,
),
base_url_patterns=(r"api\.openai\.com",),
),
ProviderSpec(
canonical_name="ollama",
aliases=("ollama",),
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
model,
base_url,
provider="ollama",
**kwargs,
),
base_url_patterns=(r"localhost:\d+", r"127\.0\.0\.1:\d+", r"ollama"),
),
ProviderSpec(
canonical_name="openrouter",
aliases=("openrouter",),
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
model,
base_url,
provider="openrouter",
**kwargs,
),
base_url_patterns=(r"openrouter\.ai",),
),
ProviderSpec(
canonical_name="xai",
aliases=("xai",),
builder=lambda model, base_url=None, **kwargs: OpenAIClient(
model,
base_url,
provider="xai",
**kwargs,
),
base_url_patterns=(r"api\.x\.ai",),
),
ProviderSpec(
canonical_name="anthropic",
aliases=("anthropic",),
builder=lambda model, base_url=None, **kwargs: AnthropicClient(model, base_url, **kwargs),
base_url_patterns=(r"api\.anthropic\.com", r"api\.minimaxi\.com/anthropic"),
),
ProviderSpec(
canonical_name="google",
aliases=("google",),
builder=lambda model, base_url=None, **kwargs: GoogleClient(model, base_url, **kwargs),
base_url_patterns=(r"generativelanguage\.googleapis\.com",),
),
)
def get_provider_spec(provider: str) -> ProviderSpec:
"""Resolve a provider or alias to its canonical registry entry."""
provider_lower = provider.lower()
for spec in _PROVIDER_SPECS:
if provider_lower in spec.aliases:
return spec
raise ValueError(f"Unsupported LLM provider: {provider}")
def get_supported_providers() -> tuple[str, ...]:
"""Return canonical provider names exposed by the registry."""
return tuple(spec.canonical_name for spec in _PROVIDER_SPECS)
def create_llm_client(
provider: str,
model: str,
base_url: Optional[str] = None,
**kwargs,
) -> BaseLLMClient:
"""Create an LLM client for the specified provider.
Args:
provider: LLM provider name
model: Model name/identifier
base_url: Optional base URL for API endpoint
**kwargs: Additional provider-specific arguments
Returns:
Configured BaseLLMClient instance
Raises:
ValueError: If provider is not supported
"""
provider_lower = provider.lower()
provider_spec = get_provider_spec(provider_lower)
return provider_spec.builder(model, base_url, **kwargs)
def validate_provider_base_url(provider: str, base_url: str) -> Optional[ProviderMismatch]:
"""Validate provider × base_url compatibility.
Args:
provider: LLM provider name (original, not canonical)
base_url: API endpoint URL
Returns:
None if valid, or ProviderMismatch dict if invalid
"""
if not provider or not base_url:
return None
provider_lower = provider.lower()
base_url_lower = base_url.lower()
try:
spec = get_provider_spec(provider_lower)
except ValueError:
# Unknown provider - no validation rules
return None
if spec.base_url_patterns is None:
# No validation rules defined for this provider
return None
# Use cached compiled patterns for performance
cache_key = spec.canonical_name
if cache_key not in _COMPILED_PATTERNS:
_COMPILED_PATTERNS[cache_key] = [re.compile(p) for p in spec.base_url_patterns]
for pattern in _COMPILED_PATTERNS[cache_key]:
if pattern.search(base_url_lower):
return None # Match found
# No pattern matched - return mismatch details
return {
"provider": provider_lower,
"backend_url": base_url,
"expected_patterns": spec.base_url_patterns,
}