44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Any, Optional
|
|
import warnings
|
|
|
|
|
|
class BaseLLMClient(ABC):
|
|
"""Abstract base class for LLM clients."""
|
|
|
|
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
|
self.model = model
|
|
self.base_url = base_url
|
|
self.kwargs = kwargs
|
|
|
|
def get_provider_name(self) -> str:
|
|
"""Return the provider name used in warning messages."""
|
|
provider = getattr(self, "provider", None)
|
|
if provider:
|
|
return str(provider)
|
|
return self.__class__.__name__.removesuffix("Client").lower()
|
|
|
|
def warn_if_unknown_model(self) -> None:
|
|
"""Warn when the model is outside the known list for the provider."""
|
|
if self.validate_model():
|
|
return
|
|
|
|
warnings.warn(
|
|
(
|
|
f"Model '{self.model}' is not in the known model list for "
|
|
f"provider '{self.get_provider_name()}'. Continuing anyway."
|
|
),
|
|
RuntimeWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
@abstractmethod
|
|
def get_llm(self) -> Any:
|
|
"""Return the configured LLM instance."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def validate_model(self) -> bool:
|
|
"""Validate that the model is supported by this client."""
|
|
pass
|