237 lines
8.4 KiB
Python
237 lines
8.4 KiB
Python
import json
|
|
from typing import Any, Iterable, Mapping, Sequence
|
|
|
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
|
|
|
|
|
|
class CodexMessageCodecError(ValueError):
|
|
"""Raised when TradingAgents inputs cannot be normalized for Codex."""
|
|
|
|
|
|
def normalize_input_messages(
|
|
value: str | Sequence[BaseMessage | Mapping[str, Any]],
|
|
) -> list[BaseMessage]:
|
|
"""Normalize TradingAgents model inputs into LangChain messages."""
|
|
if isinstance(value, str):
|
|
return [HumanMessage(content=value)]
|
|
|
|
normalized: list[BaseMessage] = []
|
|
for item in value:
|
|
if isinstance(item, BaseMessage):
|
|
normalized.append(item)
|
|
continue
|
|
|
|
if not isinstance(item, Mapping):
|
|
raise CodexMessageCodecError(
|
|
f"Unsupported message input type: {type(item).__name__}"
|
|
)
|
|
|
|
normalized.append(_message_from_dict(item))
|
|
|
|
return normalized
|
|
|
|
|
|
def format_messages_for_codex(
|
|
messages: Sequence[BaseMessage],
|
|
*,
|
|
tool_names: Iterable[str] = (),
|
|
tool_schemas: Sequence[Mapping[str, Any]] = (),
|
|
tool_choice: str | None = None,
|
|
tool_arguments_as_json_string: bool = False,
|
|
retry_message: str | None = None,
|
|
) -> str:
|
|
"""Render a chat transcript into a single text prompt for Codex."""
|
|
tool_list = list(tool_names)
|
|
lines = [
|
|
"You are answering on behalf of TradingAgents.",
|
|
"The conversation transcript is provided below.",
|
|
"Treat tool outputs as authoritative execution results from the host application.",
|
|
]
|
|
if tool_list:
|
|
lines.append(
|
|
"If external data is still needed, respond with tool calls using only these tools: "
|
|
+ ", ".join(tool_list)
|
|
+ "."
|
|
)
|
|
else:
|
|
lines.append("No host tools are available for this turn.")
|
|
if tool_choice == "none":
|
|
lines.append("Do not request tool calls for this turn.")
|
|
elif tool_choice in {"any", "required"}:
|
|
lines.append("You must respond with one or more tool calls for this turn.")
|
|
elif tool_choice and tool_choice != "auto":
|
|
lines.append(f"You must call the tool named `{tool_choice}` for this turn.")
|
|
elif tool_choice == "auto":
|
|
lines.append("Use tool calls only if they are necessary to answer correctly.")
|
|
if tool_arguments_as_json_string:
|
|
lines.append(
|
|
"When returning tool calls, encode each tool argument object as a JSON string in `arguments_json`."
|
|
)
|
|
schema_lines = _format_tool_schema_lines(tool_schemas)
|
|
if schema_lines:
|
|
lines.append("Tool argument requirements:")
|
|
lines.extend(schema_lines)
|
|
lines.append("Respond only with JSON that matches the requested output schema.")
|
|
if retry_message:
|
|
lines.append(retry_message)
|
|
|
|
transcript: list[str] = []
|
|
for message in messages:
|
|
transcript.append(_format_message(message))
|
|
|
|
return "\n\n".join(lines + ["Conversation transcript:", *transcript])
|
|
|
|
|
|
def strip_json_fence(text: str) -> str:
|
|
stripped = text.strip()
|
|
if stripped.startswith("```"):
|
|
parts = stripped.split("```")
|
|
if len(parts) >= 3:
|
|
candidate = parts[1]
|
|
if candidate.lstrip().startswith("json"):
|
|
candidate = candidate.lstrip()[4:]
|
|
return candidate.strip()
|
|
return stripped
|
|
|
|
|
|
def _message_from_dict(message: Mapping[str, Any]) -> BaseMessage:
|
|
role = str(message.get("role", "")).lower()
|
|
content = _content_to_text(message.get("content", ""))
|
|
|
|
if role == "system":
|
|
return SystemMessage(content=content)
|
|
if role == "user":
|
|
return HumanMessage(content=content)
|
|
if role == "tool":
|
|
tool_call_id = str(message.get("tool_call_id") or message.get("toolCallId") or "")
|
|
if not tool_call_id:
|
|
raise CodexMessageCodecError("Tool messages require tool_call_id.")
|
|
return ToolMessage(content=content, tool_call_id=tool_call_id)
|
|
if role == "assistant":
|
|
raw_tool_calls = message.get("tool_calls") or message.get("toolCalls") or []
|
|
tool_calls = _normalize_tool_calls(raw_tool_calls)
|
|
return AIMessage(content=content, tool_calls=tool_calls)
|
|
|
|
raise CodexMessageCodecError(f"Unsupported message role: {role!r}")
|
|
|
|
|
|
def _normalize_tool_calls(raw_tool_calls: Any) -> list[dict[str, Any]]:
|
|
normalized: list[dict[str, Any]] = []
|
|
if not raw_tool_calls:
|
|
return normalized
|
|
|
|
if not isinstance(raw_tool_calls, Sequence):
|
|
raise CodexMessageCodecError("assistant.tool_calls must be a sequence")
|
|
|
|
for item in raw_tool_calls:
|
|
if not isinstance(item, Mapping):
|
|
raise CodexMessageCodecError("assistant.tool_calls items must be objects")
|
|
|
|
if "function" in item:
|
|
function = item.get("function")
|
|
if not isinstance(function, Mapping):
|
|
raise CodexMessageCodecError("assistant.tool_calls.function must be an object")
|
|
raw_args = function.get("arguments", {})
|
|
if isinstance(raw_args, str):
|
|
try:
|
|
args = json.loads(raw_args)
|
|
except json.JSONDecodeError as exc:
|
|
raise CodexMessageCodecError(
|
|
f"assistant tool arguments must be valid JSON: {raw_args!r}"
|
|
) from exc
|
|
else:
|
|
args = raw_args
|
|
if not isinstance(args, Mapping):
|
|
raise CodexMessageCodecError("assistant tool arguments must decode to an object")
|
|
normalized.append(
|
|
{
|
|
"name": str(function.get("name", "")),
|
|
"args": dict(args),
|
|
"id": str(item.get("id") or ""),
|
|
}
|
|
)
|
|
continue
|
|
|
|
args = item.get("args", {})
|
|
if not isinstance(args, Mapping):
|
|
raise CodexMessageCodecError("assistant tool args must be an object")
|
|
normalized.append(
|
|
{
|
|
"name": str(item.get("name", "")),
|
|
"args": dict(args),
|
|
"id": str(item.get("id") or ""),
|
|
}
|
|
)
|
|
|
|
return normalized
|
|
|
|
|
|
def _format_message(message: BaseMessage) -> str:
|
|
role = type(message).__name__.replace("Message", "") or "Message"
|
|
body = _content_to_text(message.content)
|
|
|
|
if isinstance(message, AIMessage) and message.tool_calls:
|
|
tool_call_json = json.dumps(
|
|
[
|
|
{
|
|
"id": tool_call.get("id"),
|
|
"name": tool_call.get("name"),
|
|
"args": tool_call.get("args", {}),
|
|
}
|
|
for tool_call in message.tool_calls
|
|
],
|
|
ensure_ascii=False,
|
|
indent=2,
|
|
sort_keys=True,
|
|
)
|
|
return f"[{role}]\n{body}\nTool calls:\n{tool_call_json}".strip()
|
|
|
|
if isinstance(message, ToolMessage):
|
|
return f"[Tool:{message.tool_call_id}]\n{body}".strip()
|
|
|
|
return f"[{role}]\n{body}".strip()
|
|
|
|
|
|
def _content_to_text(content: Any) -> str:
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
parts: list[str] = []
|
|
for item in content:
|
|
if isinstance(item, str):
|
|
parts.append(item)
|
|
elif isinstance(item, Mapping):
|
|
text = item.get("text")
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
else:
|
|
parts.append(json.dumps(dict(item), ensure_ascii=False))
|
|
else:
|
|
parts.append(str(item))
|
|
return "\n".join(part for part in parts if part)
|
|
if content is None:
|
|
return ""
|
|
return str(content)
|
|
|
|
|
|
def _format_tool_schema_lines(tool_schemas: Sequence[Mapping[str, Any]]) -> list[str]:
|
|
lines: list[str] = []
|
|
for tool_schema in tool_schemas:
|
|
function = tool_schema.get("function")
|
|
if not isinstance(function, Mapping):
|
|
continue
|
|
name = function.get("name")
|
|
parameters = function.get("parameters") or {}
|
|
if not isinstance(name, str) or not isinstance(parameters, Mapping):
|
|
continue
|
|
required = parameters.get("required") or []
|
|
properties = parameters.get("properties") or {}
|
|
summary = {
|
|
"required": required if isinstance(required, list) else [],
|
|
"properties": properties if isinstance(properties, Mapping) else {},
|
|
}
|
|
lines.append(
|
|
f"- {name}: {json.dumps(summary, ensure_ascii=False, sort_keys=True)}"
|
|
)
|
|
return lines
|