update
This commit is contained in:
parent
894b7bcba0
commit
a84c69e42f
|
|
@ -1,106 +1,11 @@
|
||||||
"""OpenAI-compatible LLM clients (OpenAI, xAI, OpenRouter, Ollama).
|
|
||||||
|
|
||||||
OpenAI auth resolution order:
|
|
||||||
1. ``api_key`` kwarg (explicit key always wins)
|
|
||||||
2. ``OPENAI_API_KEY`` environment variable
|
|
||||||
3. ``~/.codex/auth.json`` — OpenAI Codex CLI OAuth token (auto-refreshed)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import requests
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
from .base_client import BaseLLMClient, normalize_content
|
from .base_client import BaseLLMClient, normalize_content
|
||||||
from .validators import validate_model
|
from .validators import validate_model
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Codex OAuth token reader (inlined from auth/codex_token.py)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_CODEX_AUTH_FILE = Path.home() / ".codex" / "auth.json"
|
|
||||||
_CODEX_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
|
||||||
_EXPIRY_BUFFER_SECS = 60
|
|
||||||
|
|
||||||
|
|
||||||
def _load_codex_auth() -> Optional[dict]:
|
|
||||||
if not _CODEX_AUTH_FILE.exists():
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
return json.loads(_CODEX_AUTH_FILE.read_text())
|
|
||||||
except (json.JSONDecodeError, OSError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _codex_token_expired(auth: dict) -> bool:
|
|
||||||
expires = auth.get("expires_at") or auth.get("tokens", {}).get("expires_at")
|
|
||||||
if expires is None:
|
|
||||||
try:
|
|
||||||
import base64
|
|
||||||
token = auth["tokens"]["access_token"]
|
|
||||||
payload = token.split(".")[1]
|
|
||||||
decoded = json.loads(base64.b64decode(payload + "=="))
|
|
||||||
expires = decoded.get("exp")
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
return time.time() >= (expires - _EXPIRY_BUFFER_SECS)
|
|
||||||
|
|
||||||
|
|
||||||
def _refresh_codex_token(auth: dict) -> dict:
|
|
||||||
refresh_token = auth["tokens"]["refresh_token"]
|
|
||||||
resp = requests.post(
|
|
||||||
_CODEX_TOKEN_URL,
|
|
||||||
json={"grant_type": "refresh_token", "refresh_token": refresh_token},
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
timeout=15,
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
new_tokens = resp.json()
|
|
||||||
auth["tokens"].update({
|
|
||||||
"access_token": new_tokens["access_token"],
|
|
||||||
"refresh_token": new_tokens.get("refresh_token", refresh_token),
|
|
||||||
"expires_at": new_tokens.get("expires_in") and int(time.time()) + int(new_tokens["expires_in"]),
|
|
||||||
})
|
|
||||||
auth["last_refresh"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
|
||||||
_CODEX_AUTH_FILE.write_text(json.dumps(auth, indent=2))
|
|
||||||
return auth
|
|
||||||
|
|
||||||
|
|
||||||
def _get_codex_token() -> Optional[str]:
|
|
||||||
"""Return a valid OpenAI token from OPENAI_API_KEY or ~/.codex/auth.json."""
|
|
||||||
explicit = os.environ.get("OPENAI_API_KEY")
|
|
||||||
if explicit:
|
|
||||||
return explicit
|
|
||||||
auth = _load_codex_auth()
|
|
||||||
if not auth or "tokens" not in auth:
|
|
||||||
return None
|
|
||||||
if _codex_token_expired(auth):
|
|
||||||
try:
|
|
||||||
auth = _refresh_codex_token(auth)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return auth["tokens"].get("access_token")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# OpenAI-compatible client
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_PASSTHROUGH_KWARGS = (
|
|
||||||
"timeout", "max_retries", "reasoning_effort",
|
|
||||||
"api_key", "callbacks", "http_client", "http_async_client",
|
|
||||||
)
|
|
||||||
|
|
||||||
_PROVIDER_CONFIG = {
|
|
||||||
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
|
||||||
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
|
||||||
"ollama": ("http://localhost:11434/v1", None),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class NormalizedChatOpenAI(ChatOpenAI):
|
class NormalizedChatOpenAI(ChatOpenAI):
|
||||||
"""ChatOpenAI with normalized content output.
|
"""ChatOpenAI with normalized content output.
|
||||||
|
|
@ -113,9 +18,22 @@ class NormalizedChatOpenAI(ChatOpenAI):
|
||||||
def invoke(self, input, config=None, **kwargs):
|
def invoke(self, input, config=None, **kwargs):
|
||||||
return normalize_content(super().invoke(input, config, **kwargs))
|
return normalize_content(super().invoke(input, config, **kwargs))
|
||||||
|
|
||||||
|
# Kwargs forwarded from user config to ChatOpenAI
|
||||||
|
_PASSTHROUGH_KWARGS = (
|
||||||
|
"timeout", "max_retries", "reasoning_effort",
|
||||||
|
"api_key", "callbacks", "http_client", "http_async_client",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Provider base URLs and API key env vars
|
||||||
|
_PROVIDER_CONFIG = {
|
||||||
|
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
||||||
|
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
||||||
|
"ollama": ("http://localhost:11434/v1", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class OpenAIClient(BaseLLMClient):
|
class OpenAIClient(BaseLLMClient):
|
||||||
"""Client for OpenAI, xAI, OpenRouter, and Ollama providers.
|
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers.
|
||||||
|
|
||||||
For native OpenAI models, uses the Responses API (/v1/responses) which
|
For native OpenAI models, uses the Responses API (/v1/responses) which
|
||||||
supports reasoning_effort with function tools across all model families
|
supports reasoning_effort with function tools across all model families
|
||||||
|
|
@ -137,6 +55,7 @@ class OpenAIClient(BaseLLMClient):
|
||||||
"""Return configured ChatOpenAI instance."""
|
"""Return configured ChatOpenAI instance."""
|
||||||
llm_kwargs = {"model": self.model}
|
llm_kwargs = {"model": self.model}
|
||||||
|
|
||||||
|
# Provider-specific base URL and auth
|
||||||
if self.provider in _PROVIDER_CONFIG:
|
if self.provider in _PROVIDER_CONFIG:
|
||||||
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
|
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
|
||||||
llm_kwargs["base_url"] = base_url
|
llm_kwargs["base_url"] = base_url
|
||||||
|
|
@ -149,19 +68,18 @@ class OpenAIClient(BaseLLMClient):
|
||||||
elif self.base_url:
|
elif self.base_url:
|
||||||
llm_kwargs["base_url"] = self.base_url
|
llm_kwargs["base_url"] = self.base_url
|
||||||
|
|
||||||
|
# Forward user-provided kwargs
|
||||||
for key in _PASSTHROUGH_KWARGS:
|
for key in _PASSTHROUGH_KWARGS:
|
||||||
if key in self.kwargs:
|
if key in self.kwargs:
|
||||||
llm_kwargs[key] = self.kwargs[key]
|
llm_kwargs[key] = self.kwargs[key]
|
||||||
|
|
||||||
|
# Native OpenAI: use Responses API for consistent behavior across
|
||||||
|
# all model families. Third-party providers use Chat Completions.
|
||||||
if self.provider == "openai":
|
if self.provider == "openai":
|
||||||
llm_kwargs["use_responses_api"] = True
|
llm_kwargs["use_responses_api"] = True
|
||||||
if "api_key" not in llm_kwargs:
|
|
||||||
token = _get_codex_token()
|
|
||||||
if token:
|
|
||||||
llm_kwargs["api_key"] = token
|
|
||||||
|
|
||||||
return NormalizedChatOpenAI(**llm_kwargs)
|
return NormalizedChatOpenAI(**llm_kwargs)
|
||||||
|
|
||||||
def validate_model(self) -> bool:
|
def validate_model(self) -> bool:
|
||||||
|
"""Validate model for the provider."""
|
||||||
return validate_model(self.provider, self.model)
|
return validate_model(self.provider, self.model)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue