91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
import os
|
||
from typing import Any, Optional
|
||
|
||
from langchain_openai import ChatOpenAI
|
||
|
||
from .base_client import BaseLLMClient
|
||
from .validators import validate_model
|
||
|
||
|
||
class UnifiedChatOpenAI(ChatOpenAI):
|
||
"""ChatOpenAI subclass that strips incompatible params for certain models."""
|
||
|
||
def __init__(self, **kwargs):
|
||
model = kwargs.get("model", "")
|
||
if self._is_reasoning_model(model):
|
||
kwargs.pop("temperature", None)
|
||
kwargs.pop("top_p", None)
|
||
super().__init__(**kwargs)
|
||
|
||
@staticmethod
|
||
def _is_reasoning_model(model: str) -> bool:
|
||
"""Check if model is a reasoning model that doesn't support temperature."""
|
||
model_lower = model.lower()
|
||
return (
|
||
model_lower.startswith("o1")
|
||
or model_lower.startswith("o3")
|
||
or "gpt-5" in model_lower
|
||
)
|
||
|
||
|
||
class OpenAIClient(BaseLLMClient):
|
||
"""Client for OpenAI-compatible providers.
|
||
|
||
Supported providers:
|
||
- openai → OpenAI platform
|
||
- ollama → Local Ollama server (no auth)
|
||
- openrouter → OpenRouter API
|
||
- xai → xAI / Grok API
|
||
- ark → ByteDance Ark (OpenAI-compatible API)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
model: str,
|
||
base_url: Optional[str] = None,
|
||
provider: str = "openai",
|
||
**kwargs,
|
||
):
|
||
super().__init__(model, base_url, **kwargs)
|
||
self.provider = provider.lower()
|
||
|
||
def get_llm(self) -> Any:
|
||
"""Return configured ChatOpenAI instance."""
|
||
llm_kwargs = {"model": self.model}
|
||
|
||
if self.provider == "xai":
|
||
llm_kwargs["base_url"] = "https://api.x.ai/v1"
|
||
api_key = os.environ.get("XAI_API_KEY")
|
||
if api_key:
|
||
llm_kwargs["api_key"] = api_key
|
||
elif self.provider == "openrouter":
|
||
llm_kwargs["base_url"] = "https://openrouter.ai/api/v1"
|
||
api_key = os.environ.get("OPENROUTER_API_KEY")
|
||
if api_key:
|
||
llm_kwargs["api_key"] = api_key
|
||
elif self.provider == "ollama":
|
||
llm_kwargs["base_url"] = "http://localhost:11434/v1"
|
||
llm_kwargs["api_key"] = "ollama" # Ollama doesn't require auth
|
||
elif self.provider == "ark":
|
||
# ByteDance Ark (OpenAI-compatible) – API key from ARK_API_KEY
|
||
# Default base_url matches official docs but can be overridden.
|
||
llm_kwargs["base_url"] = (
|
||
self.base_url
|
||
or "https://ark.ap-southeast.bytepluses.com/api/v3"
|
||
)
|
||
api_key = os.environ.get("ARK_API_KEY")
|
||
if api_key:
|
||
llm_kwargs["api_key"] = api_key
|
||
elif self.base_url:
|
||
llm_kwargs["base_url"] = self.base_url
|
||
|
||
for key in ("timeout", "max_retries", "reasoning_effort", "api_key", "callbacks"):
|
||
if key in self.kwargs:
|
||
llm_kwargs[key] = self.kwargs[key]
|
||
|
||
return UnifiedChatOpenAI(**llm_kwargs)
|
||
|
||
def validate_model(self) -> bool:
|
||
"""Validate model for the provider."""
|
||
return validate_model(self.provider, self.model)
|