119 lines
4.3 KiB
Python
119 lines
4.3 KiB
Python
import os
|
|
import time
|
|
import logging
|
|
from typing import Any, Optional
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
from .base_client import BaseLLMClient, normalize_content
|
|
from .validators import validate_model
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_NULL_CHOICES_RETRIES = 3
|
|
_NULL_CHOICES_DELAY = 2 # seconds
|
|
|
|
|
|
class NormalizedChatOpenAI(ChatOpenAI):
|
|
"""ChatOpenAI with normalized content output.
|
|
|
|
The Responses API returns content as a list of typed blocks
|
|
(reasoning, text, etc.). This normalizes to string for consistent
|
|
downstream handling.
|
|
"""
|
|
|
|
def invoke(self, input, config=None, **kwargs):
|
|
for attempt in range(1, _NULL_CHOICES_RETRIES + 1):
|
|
try:
|
|
return normalize_content(super().invoke(input, config, **kwargs))
|
|
except TypeError as e:
|
|
if "null value for 'choices'" in str(e):
|
|
if attempt < _NULL_CHOICES_RETRIES:
|
|
logger.warning(
|
|
"Received null choices from API (content filter or transient error). "
|
|
"Retrying in %ds (attempt %d/%d)...",
|
|
_NULL_CHOICES_DELAY,
|
|
attempt,
|
|
_NULL_CHOICES_RETRIES,
|
|
)
|
|
time.sleep(_NULL_CHOICES_DELAY)
|
|
else:
|
|
raise RuntimeError(
|
|
"API returned null choices after retries. "
|
|
"The request may have been blocked by content moderation. "
|
|
"Try rephrasing the prompt or check the provider's content policy."
|
|
) from e
|
|
else:
|
|
raise
|
|
|
|
# Kwargs forwarded from user config to ChatOpenAI
|
|
_PASSTHROUGH_KWARGS = (
|
|
"timeout", "max_retries", "reasoning_effort",
|
|
"api_key", "callbacks", "http_client", "http_async_client",
|
|
)
|
|
|
|
# Provider base URLs and API key env vars
|
|
_PROVIDER_CONFIG = {
|
|
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
|
"deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"),
|
|
"qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"),
|
|
"glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"),
|
|
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
|
"ollama": ("http://localhost:11434/v1", None),
|
|
"minimax": ("https://api.minimaxi.chat/v1", "MINIMAX_API_KEY"),
|
|
}
|
|
|
|
|
|
class OpenAIClient(BaseLLMClient):
|
|
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers.
|
|
|
|
For native OpenAI models, uses the Responses API (/v1/responses) which
|
|
supports reasoning_effort with function tools across all model families
|
|
(GPT-4.1, GPT-5). Third-party compatible providers (xAI, OpenRouter,
|
|
Ollama) use standard Chat Completions.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str,
|
|
base_url: Optional[str] = None,
|
|
provider: str = "openai",
|
|
**kwargs,
|
|
):
|
|
super().__init__(model, base_url, **kwargs)
|
|
self.provider = provider.lower()
|
|
|
|
def get_llm(self) -> Any:
|
|
"""Return configured ChatOpenAI instance."""
|
|
self.warn_if_unknown_model()
|
|
llm_kwargs = {"model": self.model}
|
|
|
|
# Provider-specific base URL and auth
|
|
if self.provider in _PROVIDER_CONFIG:
|
|
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
|
|
llm_kwargs["base_url"] = base_url
|
|
if api_key_env:
|
|
api_key = os.environ.get(api_key_env)
|
|
if api_key:
|
|
llm_kwargs["api_key"] = api_key
|
|
else:
|
|
llm_kwargs["api_key"] = "ollama"
|
|
elif self.base_url:
|
|
llm_kwargs["base_url"] = self.base_url
|
|
|
|
# Forward user-provided kwargs
|
|
for key in _PASSTHROUGH_KWARGS:
|
|
if key in self.kwargs:
|
|
llm_kwargs[key] = self.kwargs[key]
|
|
|
|
# Native OpenAI: use Responses API for consistent behavior across
|
|
# all model families. Third-party providers use Chat Completions.
|
|
if self.provider == "openai":
|
|
llm_kwargs["use_responses_api"] = True
|
|
|
|
return NormalizedChatOpenAI(**llm_kwargs)
|
|
|
|
def validate_model(self) -> bool:
|
|
"""Validate model for the provider."""
|
|
return validate_model(self.provider, self.model)
|