338 lines
12 KiB
Python
338 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import queue
|
|
import subprocess
|
|
import threading
|
|
import uuid
|
|
from collections import deque
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from .codex_binary import codex_binary_error_message, resolve_codex_binary
|
|
|
|
|
|
class CodexAppServerError(RuntimeError):
|
|
"""Raised when the Codex app-server request cycle fails."""
|
|
|
|
|
|
class CodexAppServerAuthError(CodexAppServerError):
|
|
"""Raised when Codex login is missing or unusable."""
|
|
|
|
|
|
class CodexAppServerBinaryError(CodexAppServerError):
|
|
"""Raised when the Codex binary cannot be started."""
|
|
|
|
|
|
class CodexStructuredOutputError(CodexAppServerError):
|
|
"""Raised when Codex does not honor the requested structured output."""
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class CodexInvocationResult:
|
|
final_text: str
|
|
notifications: list[dict[str, Any]]
|
|
|
|
|
|
class CodexAppServerSession:
|
|
"""Minimal JSON-RPC client for `codex app-server` over stdio JSONL."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
codex_binary: str | None,
|
|
request_timeout: float,
|
|
workspace_dir: str,
|
|
cleanup_threads: bool,
|
|
client_name: str = "tradingagents_codex",
|
|
client_title: str = "TradingAgents Codex Provider",
|
|
client_version: str = "0.2.3",
|
|
) -> None:
|
|
self.codex_binary = codex_binary
|
|
self.request_timeout = request_timeout
|
|
self.workspace_dir = str(Path(workspace_dir).expanduser())
|
|
self.cleanup_threads = cleanup_threads
|
|
self.client_name = client_name
|
|
self.client_title = client_title
|
|
self.client_version = client_version
|
|
self._proc: subprocess.Popen[str] | None = None
|
|
self._stdout_queue: queue.Queue[dict[str, Any] | None] = queue.Queue()
|
|
self._pending: deque[dict[str, Any]] = deque()
|
|
self._stderr_lines: deque[str] = deque(maxlen=200)
|
|
self._lock = threading.RLock()
|
|
self._request_lock = threading.RLock()
|
|
self._reader_thread: threading.Thread | None = None
|
|
self._stderr_thread: threading.Thread | None = None
|
|
|
|
def start(self) -> None:
|
|
with self._lock:
|
|
if self._proc is not None:
|
|
return
|
|
|
|
Path(self.workspace_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
binary = resolve_codex_binary(self.codex_binary)
|
|
if not binary:
|
|
raise CodexAppServerBinaryError(codex_binary_error_message(self.codex_binary))
|
|
self.codex_binary = binary
|
|
|
|
try:
|
|
self._proc = subprocess.Popen(
|
|
[binary, "app-server", "--listen", "stdio://"],
|
|
stdin=subprocess.PIPE,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
text=True,
|
|
encoding="utf-8",
|
|
cwd=self.workspace_dir,
|
|
bufsize=1,
|
|
)
|
|
except OSError as exc:
|
|
raise CodexAppServerBinaryError(
|
|
f"Failed to start Codex app-server with binary '{binary}': {exc}"
|
|
) from exc
|
|
|
|
self._start_reader_threads()
|
|
self._initialize()
|
|
|
|
def close(self) -> None:
|
|
with self._lock:
|
|
proc = self._proc
|
|
self._proc = None
|
|
if proc is None:
|
|
return
|
|
|
|
try:
|
|
if proc.stdin:
|
|
proc.stdin.close()
|
|
except OSError:
|
|
pass
|
|
|
|
try:
|
|
proc.terminate()
|
|
proc.wait(timeout=2)
|
|
except Exception:
|
|
proc.kill()
|
|
|
|
def account_read(self) -> dict[str, Any]:
|
|
return self.request("account/read", {"refreshToken": False})
|
|
|
|
def model_list(self, *, include_hidden: bool = True) -> dict[str, Any]:
|
|
return self.request("model/list", {"includeHidden": include_hidden})
|
|
|
|
def invoke(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
output_schema: dict[str, Any],
|
|
reasoning_effort: str | None,
|
|
summary: str | None,
|
|
personality: str | None,
|
|
) -> CodexInvocationResult:
|
|
with self._request_lock:
|
|
self.start()
|
|
thread_id = None
|
|
try:
|
|
thread = self.request(
|
|
"thread/start",
|
|
{
|
|
"approvalPolicy": "never",
|
|
"cwd": self.workspace_dir,
|
|
"ephemeral": True,
|
|
"model": model,
|
|
"personality": personality,
|
|
"sandbox": "read-only",
|
|
"serviceName": "tradingagents_codex",
|
|
},
|
|
)
|
|
thread_id = thread["thread"]["id"]
|
|
started = self.request(
|
|
"turn/start",
|
|
{
|
|
"threadId": thread_id,
|
|
"input": [{"type": "text", "text": prompt}],
|
|
"model": model,
|
|
"effort": reasoning_effort,
|
|
"summary": summary,
|
|
"outputSchema": output_schema,
|
|
},
|
|
)
|
|
turn_id = started["turn"]["id"]
|
|
final_text, notifications = self._collect_turn(turn_id)
|
|
return CodexInvocationResult(final_text=final_text, notifications=notifications)
|
|
finally:
|
|
if thread_id and self.cleanup_threads:
|
|
try:
|
|
self.request("thread/unsubscribe", {"threadId": thread_id})
|
|
except CodexAppServerError:
|
|
pass
|
|
|
|
def request(self, method: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
request_id = str(uuid.uuid4())
|
|
self._write({"id": request_id, "method": method, "params": params or {}})
|
|
deferred: list[dict[str, Any]] = []
|
|
|
|
while True:
|
|
message = self._next_message(self.request_timeout)
|
|
if message.get("id") == request_id:
|
|
self._restore_deferred(deferred)
|
|
if "error" in message:
|
|
error = message["error"] or {}
|
|
code = error.get("code")
|
|
text = error.get("message", "unknown Codex app-server error")
|
|
raise CodexAppServerError(
|
|
f"{method} failed ({code}): {text}. stderr_tail={self._stderr_tail()}"
|
|
)
|
|
result = message.get("result")
|
|
if not isinstance(result, dict):
|
|
raise CodexAppServerError(f"{method} returned a non-object result: {result!r}")
|
|
return result
|
|
|
|
if "method" in message and "id" in message:
|
|
self._handle_server_request(message)
|
|
continue
|
|
|
|
deferred.append(message)
|
|
|
|
def _initialize(self) -> None:
|
|
response = self.request(
|
|
"initialize",
|
|
{
|
|
"clientInfo": {
|
|
"name": self.client_name,
|
|
"title": self.client_title,
|
|
"version": self.client_version,
|
|
}
|
|
},
|
|
)
|
|
if not response.get("userAgent"):
|
|
raise CodexAppServerError("Codex initialize response did not include userAgent.")
|
|
self._write({"method": "initialized", "params": {}})
|
|
|
|
def _collect_turn(self, turn_id: str) -> tuple[str, list[dict[str, Any]]]:
|
|
notifications: list[dict[str, Any]] = []
|
|
final_messages: list[str] = []
|
|
fallback_messages: list[str] = []
|
|
|
|
while True:
|
|
message = self._next_message(self.request_timeout)
|
|
|
|
if "method" in message and "id" in message:
|
|
self._handle_server_request(message)
|
|
continue
|
|
|
|
if "method" not in message:
|
|
self._pending.append(message)
|
|
continue
|
|
|
|
method = message["method"]
|
|
params = message.get("params", {})
|
|
notifications.append(message)
|
|
|
|
if (
|
|
method == "item/completed"
|
|
and isinstance(params, dict)
|
|
and params.get("turnId") == turn_id
|
|
):
|
|
item = params.get("item", {})
|
|
if isinstance(item, dict) and item.get("type") == "agentMessage":
|
|
text = str(item.get("text", ""))
|
|
if item.get("phase") == "final_answer":
|
|
final_messages.append(text)
|
|
else:
|
|
fallback_messages.append(text)
|
|
continue
|
|
|
|
if method == "turn/completed" and isinstance(params, dict):
|
|
turn = params.get("turn", {})
|
|
if isinstance(turn, dict) and turn.get("id") == turn_id:
|
|
status = turn.get("status")
|
|
if status == "failed":
|
|
error = turn.get("error", {})
|
|
message_text = error.get("message") if isinstance(error, dict) else None
|
|
raise CodexAppServerError(
|
|
message_text or f"Codex turn {turn_id} failed without an error message."
|
|
)
|
|
break
|
|
|
|
if final_messages:
|
|
return final_messages[-1], notifications
|
|
if fallback_messages:
|
|
return fallback_messages[-1], notifications
|
|
raise CodexStructuredOutputError("Codex turn completed without an assistant message.")
|
|
|
|
def _handle_server_request(self, message: dict[str, Any]) -> None:
|
|
try:
|
|
self._write({"id": message["id"], "result": {}})
|
|
except Exception:
|
|
pass
|
|
|
|
def _write(self, payload: dict[str, Any]) -> None:
|
|
if self._proc is None or self._proc.stdin is None:
|
|
raise CodexAppServerError("Codex app-server is not running.")
|
|
try:
|
|
self._proc.stdin.write(json.dumps(payload) + "\n")
|
|
self._proc.stdin.flush()
|
|
except OSError as exc:
|
|
raise CodexAppServerError(
|
|
f"Failed to write to Codex app-server: {exc}. stderr_tail={self._stderr_tail()}"
|
|
) from exc
|
|
|
|
def _next_message(self, timeout: float) -> dict[str, Any]:
|
|
if self._pending:
|
|
return self._pending.popleft()
|
|
|
|
try:
|
|
message = self._stdout_queue.get(timeout=timeout)
|
|
except queue.Empty as exc:
|
|
raise CodexAppServerError(
|
|
f"Timed out waiting for Codex app-server after {timeout}s. stderr_tail={self._stderr_tail()}"
|
|
) from exc
|
|
|
|
if message is None:
|
|
raise CodexAppServerError(
|
|
f"Codex app-server closed unexpectedly. stderr_tail={self._stderr_tail()}"
|
|
)
|
|
return message
|
|
|
|
def _start_reader_threads(self) -> None:
|
|
assert self._proc is not None
|
|
assert self._proc.stdout is not None
|
|
assert self._proc.stderr is not None
|
|
|
|
def _read_stdout() -> None:
|
|
stdout = self._proc.stdout
|
|
assert stdout is not None
|
|
for line in stdout:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
payload = json.loads(line)
|
|
except json.JSONDecodeError:
|
|
self._stderr_lines.append(f"invalid_json_stdout={line}")
|
|
continue
|
|
if isinstance(payload, dict):
|
|
self._stdout_queue.put(payload)
|
|
self._stdout_queue.put(None)
|
|
|
|
def _read_stderr() -> None:
|
|
stderr = self._proc.stderr
|
|
assert stderr is not None
|
|
for line in stderr:
|
|
self._stderr_lines.append(line.rstrip())
|
|
|
|
self._reader_thread = threading.Thread(target=_read_stdout, daemon=True)
|
|
self._stderr_thread = threading.Thread(target=_read_stderr, daemon=True)
|
|
self._reader_thread.start()
|
|
self._stderr_thread.start()
|
|
|
|
def _stderr_tail(self) -> str:
|
|
return "\n".join(list(self._stderr_lines)[-40:])
|
|
|
|
def _restore_deferred(self, deferred: list[dict[str, Any]]) -> None:
|
|
for message in reversed(deferred):
|
|
self._pending.appendleft(message)
|