TradingAgents/tradingagents/llm_clients/codex_chat_model.py

408 lines
16 KiB
Python

from __future__ import annotations
import json
import threading
import uuid
from typing import Any, Callable, Sequence
from pydantic import ConfigDict, Field, PrivateAttr
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from .codex_app_server import CodexAppServerSession, CodexStructuredOutputError
from .codex_message_codec import (
format_messages_for_codex,
normalize_input_messages,
strip_json_fence,
)
from .codex_preflight import run_codex_preflight
from .codex_schema import (
build_plain_response_schema,
build_tool_response_schema,
normalize_tools_for_codex,
)
class CodexChatModel(BaseChatModel):
"""LangChain chat model that talks to `codex app-server` over stdio."""
model: str
codex_binary: str | None = None
codex_reasoning_effort: str | None = None
codex_summary: str | None = None
codex_personality: str | None = None
codex_workspace_dir: str
codex_request_timeout: float = 120.0
codex_max_retries: int = 2
codex_cleanup_threads: bool = True
session_factory: Callable[..., CodexAppServerSession] | None = Field(
default=None, exclude=True, repr=False
)
preflight_runner: Callable[..., Any] | None = Field(
default=None, exclude=True, repr=False
)
model_config = ConfigDict(arbitrary_types_allowed=True)
_session: CodexAppServerSession | None = PrivateAttr(default=None)
_session_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_preflight_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_preflight_done: bool = PrivateAttr(default=False)
@property
def _llm_type(self) -> str:
return "codex"
@property
def _identifying_params(self) -> dict[str, Any]:
return {
"model": self.model,
"codex_binary": self.codex_binary,
"codex_reasoning_effort": self.codex_reasoning_effort,
"codex_summary": self.codex_summary,
"codex_personality": self.codex_personality,
}
def preflight(self) -> None:
with self._preflight_lock:
if self._preflight_done:
return
runner = self.preflight_runner or run_codex_preflight
runner(
codex_binary=self.codex_binary,
model=self.model,
request_timeout=self.codex_request_timeout,
workspace_dir=self.codex_workspace_dir,
cleanup_threads=self.codex_cleanup_threads,
session_factory=self.session_factory or CodexAppServerSession,
)
self._preflight_done = True
def bind_tools(
self,
tools: Sequence[dict[str, Any] | type | Callable | Any],
*,
tool_choice: str | bool | dict[str, Any] | None = None,
**kwargs: Any,
):
normalized_tools = normalize_tools_for_codex(tools)
return self.bind(tools=normalized_tools, tool_choice=tool_choice, **kwargs)
def close(self) -> None:
with self._session_lock:
if self._session is not None:
self._session.close()
self._session = None
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager=None,
**kwargs: Any,
) -> ChatResult:
self.preflight()
normalized_messages = normalize_input_messages(messages)
tools = kwargs.get("tools") or []
tool_choice = kwargs.get("tool_choice")
tool_binding = self._resolve_tool_binding(tools, tool_choice)
tools = tool_binding["tools"]
effective_tool_choice = tool_binding["tool_choice"]
output_schema = tool_binding["output_schema"]
tool_arguments_as_json_string = tool_binding["tool_arguments_as_json_string"]
raw_response: str | None = None
last_error: Exception | None = None
for attempt in range(self.codex_max_retries + 1):
retry_message = None
if attempt:
previous_error = str(last_error) if last_error is not None else "unknown schema mismatch"
retry_message = (
"The previous response did not satisfy TradingAgents validation: "
f"{previous_error}. Return only valid JSON that exactly matches the requested "
"schema and tool argument requirements."
)
prompt = format_messages_for_codex(
normalized_messages,
tool_names=[tool["function"]["name"] for tool in tools],
tool_schemas=tools,
tool_choice=effective_tool_choice,
tool_arguments_as_json_string=tool_arguments_as_json_string,
retry_message=retry_message,
)
result = self._session_or_create().invoke(
prompt=prompt,
model=self.model,
output_schema=output_schema,
reasoning_effort=self.codex_reasoning_effort,
summary=self.codex_summary,
personality=self.codex_personality,
)
raw_response = result.final_text
if run_manager is not None:
for notification in result.notifications:
if notification.get("method") != "item/agentMessage/delta":
continue
params = notification.get("params", {})
if isinstance(params, dict):
delta = params.get("delta")
if isinstance(delta, str) and delta:
run_manager.on_llm_new_token(delta)
try:
ai_message = (
self._parse_tool_response(
raw_response,
tools,
tool_arguments_as_json_string=tool_arguments_as_json_string,
)
if tools
else self._parse_plain_response(raw_response)
)
return ChatResult(generations=[ChatGeneration(message=ai_message)])
except (json.JSONDecodeError, CodexStructuredOutputError, ValueError) as exc:
last_error = exc
continue
raise CodexStructuredOutputError(
"Codex returned malformed structured output after "
f"{self.codex_max_retries + 1} attempt(s): {last_error}. "
f"Last response: {raw_response!r}"
)
def _parse_plain_response(self, raw_response: str) -> AIMessage:
payload = json.loads(strip_json_fence(raw_response))
if not isinstance(payload, dict) or not isinstance(payload.get("answer"), str):
raise CodexStructuredOutputError(
f"Expected plain response JSON with string `answer`, got: {payload!r}"
)
return AIMessage(content=payload["answer"])
def _parse_tool_response(
self,
raw_response: str,
tools: Sequence[dict[str, Any]],
*,
tool_arguments_as_json_string: bool,
) -> AIMessage:
payload = json.loads(strip_json_fence(raw_response))
if not isinstance(payload, dict):
raise CodexStructuredOutputError(f"Expected JSON object, got: {payload!r}")
mode = payload.get("mode")
content = payload.get("content", "")
if not isinstance(content, str):
raise CodexStructuredOutputError("Structured response `content` must be a string.")
if mode == "final":
tool_calls = payload.get("tool_calls", [])
if tool_calls not in ([], None):
raise CodexStructuredOutputError(
f"`mode=final` must not include tool calls, got: {tool_calls!r}"
)
return AIMessage(content=content)
if mode != "tool_calls":
raise CodexStructuredOutputError(f"Unknown structured response mode: {mode!r}")
raw_tool_calls = payload.get("tool_calls")
if not isinstance(raw_tool_calls, list) or not raw_tool_calls:
raise CodexStructuredOutputError("`mode=tool_calls` requires a non-empty tool_calls array.")
tool_calls: list[dict[str, Any]] = []
tool_parameters = {
tool.get("function", {}).get("name"): tool.get("function", {}).get("parameters", {})
for tool in tools
}
for item in raw_tool_calls:
if not isinstance(item, dict):
raise CodexStructuredOutputError(f"Tool call entries must be objects, got: {item!r}")
name = item.get("name")
arguments = self._extract_tool_arguments(
item,
tool_arguments_as_json_string=tool_arguments_as_json_string,
)
if not isinstance(name, str) or not isinstance(arguments, dict):
raise CodexStructuredOutputError(
f"Tool call entries must include string name and object arguments, got: {item!r}"
)
if name not in tool_parameters:
raise CodexStructuredOutputError(
f"Tool call name '{name}' is not in the bound tool set."
)
self._validate_tool_arguments(name, arguments, tool_parameters[name])
tool_calls.append(
{
"name": name,
"args": arguments,
"id": f"call_{uuid.uuid4().hex}",
}
)
return AIMessage(content=content, tool_calls=tool_calls)
def _extract_tool_arguments(
self,
item: dict[str, Any],
*,
tool_arguments_as_json_string: bool,
) -> dict[str, Any]:
if tool_arguments_as_json_string:
raw_arguments = item.get("arguments_json")
if not isinstance(raw_arguments, str):
raise CodexStructuredOutputError(
f"Tool call entries must include string arguments_json, got: {item!r}"
)
try:
parsed = json.loads(raw_arguments)
except json.JSONDecodeError as exc:
raise CodexStructuredOutputError(
f"Tool call arguments_json must contain valid JSON, got: {raw_arguments!r}"
) from exc
if not isinstance(parsed, dict):
raise CodexStructuredOutputError(
f"Tool call arguments_json must decode to an object, got: {parsed!r}"
)
return parsed
arguments = item.get("arguments")
if not isinstance(arguments, dict):
raise CodexStructuredOutputError(
f"Tool call entries must include object arguments, got: {item!r}"
)
return arguments
def _validate_tool_arguments(
self,
tool_name: str,
arguments: dict[str, Any],
schema: dict[str, Any] | None,
) -> None:
if not isinstance(schema, dict):
return
properties = schema.get("properties")
if properties is not None and not isinstance(properties, dict):
raise CodexStructuredOutputError(
f"Tool schema for '{tool_name}' has invalid properties metadata."
)
required = schema.get("required") or []
if isinstance(required, list):
missing = [name for name in required if name not in arguments]
if missing:
raise CodexStructuredOutputError(
f"Tool call '{tool_name}' is missing required arguments: {', '.join(missing)}"
)
if properties and schema.get("additionalProperties") is False:
unexpected = [name for name in arguments if name not in properties]
if unexpected:
raise CodexStructuredOutputError(
f"Tool call '{tool_name}' included unexpected arguments: {', '.join(unexpected)}"
)
def _session_or_create(self) -> CodexAppServerSession:
with self._session_lock:
if self._session is None:
factory = self.session_factory or CodexAppServerSession
self._session = factory(
codex_binary=self.codex_binary,
request_timeout=self.codex_request_timeout,
workspace_dir=self.codex_workspace_dir,
cleanup_threads=self.codex_cleanup_threads,
)
self._session.start()
return self._session
def _resolve_tool_binding(
self,
tools: Sequence[dict[str, Any]],
tool_choice: Any,
) -> dict[str, Any]:
tool_list = list(tools)
if not tool_list:
return {
"tools": [],
"tool_choice": None,
"output_schema": build_plain_response_schema(),
"tool_arguments_as_json_string": False,
}
if tool_choice in (None, "auto"):
return {
"tools": tool_list,
"tool_choice": None if tool_choice is None else "auto",
"output_schema": build_tool_response_schema(tool_list, allow_final=True),
"tool_arguments_as_json_string": len(tool_list) > 1,
}
if tool_choice in (False, "none"):
return {
"tools": [],
"tool_choice": "none",
"output_schema": build_plain_response_schema(),
"tool_arguments_as_json_string": False,
}
if tool_choice in (True, "any", "required"):
normalized_choice = "required" if tool_choice in (True, "required") else "any"
return {
"tools": tool_list,
"tool_choice": normalized_choice,
"output_schema": build_tool_response_schema(tool_list, allow_final=False),
"tool_arguments_as_json_string": len(tool_list) > 1,
}
selected_tool_name = self._extract_named_tool_choice(tool_choice)
if selected_tool_name is None:
raise CodexStructuredOutputError(
f"Unsupported Codex tool_choice value: {tool_choice!r}"
)
selected_tools = [
tool
for tool in tool_list
if tool.get("function", {}).get("name") == selected_tool_name
]
if not selected_tools:
available = ", ".join(
tool.get("function", {}).get("name", "<unknown>")
for tool in tool_list
)
raise CodexStructuredOutputError(
f"Requested tool_choice '{selected_tool_name}' is not in the bound tool set. "
f"Available tools: {available}"
)
return {
"tools": selected_tools,
"tool_choice": selected_tool_name,
"output_schema": build_tool_response_schema(selected_tools, allow_final=False),
"tool_arguments_as_json_string": False,
}
def _extract_named_tool_choice(self, tool_choice: Any) -> str | None:
if isinstance(tool_choice, str):
return tool_choice
if not isinstance(tool_choice, dict):
return None
function = tool_choice.get("function")
if isinstance(function, dict):
name = function.get("name")
if isinstance(name, str) and name:
return name
name = tool_choice.get("name")
if isinstance(name, str) and name:
return name
return None