This commit is contained in:
Jiaxu Liu 2026-03-23 14:03:27 +00:00
parent 894b7bcba0
commit a84c69e42f
1 changed files with 20 additions and 102 deletions

View File

@ -1,106 +1,11 @@
"""OpenAI-compatible LLM clients (OpenAI, xAI, OpenRouter, Ollama).
OpenAI auth resolution order:
1. ``api_key`` kwarg (explicit key always wins)
2. ``OPENAI_API_KEY`` environment variable
3. ``~/.codex/auth.json`` OpenAI Codex CLI OAuth token (auto-refreshed)
"""
import json
import os
import time
from pathlib import Path
from typing import Any, Optional
import requests
from langchain_openai import ChatOpenAI
from .base_client import BaseLLMClient, normalize_content
from .validators import validate_model
# ---------------------------------------------------------------------------
# Codex OAuth token reader (inlined from auth/codex_token.py)
# ---------------------------------------------------------------------------
_CODEX_AUTH_FILE = Path.home() / ".codex" / "auth.json"
_CODEX_TOKEN_URL = "https://auth.openai.com/oauth/token"
_EXPIRY_BUFFER_SECS = 60
def _load_codex_auth() -> Optional[dict]:
if not _CODEX_AUTH_FILE.exists():
return None
try:
return json.loads(_CODEX_AUTH_FILE.read_text())
except (json.JSONDecodeError, OSError):
return None
def _codex_token_expired(auth: dict) -> bool:
expires = auth.get("expires_at") or auth.get("tokens", {}).get("expires_at")
if expires is None:
try:
import base64
token = auth["tokens"]["access_token"]
payload = token.split(".")[1]
decoded = json.loads(base64.b64decode(payload + "=="))
expires = decoded.get("exp")
except Exception:
return False
return time.time() >= (expires - _EXPIRY_BUFFER_SECS)
def _refresh_codex_token(auth: dict) -> dict:
refresh_token = auth["tokens"]["refresh_token"]
resp = requests.post(
_CODEX_TOKEN_URL,
json={"grant_type": "refresh_token", "refresh_token": refresh_token},
headers={"Content-Type": "application/json"},
timeout=15,
)
resp.raise_for_status()
new_tokens = resp.json()
auth["tokens"].update({
"access_token": new_tokens["access_token"],
"refresh_token": new_tokens.get("refresh_token", refresh_token),
"expires_at": new_tokens.get("expires_in") and int(time.time()) + int(new_tokens["expires_in"]),
})
auth["last_refresh"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
_CODEX_AUTH_FILE.write_text(json.dumps(auth, indent=2))
return auth
def _get_codex_token() -> Optional[str]:
"""Return a valid OpenAI token from OPENAI_API_KEY or ~/.codex/auth.json."""
explicit = os.environ.get("OPENAI_API_KEY")
if explicit:
return explicit
auth = _load_codex_auth()
if not auth or "tokens" not in auth:
return None
if _codex_token_expired(auth):
try:
auth = _refresh_codex_token(auth)
except Exception:
pass
return auth["tokens"].get("access_token")
# ---------------------------------------------------------------------------
# OpenAI-compatible client
# ---------------------------------------------------------------------------
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "reasoning_effort",
"api_key", "callbacks", "http_client", "http_async_client",
)
_PROVIDER_CONFIG = {
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
"ollama": ("http://localhost:11434/v1", None),
}
class NormalizedChatOpenAI(ChatOpenAI):
"""ChatOpenAI with normalized content output.
@ -113,9 +18,22 @@ class NormalizedChatOpenAI(ChatOpenAI):
def invoke(self, input, config=None, **kwargs):
return normalize_content(super().invoke(input, config, **kwargs))
# Kwargs forwarded from user config to ChatOpenAI
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "reasoning_effort",
"api_key", "callbacks", "http_client", "http_async_client",
)
# Provider base URLs and API key env vars
_PROVIDER_CONFIG = {
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
"ollama": ("http://localhost:11434/v1", None),
}
class OpenAIClient(BaseLLMClient):
"""Client for OpenAI, xAI, OpenRouter, and Ollama providers.
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers.
For native OpenAI models, uses the Responses API (/v1/responses) which
supports reasoning_effort with function tools across all model families
@ -137,6 +55,7 @@ class OpenAIClient(BaseLLMClient):
"""Return configured ChatOpenAI instance."""
llm_kwargs = {"model": self.model}
# Provider-specific base URL and auth
if self.provider in _PROVIDER_CONFIG:
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
llm_kwargs["base_url"] = base_url
@ -149,19 +68,18 @@ class OpenAIClient(BaseLLMClient):
elif self.base_url:
llm_kwargs["base_url"] = self.base_url
# Forward user-provided kwargs
for key in _PASSTHROUGH_KWARGS:
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
# Native OpenAI: use Responses API for consistent behavior across
# all model families. Third-party providers use Chat Completions.
if self.provider == "openai":
llm_kwargs["use_responses_api"] = True
if "api_key" not in llm_kwargs:
token = _get_codex_token()
if token:
llm_kwargs["api_key"] = token
return NormalizedChatOpenAI(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate model for the provider."""
return validate_model(self.provider, self.model)