refactor: move Copilot logic into standalone copilot_client.py
- Add tradingagents/llm_clients/copilot_client.py with all Copilot auth (gh CLI token, GraphQL URL resolution, required headers) and CopilotClient class inline — no separate auth module needed - Simplify openai_client.py: remove Copilot code, inline codex OAuth token logic directly (was tradingagents/auth/codex_token.py) - Remove tradingagents/auth/ folder entirely - Update factory.py to route 'copilot' -> CopilotClient - Simplify cli/utils.py to delegate to copilot_client helpers Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
parent
24e97fb703
commit
888fdfbfb9
|
|
@ -217,3 +217,6 @@ __marimo__/
|
|||
|
||||
# Cache
|
||||
**/data_cache/
|
||||
|
||||
# Research Results
|
||||
results/*
|
||||
83
cli/utils.py
83
cli/utils.py
|
|
@ -289,43 +289,13 @@ def fetch_copilot_models() -> list[tuple[str, str]]:
|
|||
Returns a list of (display_label, model_id) tuples sorted by model ID.
|
||||
Requires authentication via ``gh auth login`` with a Copilot subscription.
|
||||
"""
|
||||
import requests
|
||||
from tradingagents.auth import get_github_token, COPILOT_HEADERS, get_copilot_api_url
|
||||
from tradingagents.llm_clients.copilot_client import list_copilot_models
|
||||
|
||||
token = get_github_token()
|
||||
if not token:
|
||||
console.print("[red]No GitHub token available. Run `gh auth login` first.[/red]")
|
||||
return []
|
||||
|
||||
try:
|
||||
console.print("[dim]Fetching available Copilot models...[/dim]")
|
||||
copilot_url = get_copilot_api_url()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
**COPILOT_HEADERS,
|
||||
}
|
||||
resp = requests.get(
|
||||
f"{copilot_url}/models",
|
||||
headers=headers,
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
models = data.get("data", data) if isinstance(data, dict) else data
|
||||
# Filter to chat-capable models (exclude embeddings)
|
||||
chat_models = [
|
||||
m for m in models
|
||||
if not m.get("id", "").startswith("text-embedding")
|
||||
]
|
||||
|
||||
return [
|
||||
(m["id"], m["id"])
|
||||
for m in sorted(chat_models, key=lambda x: x.get("id", ""))
|
||||
]
|
||||
except Exception as e:
|
||||
console.print(f"[yellow]Warning: Could not fetch Copilot models: {e}[/yellow]")
|
||||
return []
|
||||
console.print("[dim]Fetching available Copilot models...[/dim]")
|
||||
models = list_copilot_models()
|
||||
if not models:
|
||||
console.print("[yellow]Warning: Could not fetch Copilot models.[/yellow]")
|
||||
return models
|
||||
|
||||
|
||||
def select_llm_provider() -> tuple[str, str]:
|
||||
|
|
@ -374,42 +344,24 @@ def perform_copilot_oauth() -> bool:
|
|||
|
||||
Returns True if a valid token with Copilot access is available, False otherwise.
|
||||
"""
|
||||
from tradingagents.auth import get_github_token
|
||||
from tradingagents.llm_clients.copilot_client import check_copilot_auth, _get_github_token
|
||||
|
||||
token = get_github_token()
|
||||
token = _get_github_token()
|
||||
if token:
|
||||
# Verify Copilot access
|
||||
import requests
|
||||
from tradingagents.auth import COPILOT_HEADERS, get_copilot_api_url
|
||||
try:
|
||||
copilot_url = get_copilot_api_url()
|
||||
resp = requests.get(
|
||||
f"{copilot_url}/models",
|
||||
headers={"Authorization": f"Bearer {token}", **COPILOT_HEADERS},
|
||||
timeout=5,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
console.print("[green]✓ Authenticated with GitHub Copilot[/green]")
|
||||
return True
|
||||
else:
|
||||
console.print(
|
||||
f"[yellow]⚠ GitHub token found but Copilot access failed "
|
||||
f"(HTTP {resp.status_code}). Check your Copilot subscription.[/yellow]"
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
# Network error — accept the token optimistically
|
||||
console.print("[green]✓ Authenticated with GitHub CLI (Copilot access not verified)[/green]")
|
||||
if check_copilot_auth():
|
||||
console.print("[green]✓ Authenticated with GitHub Copilot[/green]")
|
||||
return True
|
||||
console.print(
|
||||
"[yellow]⚠ GitHub token found but Copilot access failed. "
|
||||
"Check your Copilot subscription.[/yellow]"
|
||||
)
|
||||
return False
|
||||
|
||||
console.print(
|
||||
"[yellow]⚠ No GitHub token found.[/yellow] "
|
||||
"You need to authenticate to use GitHub Copilot."
|
||||
)
|
||||
should_login = questionary.confirm(
|
||||
"Run `gh auth login` now?", default=True
|
||||
).ask()
|
||||
|
||||
should_login = questionary.confirm("Run `gh auth login` now?", default=True).ask()
|
||||
if not should_login:
|
||||
console.print("[red]GitHub authentication skipped. Exiting...[/red]")
|
||||
return False
|
||||
|
|
@ -419,8 +371,7 @@ def perform_copilot_oauth() -> bool:
|
|||
console.print("[red]`gh auth login` failed.[/red]")
|
||||
return False
|
||||
|
||||
token = get_github_token()
|
||||
if token:
|
||||
if _get_github_token():
|
||||
console.print("[green]✓ GitHub authentication successful![/green]")
|
||||
return True
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +0,0 @@
|
|||
from .codex_token import get_codex_token
|
||||
from .github_token import get_github_token, get_copilot_api_url, COPILOT_HEADERS
|
||||
|
||||
__all__ = ["get_codex_token", "get_github_token", "get_copilot_api_url", "COPILOT_HEADERS"]
|
||||
|
|
@ -1,109 +0,0 @@
|
|||
"""OpenAI Codex OAuth token reader with auto-refresh.
|
||||
|
||||
Reads credentials stored by the OpenAI Codex CLI at ~/.codex/auth.json.
|
||||
Checks expiry and refreshes automatically via the OpenAI token endpoint
|
||||
before returning a valid access token — the same pattern OpenClaw uses
|
||||
with its auth-profiles.json token sink.
|
||||
|
||||
Token refresh invalidates the previous refresh token, so only one tool
|
||||
should hold the Codex credentials at a time (same caveat as OpenClaw).
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
_AUTH_FILE = Path.home() / ".codex" / "auth.json"
|
||||
_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
# Refresh this many seconds before actual expiry to avoid edge-case failures.
|
||||
_EXPIRY_BUFFER_SECS = 60
|
||||
|
||||
|
||||
def _load_auth() -> Optional[dict]:
|
||||
"""Load the Codex auth file, return None if missing or malformed."""
|
||||
if not _AUTH_FILE.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(_AUTH_FILE.read_text())
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def _save_auth(data: dict) -> None:
|
||||
_AUTH_FILE.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
def _is_expired(auth: dict) -> bool:
|
||||
"""Return True if the access token is expired (or close to expiring)."""
|
||||
expires = auth.get("expires_at") or auth.get("tokens", {}).get("expires_at")
|
||||
if expires is None:
|
||||
# Fall back to decoding the JWT exp claim.
|
||||
try:
|
||||
import base64
|
||||
token = auth["tokens"]["access_token"]
|
||||
payload = token.split(".")[1]
|
||||
decoded = json.loads(base64.b64decode(payload + "=="))
|
||||
expires = decoded.get("exp")
|
||||
except Exception:
|
||||
return False # Can't determine — assume valid.
|
||||
return time.time() >= (expires - _EXPIRY_BUFFER_SECS)
|
||||
|
||||
|
||||
def _refresh(auth: dict) -> dict:
|
||||
"""Exchange the refresh token for a new token pair and persist it."""
|
||||
refresh_token = auth["tokens"]["refresh_token"]
|
||||
resp = requests.post(
|
||||
_TOKEN_URL,
|
||||
json={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=15,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
new_tokens = resp.json()
|
||||
|
||||
# Merge new tokens back into the auth structure and persist.
|
||||
auth["tokens"].update({
|
||||
"access_token": new_tokens["access_token"],
|
||||
"refresh_token": new_tokens.get("refresh_token", refresh_token),
|
||||
"expires_at": new_tokens.get("expires_in") and
|
||||
int(time.time()) + int(new_tokens["expires_in"]),
|
||||
})
|
||||
auth["last_refresh"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
_save_auth(auth)
|
||||
return auth
|
||||
|
||||
|
||||
def get_codex_token() -> Optional[str]:
|
||||
"""Return a valid OpenAI access token from the Codex CLI auth file.
|
||||
|
||||
Resolution order:
|
||||
1. OPENAI_API_KEY environment variable (explicit key always wins)
|
||||
2. ~/.codex/auth.json — auto-refreshes if the access token is expired
|
||||
|
||||
Returns None if no credentials are found.
|
||||
"""
|
||||
import os
|
||||
explicit = os.environ.get("OPENAI_API_KEY")
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
auth = _load_auth()
|
||||
if not auth or "tokens" not in auth:
|
||||
return None
|
||||
|
||||
# Refresh if expired.
|
||||
if _is_expired(auth):
|
||||
try:
|
||||
auth = _refresh(auth)
|
||||
except Exception:
|
||||
# Refresh failed — return whatever token we have and let the
|
||||
# API call surface a clearer error.
|
||||
pass
|
||||
|
||||
return auth["tokens"].get("access_token")
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
"""GitHub token retrieval for the GitHub Copilot API.
|
||||
|
||||
Uses the ``gh`` CLI exclusively — no explicit API token or env var.
|
||||
Run ``gh auth login`` once to authenticate; this module handles the rest.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_github_token() -> Optional[str]:
|
||||
"""Return a GitHub token obtained via the GitHub CLI (``gh auth token``).
|
||||
|
||||
Returns None if the CLI is unavailable or the user is not logged in.
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["gh", "auth", "token"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
return result.stdout.strip()
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def get_copilot_api_url() -> str:
|
||||
"""Resolve the Copilot inference base URL.
|
||||
|
||||
Queries the GitHub GraphQL API for the user's Copilot endpoints.
|
||||
Falls back to the standard individual endpoint on failure.
|
||||
"""
|
||||
import requests
|
||||
|
||||
token = get_github_token()
|
||||
if not token:
|
||||
return "https://api.individual.githubcopilot.com"
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://api.github.com/graphql",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"query": "{ viewer { copilotEndpoints { api } } }"},
|
||||
timeout=5,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
api = resp.json()["data"]["viewer"]["copilotEndpoints"]["api"]
|
||||
if api:
|
||||
return api.rstrip("/")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "https://api.individual.githubcopilot.com"
|
||||
|
||||
|
||||
# Required headers for the Copilot inference API (reverse-engineered from the
|
||||
# Copilot CLI at /usr/local/lib/node_modules/@github/copilot).
|
||||
COPILOT_HEADERS = {
|
||||
"Copilot-Integration-Id": "copilot-developer-cli",
|
||||
"X-GitHub-Api-Version": "2025-05-01",
|
||||
"Openai-Intent": "conversation-agent",
|
||||
}
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
"""GitHub Copilot LLM client.
|
||||
|
||||
Authenticates via the ``gh`` CLI (``gh auth token``) and calls the Copilot
|
||||
inference API (api.individual.githubcopilot.com) using headers reverse-
|
||||
engineered from the Copilot CLI (copilot-developer-cli integration ID).
|
||||
|
||||
No env var or separate auth module needed — run ``gh auth login`` once.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from .base_client import BaseLLMClient, normalize_content
|
||||
from .validators import validate_model
|
||||
|
||||
# Required headers for the Copilot inference API (reverse-engineered from
|
||||
# /usr/local/lib/node_modules/@github/copilot).
|
||||
_COPILOT_HEADERS = {
|
||||
"Copilot-Integration-Id": "copilot-developer-cli",
|
||||
"X-GitHub-Api-Version": "2025-05-01",
|
||||
"Openai-Intent": "conversation-agent",
|
||||
}
|
||||
|
||||
# Models that only support /responses, not /chat/completions on the Copilot endpoint.
|
||||
_RESPONSES_ONLY_MODELS = frozenset((
|
||||
"gpt-5.4", "gpt-5.4-mini",
|
||||
"gpt-5.3-codex", "gpt-5.2-codex",
|
||||
"gpt-5.1-codex", "gpt-5.1-codex-mini", "gpt-5.1-codex-max",
|
||||
))
|
||||
|
||||
_PASSTHROUGH_KWARGS = (
|
||||
"timeout", "max_retries", "reasoning_effort",
|
||||
"api_key", "callbacks", "http_client", "http_async_client",
|
||||
)
|
||||
|
||||
|
||||
def _get_github_token() -> Optional[str]:
|
||||
"""Return a GitHub token via the ``gh`` CLI."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["gh", "auth", "token"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
return result.stdout.strip()
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _get_copilot_api_url() -> str:
|
||||
"""Resolve the Copilot inference base URL via GraphQL, falling back to the
|
||||
standard individual endpoint."""
|
||||
token = _get_github_token()
|
||||
if token:
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://api.github.com/graphql",
|
||||
headers={"Authorization": f"Bearer {token}", "Content-Type": "application/json"},
|
||||
json={"query": "{ viewer { copilotEndpoints { api } } }"},
|
||||
timeout=5,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
api = resp.json()["data"]["viewer"]["copilotEndpoints"]["api"]
|
||||
if api:
|
||||
return api.rstrip("/")
|
||||
except Exception:
|
||||
pass
|
||||
return "https://api.individual.githubcopilot.com"
|
||||
|
||||
|
||||
def list_copilot_models() -> list[tuple[str, str]]:
|
||||
"""Fetch available Copilot models from the inference API.
|
||||
|
||||
Returns a list of ``(display_label, model_id)`` tuples sorted by model ID.
|
||||
Requires ``gh auth login`` with an active Copilot subscription.
|
||||
"""
|
||||
token = _get_github_token()
|
||||
if not token:
|
||||
return []
|
||||
try:
|
||||
url = _get_copilot_api_url()
|
||||
resp = requests.get(
|
||||
f"{url}/models",
|
||||
headers={"Authorization": f"Bearer {token}", **_COPILOT_HEADERS},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
models = data.get("data", data) if isinstance(data, dict) else data
|
||||
chat_models = [m for m in models if not m.get("id", "").startswith("text-embedding")]
|
||||
return [(m["id"], m["id"]) for m in sorted(chat_models, key=lambda x: x.get("id", ""))]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def check_copilot_auth() -> bool:
|
||||
"""Return True if a GitHub token with Copilot access is available."""
|
||||
token = _get_github_token()
|
||||
if not token:
|
||||
return False
|
||||
try:
|
||||
url = _get_copilot_api_url()
|
||||
resp = requests.get(
|
||||
f"{url}/models",
|
||||
headers={"Authorization": f"Bearer {token}", **_COPILOT_HEADERS},
|
||||
timeout=5,
|
||||
)
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return True # Network error — accept optimistically
|
||||
|
||||
|
||||
class NormalizedChatOpenAI(ChatOpenAI):
|
||||
"""ChatOpenAI with normalized content output."""
|
||||
|
||||
def invoke(self, input, config=None, **kwargs):
|
||||
return normalize_content(super().invoke(input, config, **kwargs))
|
||||
|
||||
|
||||
class CopilotClient(BaseLLMClient):
|
||||
"""Client for GitHub Copilot inference API.
|
||||
|
||||
Uses the gh CLI for authentication. Automatically routes models that only
|
||||
support the Responses API (gpt-5.4, codex variants) to ``/responses``
|
||||
instead of ``/chat/completions``.
|
||||
"""
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured ChatOpenAI instance pointed at the Copilot API."""
|
||||
token = _get_github_token()
|
||||
copilot_url = _get_copilot_api_url()
|
||||
|
||||
llm_kwargs = {
|
||||
"model": self.model,
|
||||
"base_url": copilot_url,
|
||||
"api_key": token or "copilot",
|
||||
"default_headers": dict(_COPILOT_HEADERS),
|
||||
}
|
||||
|
||||
for key in _PASSTHROUGH_KWARGS:
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
if self.model in _RESPONSES_ONLY_MODELS:
|
||||
llm_kwargs["use_responses_api"] = True
|
||||
|
||||
return NormalizedChatOpenAI(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
return validate_model("copilot", self.model)
|
||||
|
|
@ -2,6 +2,7 @@ from typing import Optional
|
|||
|
||||
from .base_client import BaseLLMClient
|
||||
from .openai_client import OpenAIClient
|
||||
from .copilot_client import CopilotClient
|
||||
from .anthropic_client import AnthropicClient
|
||||
from .google_client import GoogleClient
|
||||
|
||||
|
|
@ -34,7 +35,10 @@ def create_llm_client(
|
|||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower in ("openai", "ollama", "openrouter", "copilot"):
|
||||
if provider_lower == "copilot":
|
||||
return CopilotClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower in ("openai", "ollama", "openrouter"):
|
||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||
|
||||
if provider_lower == "xai":
|
||||
|
|
|
|||
|
|
@ -1,11 +1,105 @@
|
|||
"""OpenAI-compatible LLM clients (OpenAI, xAI, OpenRouter, Ollama).
|
||||
|
||||
OpenAI auth resolution order:
|
||||
1. ``api_key`` kwarg (explicit key always wins)
|
||||
2. ``OPENAI_API_KEY`` environment variable
|
||||
3. ``~/.codex/auth.json`` — OpenAI Codex CLI OAuth token (auto-refreshed)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from .base_client import BaseLLMClient, normalize_content
|
||||
from .validators import validate_model
|
||||
from ..auth import get_codex_token, get_github_token, get_copilot_api_url, COPILOT_HEADERS
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Codex OAuth token reader (inlined from auth/codex_token.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CODEX_AUTH_FILE = Path.home() / ".codex" / "auth.json"
|
||||
_CODEX_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
_EXPIRY_BUFFER_SECS = 60
|
||||
|
||||
|
||||
def _load_codex_auth() -> Optional[dict]:
|
||||
if not _CODEX_AUTH_FILE.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(_CODEX_AUTH_FILE.read_text())
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def _codex_token_expired(auth: dict) -> bool:
|
||||
expires = auth.get("expires_at") or auth.get("tokens", {}).get("expires_at")
|
||||
if expires is None:
|
||||
try:
|
||||
import base64
|
||||
token = auth["tokens"]["access_token"]
|
||||
payload = token.split(".")[1]
|
||||
decoded = json.loads(base64.b64decode(payload + "=="))
|
||||
expires = decoded.get("exp")
|
||||
except Exception:
|
||||
return False
|
||||
return time.time() >= (expires - _EXPIRY_BUFFER_SECS)
|
||||
|
||||
|
||||
def _refresh_codex_token(auth: dict) -> dict:
|
||||
refresh_token = auth["tokens"]["refresh_token"]
|
||||
resp = requests.post(
|
||||
_CODEX_TOKEN_URL,
|
||||
json={"grant_type": "refresh_token", "refresh_token": refresh_token},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=15,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
new_tokens = resp.json()
|
||||
auth["tokens"].update({
|
||||
"access_token": new_tokens["access_token"],
|
||||
"refresh_token": new_tokens.get("refresh_token", refresh_token),
|
||||
"expires_at": new_tokens.get("expires_in") and int(time.time()) + int(new_tokens["expires_in"]),
|
||||
})
|
||||
auth["last_refresh"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
_CODEX_AUTH_FILE.write_text(json.dumps(auth, indent=2))
|
||||
return auth
|
||||
|
||||
|
||||
def _get_codex_token() -> Optional[str]:
|
||||
"""Return a valid OpenAI token from OPENAI_API_KEY or ~/.codex/auth.json."""
|
||||
explicit = os.environ.get("OPENAI_API_KEY")
|
||||
if explicit:
|
||||
return explicit
|
||||
auth = _load_codex_auth()
|
||||
if not auth or "tokens" not in auth:
|
||||
return None
|
||||
if _codex_token_expired(auth):
|
||||
try:
|
||||
auth = _refresh_codex_token(auth)
|
||||
except Exception:
|
||||
pass
|
||||
return auth["tokens"].get("access_token")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI-compatible client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PASSTHROUGH_KWARGS = (
|
||||
"timeout", "max_retries", "reasoning_effort",
|
||||
"api_key", "callbacks", "http_client", "http_async_client",
|
||||
)
|
||||
|
||||
_PROVIDER_CONFIG = {
|
||||
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
||||
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
||||
"ollama": ("http://localhost:11434/v1", None),
|
||||
}
|
||||
|
||||
|
||||
class NormalizedChatOpenAI(ChatOpenAI):
|
||||
|
|
@ -19,44 +113,14 @@ class NormalizedChatOpenAI(ChatOpenAI):
|
|||
def invoke(self, input, config=None, **kwargs):
|
||||
return normalize_content(super().invoke(input, config, **kwargs))
|
||||
|
||||
# Kwargs forwarded from user config to ChatOpenAI
|
||||
_PASSTHROUGH_KWARGS = (
|
||||
"timeout", "max_retries", "reasoning_effort",
|
||||
"api_key", "callbacks", "http_client", "http_async_client",
|
||||
)
|
||||
|
||||
# Provider base URLs and API key env vars.
|
||||
# Copilot: uses the GitHub Copilot inference API, authenticated via ``gh``
|
||||
# CLI token with Copilot-specific headers. No env var needed.
|
||||
_PROVIDER_CONFIG = {
|
||||
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
||||
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
||||
"ollama": ("http://localhost:11434/v1", None),
|
||||
"copilot": (None, None), # base_url resolved at runtime via GraphQL
|
||||
}
|
||||
|
||||
|
||||
# Models that only support the Responses API on the Copilot endpoint.
|
||||
_COPILOT_RESPONSES_ONLY = frozenset((
|
||||
"gpt-5.4", "gpt-5.4-mini",
|
||||
"gpt-5.3-codex", "gpt-5.2-codex",
|
||||
"gpt-5.1-codex", "gpt-5.1-codex-mini", "gpt-5.1-codex-max",
|
||||
))
|
||||
|
||||
|
||||
def _copilot_needs_responses_api(model: str) -> bool:
|
||||
"""Return True if the model requires /responses instead of /chat/completions."""
|
||||
return model in _COPILOT_RESPONSES_ONLY
|
||||
|
||||
|
||||
class OpenAIClient(BaseLLMClient):
|
||||
"""Client for OpenAI, Ollama, OpenRouter, xAI, and GitHub Copilot providers.
|
||||
"""Client for OpenAI, xAI, OpenRouter, and Ollama providers.
|
||||
|
||||
For native OpenAI models, uses the Responses API (/v1/responses) which
|
||||
supports reasoning_effort with function tools across all model families
|
||||
(GPT-4.1, GPT-5). Third-party compatible providers (xAI, OpenRouter,
|
||||
Ollama) use standard Chat Completions. GitHub Copilot uses the Copilot
|
||||
inference API with special headers.
|
||||
Ollama) use standard Chat Completions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -73,19 +137,9 @@ class OpenAIClient(BaseLLMClient):
|
|||
"""Return configured ChatOpenAI instance."""
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
# Provider-specific base URL and auth
|
||||
if self.provider == "copilot":
|
||||
# GitHub Copilot: resolve base URL and inject required headers
|
||||
copilot_url = get_copilot_api_url()
|
||||
llm_kwargs["base_url"] = copilot_url
|
||||
token = get_github_token()
|
||||
if token:
|
||||
llm_kwargs["api_key"] = token
|
||||
llm_kwargs["default_headers"] = dict(COPILOT_HEADERS)
|
||||
elif self.provider in _PROVIDER_CONFIG:
|
||||
if self.provider in _PROVIDER_CONFIG:
|
||||
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
|
||||
if base_url:
|
||||
llm_kwargs["base_url"] = base_url
|
||||
llm_kwargs["base_url"] = base_url
|
||||
if api_key_env:
|
||||
api_key = os.environ.get(api_key_env)
|
||||
if api_key:
|
||||
|
|
@ -95,28 +149,19 @@ class OpenAIClient(BaseLLMClient):
|
|||
elif self.base_url:
|
||||
llm_kwargs["base_url"] = self.base_url
|
||||
|
||||
# Forward user-provided kwargs (takes precedence over auto-resolved tokens)
|
||||
for key in _PASSTHROUGH_KWARGS:
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
# Native OpenAI: use Responses API for consistent behavior across
|
||||
# all model families. Third-party providers use Chat Completions.
|
||||
if self.provider == "openai":
|
||||
llm_kwargs["use_responses_api"] = True
|
||||
# If no explicit api_key in kwargs, fall back to Codex OAuth token.
|
||||
if "api_key" not in llm_kwargs:
|
||||
codex_token = get_codex_token()
|
||||
if codex_token:
|
||||
llm_kwargs["api_key"] = codex_token
|
||||
|
||||
# Copilot: newer models (gpt-5.4, codex variants) only support the
|
||||
# Responses API (/responses), not Chat Completions (/chat/completions).
|
||||
if self.provider == "copilot" and _copilot_needs_responses_api(self.model):
|
||||
llm_kwargs["use_responses_api"] = True
|
||||
token = _get_codex_token()
|
||||
if token:
|
||||
llm_kwargs["api_key"] = token
|
||||
|
||||
return NormalizedChatOpenAI(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
"""Validate model for the provider."""
|
||||
return validate_model(self.provider, self.model)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue