TradingAgents/tradingagents/llm_clients/codex_app_server.py

338 lines
12 KiB
Python

from __future__ import annotations
import json
import queue
import subprocess
import threading
import uuid
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from .codex_binary import codex_binary_error_message, resolve_codex_binary
class CodexAppServerError(RuntimeError):
"""Raised when the Codex app-server request cycle fails."""
class CodexAppServerAuthError(CodexAppServerError):
"""Raised when Codex login is missing or unusable."""
class CodexAppServerBinaryError(CodexAppServerError):
"""Raised when the Codex binary cannot be started."""
class CodexStructuredOutputError(CodexAppServerError):
"""Raised when Codex does not honor the requested structured output."""
@dataclass(slots=True)
class CodexInvocationResult:
final_text: str
notifications: list[dict[str, Any]]
class CodexAppServerSession:
"""Minimal JSON-RPC client for `codex app-server` over stdio JSONL."""
def __init__(
self,
*,
codex_binary: str | None,
request_timeout: float,
workspace_dir: str,
cleanup_threads: bool,
client_name: str = "tradingagents_codex",
client_title: str = "TradingAgents Codex Provider",
client_version: str = "0.2.3",
) -> None:
self.codex_binary = codex_binary
self.request_timeout = request_timeout
self.workspace_dir = str(Path(workspace_dir).expanduser())
self.cleanup_threads = cleanup_threads
self.client_name = client_name
self.client_title = client_title
self.client_version = client_version
self._proc: subprocess.Popen[str] | None = None
self._stdout_queue: queue.Queue[dict[str, Any] | None] = queue.Queue()
self._pending: deque[dict[str, Any]] = deque()
self._stderr_lines: deque[str] = deque(maxlen=200)
self._lock = threading.RLock()
self._request_lock = threading.RLock()
self._reader_thread: threading.Thread | None = None
self._stderr_thread: threading.Thread | None = None
def start(self) -> None:
with self._lock:
if self._proc is not None:
return
Path(self.workspace_dir).mkdir(parents=True, exist_ok=True)
binary = resolve_codex_binary(self.codex_binary)
if not binary:
raise CodexAppServerBinaryError(codex_binary_error_message(self.codex_binary))
self.codex_binary = binary
try:
self._proc = subprocess.Popen(
[binary, "app-server", "--listen", "stdio://"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding="utf-8",
cwd=self.workspace_dir,
bufsize=1,
)
except OSError as exc:
raise CodexAppServerBinaryError(
f"Failed to start Codex app-server with binary '{binary}': {exc}"
) from exc
self._start_reader_threads()
self._initialize()
def close(self) -> None:
with self._lock:
proc = self._proc
self._proc = None
if proc is None:
return
try:
if proc.stdin:
proc.stdin.close()
except OSError:
pass
try:
proc.terminate()
proc.wait(timeout=2)
except Exception:
proc.kill()
def account_read(self) -> dict[str, Any]:
return self.request("account/read", {"refreshToken": False})
def model_list(self, *, include_hidden: bool = True) -> dict[str, Any]:
return self.request("model/list", {"includeHidden": include_hidden})
def invoke(
self,
*,
prompt: str,
model: str,
output_schema: dict[str, Any],
reasoning_effort: str | None,
summary: str | None,
personality: str | None,
) -> CodexInvocationResult:
with self._request_lock:
self.start()
thread_id = None
try:
thread = self.request(
"thread/start",
{
"approvalPolicy": "never",
"cwd": self.workspace_dir,
"ephemeral": True,
"model": model,
"personality": personality,
"sandbox": "read-only",
"serviceName": "tradingagents_codex",
},
)
thread_id = thread["thread"]["id"]
started = self.request(
"turn/start",
{
"threadId": thread_id,
"input": [{"type": "text", "text": prompt}],
"model": model,
"effort": reasoning_effort,
"summary": summary,
"outputSchema": output_schema,
},
)
turn_id = started["turn"]["id"]
final_text, notifications = self._collect_turn(turn_id)
return CodexInvocationResult(final_text=final_text, notifications=notifications)
finally:
if thread_id and self.cleanup_threads:
try:
self.request("thread/unsubscribe", {"threadId": thread_id})
except CodexAppServerError:
pass
def request(self, method: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
request_id = str(uuid.uuid4())
self._write({"id": request_id, "method": method, "params": params or {}})
deferred: list[dict[str, Any]] = []
while True:
message = self._next_message(self.request_timeout)
if message.get("id") == request_id:
self._restore_deferred(deferred)
if "error" in message:
error = message["error"] or {}
code = error.get("code")
text = error.get("message", "unknown Codex app-server error")
raise CodexAppServerError(
f"{method} failed ({code}): {text}. stderr_tail={self._stderr_tail()}"
)
result = message.get("result")
if not isinstance(result, dict):
raise CodexAppServerError(f"{method} returned a non-object result: {result!r}")
return result
if "method" in message and "id" in message:
self._handle_server_request(message)
continue
deferred.append(message)
def _initialize(self) -> None:
response = self.request(
"initialize",
{
"clientInfo": {
"name": self.client_name,
"title": self.client_title,
"version": self.client_version,
}
},
)
if not response.get("userAgent"):
raise CodexAppServerError("Codex initialize response did not include userAgent.")
self._write({"method": "initialized", "params": {}})
def _collect_turn(self, turn_id: str) -> tuple[str, list[dict[str, Any]]]:
notifications: list[dict[str, Any]] = []
final_messages: list[str] = []
fallback_messages: list[str] = []
while True:
message = self._next_message(self.request_timeout)
if "method" in message and "id" in message:
self._handle_server_request(message)
continue
if "method" not in message:
self._pending.append(message)
continue
method = message["method"]
params = message.get("params", {})
notifications.append(message)
if (
method == "item/completed"
and isinstance(params, dict)
and params.get("turnId") == turn_id
):
item = params.get("item", {})
if isinstance(item, dict) and item.get("type") == "agentMessage":
text = str(item.get("text", ""))
if item.get("phase") == "final_answer":
final_messages.append(text)
else:
fallback_messages.append(text)
continue
if method == "turn/completed" and isinstance(params, dict):
turn = params.get("turn", {})
if isinstance(turn, dict) and turn.get("id") == turn_id:
status = turn.get("status")
if status == "failed":
error = turn.get("error", {})
message_text = error.get("message") if isinstance(error, dict) else None
raise CodexAppServerError(
message_text or f"Codex turn {turn_id} failed without an error message."
)
break
if final_messages:
return final_messages[-1], notifications
if fallback_messages:
return fallback_messages[-1], notifications
raise CodexStructuredOutputError("Codex turn completed without an assistant message.")
def _handle_server_request(self, message: dict[str, Any]) -> None:
try:
self._write({"id": message["id"], "result": {}})
except Exception:
pass
def _write(self, payload: dict[str, Any]) -> None:
if self._proc is None or self._proc.stdin is None:
raise CodexAppServerError("Codex app-server is not running.")
try:
self._proc.stdin.write(json.dumps(payload) + "\n")
self._proc.stdin.flush()
except OSError as exc:
raise CodexAppServerError(
f"Failed to write to Codex app-server: {exc}. stderr_tail={self._stderr_tail()}"
) from exc
def _next_message(self, timeout: float) -> dict[str, Any]:
if self._pending:
return self._pending.popleft()
try:
message = self._stdout_queue.get(timeout=timeout)
except queue.Empty as exc:
raise CodexAppServerError(
f"Timed out waiting for Codex app-server after {timeout}s. stderr_tail={self._stderr_tail()}"
) from exc
if message is None:
raise CodexAppServerError(
f"Codex app-server closed unexpectedly. stderr_tail={self._stderr_tail()}"
)
return message
def _start_reader_threads(self) -> None:
assert self._proc is not None
assert self._proc.stdout is not None
assert self._proc.stderr is not None
def _read_stdout() -> None:
stdout = self._proc.stdout
assert stdout is not None
for line in stdout:
line = line.strip()
if not line:
continue
try:
payload = json.loads(line)
except json.JSONDecodeError:
self._stderr_lines.append(f"invalid_json_stdout={line}")
continue
if isinstance(payload, dict):
self._stdout_queue.put(payload)
self._stdout_queue.put(None)
def _read_stderr() -> None:
stderr = self._proc.stderr
assert stderr is not None
for line in stderr:
self._stderr_lines.append(line.rstrip())
self._reader_thread = threading.Thread(target=_read_stdout, daemon=True)
self._stderr_thread = threading.Thread(target=_read_stderr, daemon=True)
self._reader_thread.start()
self._stderr_thread.start()
def _stderr_tail(self) -> str:
return "\n".join(list(self._stderr_lines)[-40:])
def _restore_deferred(self, deferred: list[dict[str, Any]]) -> None:
for message in reversed(deferred):
self._pending.appendleft(message)