update
This commit is contained in:
parent
894b7bcba0
commit
a84c69e42f
|
|
@ -1,106 +1,11 @@
|
|||
"""OpenAI-compatible LLM clients (OpenAI, xAI, OpenRouter, Ollama).
|
||||
|
||||
OpenAI auth resolution order:
|
||||
1. ``api_key`` kwarg (explicit key always wins)
|
||||
2. ``OPENAI_API_KEY`` environment variable
|
||||
3. ``~/.codex/auth.json`` — OpenAI Codex CLI OAuth token (auto-refreshed)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from .base_client import BaseLLMClient, normalize_content
|
||||
from .validators import validate_model
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Codex OAuth token reader (inlined from auth/codex_token.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CODEX_AUTH_FILE = Path.home() / ".codex" / "auth.json"
|
||||
_CODEX_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
_EXPIRY_BUFFER_SECS = 60
|
||||
|
||||
|
||||
def _load_codex_auth() -> Optional[dict]:
|
||||
if not _CODEX_AUTH_FILE.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(_CODEX_AUTH_FILE.read_text())
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def _codex_token_expired(auth: dict) -> bool:
|
||||
expires = auth.get("expires_at") or auth.get("tokens", {}).get("expires_at")
|
||||
if expires is None:
|
||||
try:
|
||||
import base64
|
||||
token = auth["tokens"]["access_token"]
|
||||
payload = token.split(".")[1]
|
||||
decoded = json.loads(base64.b64decode(payload + "=="))
|
||||
expires = decoded.get("exp")
|
||||
except Exception:
|
||||
return False
|
||||
return time.time() >= (expires - _EXPIRY_BUFFER_SECS)
|
||||
|
||||
|
||||
def _refresh_codex_token(auth: dict) -> dict:
|
||||
refresh_token = auth["tokens"]["refresh_token"]
|
||||
resp = requests.post(
|
||||
_CODEX_TOKEN_URL,
|
||||
json={"grant_type": "refresh_token", "refresh_token": refresh_token},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=15,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
new_tokens = resp.json()
|
||||
auth["tokens"].update({
|
||||
"access_token": new_tokens["access_token"],
|
||||
"refresh_token": new_tokens.get("refresh_token", refresh_token),
|
||||
"expires_at": new_tokens.get("expires_in") and int(time.time()) + int(new_tokens["expires_in"]),
|
||||
})
|
||||
auth["last_refresh"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
_CODEX_AUTH_FILE.write_text(json.dumps(auth, indent=2))
|
||||
return auth
|
||||
|
||||
|
||||
def _get_codex_token() -> Optional[str]:
|
||||
"""Return a valid OpenAI token from OPENAI_API_KEY or ~/.codex/auth.json."""
|
||||
explicit = os.environ.get("OPENAI_API_KEY")
|
||||
if explicit:
|
||||
return explicit
|
||||
auth = _load_codex_auth()
|
||||
if not auth or "tokens" not in auth:
|
||||
return None
|
||||
if _codex_token_expired(auth):
|
||||
try:
|
||||
auth = _refresh_codex_token(auth)
|
||||
except Exception:
|
||||
pass
|
||||
return auth["tokens"].get("access_token")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI-compatible client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PASSTHROUGH_KWARGS = (
|
||||
"timeout", "max_retries", "reasoning_effort",
|
||||
"api_key", "callbacks", "http_client", "http_async_client",
|
||||
)
|
||||
|
||||
_PROVIDER_CONFIG = {
|
||||
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
||||
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
||||
"ollama": ("http://localhost:11434/v1", None),
|
||||
}
|
||||
|
||||
|
||||
class NormalizedChatOpenAI(ChatOpenAI):
|
||||
"""ChatOpenAI with normalized content output.
|
||||
|
|
@ -113,9 +18,22 @@ class NormalizedChatOpenAI(ChatOpenAI):
|
|||
def invoke(self, input, config=None, **kwargs):
|
||||
return normalize_content(super().invoke(input, config, **kwargs))
|
||||
|
||||
# Kwargs forwarded from user config to ChatOpenAI
|
||||
_PASSTHROUGH_KWARGS = (
|
||||
"timeout", "max_retries", "reasoning_effort",
|
||||
"api_key", "callbacks", "http_client", "http_async_client",
|
||||
)
|
||||
|
||||
# Provider base URLs and API key env vars
|
||||
_PROVIDER_CONFIG = {
|
||||
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
||||
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
||||
"ollama": ("http://localhost:11434/v1", None),
|
||||
}
|
||||
|
||||
|
||||
class OpenAIClient(BaseLLMClient):
|
||||
"""Client for OpenAI, xAI, OpenRouter, and Ollama providers.
|
||||
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers.
|
||||
|
||||
For native OpenAI models, uses the Responses API (/v1/responses) which
|
||||
supports reasoning_effort with function tools across all model families
|
||||
|
|
@ -137,6 +55,7 @@ class OpenAIClient(BaseLLMClient):
|
|||
"""Return configured ChatOpenAI instance."""
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
# Provider-specific base URL and auth
|
||||
if self.provider in _PROVIDER_CONFIG:
|
||||
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
|
||||
llm_kwargs["base_url"] = base_url
|
||||
|
|
@ -149,19 +68,18 @@ class OpenAIClient(BaseLLMClient):
|
|||
elif self.base_url:
|
||||
llm_kwargs["base_url"] = self.base_url
|
||||
|
||||
# Forward user-provided kwargs
|
||||
for key in _PASSTHROUGH_KWARGS:
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
# Native OpenAI: use Responses API for consistent behavior across
|
||||
# all model families. Third-party providers use Chat Completions.
|
||||
if self.provider == "openai":
|
||||
llm_kwargs["use_responses_api"] = True
|
||||
if "api_key" not in llm_kwargs:
|
||||
token = _get_codex_token()
|
||||
if token:
|
||||
llm_kwargs["api_key"] = token
|
||||
|
||||
return NormalizedChatOpenAI(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
return validate_model(self.provider, self.model)
|
||||
|
||||
"""Validate model for the provider."""
|
||||
return validate_model(self.provider, self.model)
|
||||
Loading…
Reference in New Issue