fix: report LLM calls, tool calls, and token usage for claude_agent
ChatClaudeAgent is a plain Runnable rather than a BaseChatModel, so LangChain's callback system never fired on_chat_model_start / on_llm_end for it — leaving the CLI TUI stuck on "LLM: 0" and "Tokens: --" during runs. Pop callbacks out of the LLM kwargs, invoke them manually around each SDK call, and attach usage_metadata extracted from the SDK's ResultMessage (input, output, total — including cached input) to the returned AIMessage so downstream handlers pick it up. Tool callbacks now also fire through the MCP wrapper: forward the callback list into each wrapped LangChain tool's invocation config so StatsCallbackHandler sees on_tool_start/on_tool_end when the SDK loop calls a tool. Verified via direct StatsCallbackHandler round-trip on both Shape A (ChatClaudeAgent.invoke) and Shape B (run_sdk_analyst): llm_calls, tool_calls, tokens_in, and tokens_out all increment as expected. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
39182785ce
commit
3d8341c104
|
|
@ -23,7 +23,11 @@ from typing import Any, Dict, List
|
|||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
|
||||
from tradingagents.llm_clients.claude_agent_client import (
|
||||
ChatClaudeAgent,
|
||||
extract_usage,
|
||||
fire_llm_callbacks,
|
||||
)
|
||||
from tradingagents.llm_clients.mcp_tool_adapter import build_mcp_server
|
||||
|
||||
|
||||
|
|
@ -109,7 +113,8 @@ async def _run(
|
|||
lc_tools: List[Any],
|
||||
server_name: str,
|
||||
model: str,
|
||||
) -> str:
|
||||
callbacks: List[Any],
|
||||
) -> tuple[str, Dict[str, int]]:
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ClaudeAgentOptions,
|
||||
|
|
@ -119,7 +124,7 @@ async def _run(
|
|||
|
||||
_log(f"[{server_name}] building MCP server with {len(lc_tools)} tools: "
|
||||
f"{[t.name for t in lc_tools]}")
|
||||
server, allowed = build_mcp_server(server_name, lc_tools)
|
||||
server, allowed = build_mcp_server(server_name, lc_tools, callbacks=callbacks)
|
||||
_log(f"[{server_name}] allowed_tools={allowed}")
|
||||
|
||||
options = ClaudeAgentOptions(
|
||||
|
|
@ -140,6 +145,7 @@ async def _run(
|
|||
start = time.monotonic()
|
||||
|
||||
text_parts: List[str] = []
|
||||
final_usage: Dict[str, int] = {}
|
||||
msg_count = 0
|
||||
async for msg in query(prompt=user_prompt, options=options):
|
||||
msg_count += 1
|
||||
|
|
@ -149,11 +155,15 @@ async def _run(
|
|||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
text_parts.append(block.text)
|
||||
sdk_usage = getattr(msg, "usage", None)
|
||||
if isinstance(sdk_usage, dict) and sdk_usage:
|
||||
final_usage = extract_usage(sdk_usage)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
_log(f"[{server_name}] query complete after {elapsed:.1f}s, "
|
||||
f"{msg_count} messages, {sum(len(t) for t in text_parts)} chars")
|
||||
return "\n".join(text_parts).strip()
|
||||
f"{msg_count} messages, {sum(len(t) for t in text_parts)} chars, "
|
||||
f"usage={final_usage}")
|
||||
return "\n".join(text_parts).strip(), final_usage
|
||||
|
||||
|
||||
def run_sdk_analyst(
|
||||
|
|
@ -169,20 +179,23 @@ def run_sdk_analyst(
|
|||
_log(f"=== run_sdk_analyst start: server={server_name} report_field={report_field} "
|
||||
f"ticker={state.get('company_of_interest')!r} date={state.get('trade_date')!r} ===")
|
||||
try:
|
||||
report = asyncio.run(
|
||||
report, usage = asyncio.run(
|
||||
_run(
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
lc_tools=lc_tools,
|
||||
server_name=server_name,
|
||||
model=llm.model,
|
||||
callbacks=llm.callbacks,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_log(f"[{server_name}] EXCEPTION: {type(e).__name__}: {e}")
|
||||
raise
|
||||
_log(f"=== run_sdk_analyst done: {report_field}={len(report)} chars ===")
|
||||
_log(f"=== run_sdk_analyst done: {report_field}={len(report)} chars usage={usage} ===")
|
||||
message = AIMessage(content=report, usage_metadata=usage or None)
|
||||
fire_llm_callbacks(llm.callbacks, message, user_prompt)
|
||||
return {
|
||||
"messages": [AIMessage(content=report)],
|
||||
"messages": [message],
|
||||
report_field: report,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,9 +10,11 @@ Shape B.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
|
|
@ -69,6 +71,79 @@ def _coerce_input(input: Any) -> Tuple[Optional[str], str]:
|
|||
return system_prompt, user_prompt
|
||||
|
||||
|
||||
def extract_usage(sdk_usage: Any) -> Dict[str, int]:
|
||||
"""Normalize the SDK's `usage` dict into LangChain's usage_metadata shape.
|
||||
|
||||
Accepts either a plain dict (ResultMessage.usage) or None. Returns a dict
|
||||
with ``input_tokens``, ``output_tokens``, ``total_tokens`` keys.
|
||||
"""
|
||||
if not isinstance(sdk_usage, dict):
|
||||
return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
# The SDK mirrors Anthropic usage shape. Be defensive across versions.
|
||||
input_tokens = (
|
||||
sdk_usage.get("input_tokens")
|
||||
or sdk_usage.get("prompt_tokens")
|
||||
or 0
|
||||
)
|
||||
output_tokens = (
|
||||
sdk_usage.get("output_tokens")
|
||||
or sdk_usage.get("completion_tokens")
|
||||
or 0
|
||||
)
|
||||
# Count cached input against the input budget too so the TUI reflects it.
|
||||
input_tokens += sdk_usage.get("cache_read_input_tokens", 0) or 0
|
||||
input_tokens += sdk_usage.get("cache_creation_input_tokens", 0) or 0
|
||||
total = input_tokens + output_tokens
|
||||
return {
|
||||
"input_tokens": int(input_tokens),
|
||||
"output_tokens": int(output_tokens),
|
||||
"total_tokens": int(total),
|
||||
}
|
||||
|
||||
|
||||
def fire_llm_callbacks(
|
||||
callbacks: List[Any],
|
||||
message: AIMessage,
|
||||
prompt_preview: str,
|
||||
) -> None:
|
||||
"""Manually fire on_chat_model_start + on_llm_end on the given handlers.
|
||||
|
||||
ChatClaudeAgent is a plain Runnable, so LangChain does not fire chat-model
|
||||
callbacks automatically. We invoke them ourselves so stats handlers
|
||||
(StatsCallbackHandler in the CLI, etc.) see LLM calls and token usage.
|
||||
"""
|
||||
if not callbacks:
|
||||
return
|
||||
run_id = uuid4()
|
||||
serialized = {"name": "ChatClaudeAgent"}
|
||||
messages = [[{"role": "user", "content": prompt_preview}]]
|
||||
for cb in callbacks:
|
||||
if hasattr(cb, "on_chat_model_start"):
|
||||
try:
|
||||
cb.on_chat_model_start(serialized, messages, run_id=run_id)
|
||||
except TypeError:
|
||||
# Some handlers don't accept run_id; best-effort.
|
||||
try:
|
||||
cb.on_chat_model_start(serialized, messages)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
result = LLMResult(generations=[[ChatGeneration(message=message)]])
|
||||
for cb in callbacks:
|
||||
if hasattr(cb, "on_llm_end"):
|
||||
try:
|
||||
cb.on_llm_end(result, run_id=run_id)
|
||||
except TypeError:
|
||||
try:
|
||||
cb.on_llm_end(result)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class ChatClaudeAgent(Runnable):
|
||||
"""LangChain-compatible Runnable that routes inference through claude-agent-sdk.
|
||||
|
||||
|
|
@ -78,11 +153,16 @@ class ChatClaudeAgent(Runnable):
|
|||
|
||||
def __init__(self, model: str, **kwargs: Any) -> None:
|
||||
self.model = model
|
||||
# Pull callbacks out so we can fire them manually around each invoke —
|
||||
# Runnable doesn't trigger chat-model callbacks the way BaseChatModel does.
|
||||
self.callbacks: List[Any] = list(kwargs.pop("callbacks", None) or [])
|
||||
self.kwargs = kwargs
|
||||
|
||||
def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> AIMessage:
|
||||
system_prompt, prompt = _coerce_input(input)
|
||||
return asyncio.run(self._ainvoke(prompt, system_prompt))
|
||||
message = asyncio.run(self._ainvoke(prompt, system_prompt))
|
||||
fire_llm_callbacks(self.callbacks, message, prompt)
|
||||
return message
|
||||
|
||||
async def _ainvoke(self, prompt: str, system_prompt: Optional[str]) -> AIMessage:
|
||||
from claude_agent_sdk import (
|
||||
|
|
@ -104,13 +184,22 @@ class ChatClaudeAgent(Runnable):
|
|||
options = ClaudeAgentOptions(**options_kwargs)
|
||||
|
||||
text_parts: List[str] = []
|
||||
final_usage: Dict[str, int] = {}
|
||||
async for msg in query(prompt=prompt, options=options):
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
text_parts.append(block.text)
|
||||
# The ResultMessage at the end carries cumulative usage; prefer it.
|
||||
# Fall back to AssistantMessage.usage if ResultMessage omits it.
|
||||
sdk_usage = getattr(msg, "usage", None)
|
||||
if isinstance(sdk_usage, dict) and sdk_usage:
|
||||
final_usage = extract_usage(sdk_usage)
|
||||
|
||||
return AIMessage(content="\n".join(text_parts))
|
||||
return AIMessage(
|
||||
content="\n".join(text_parts),
|
||||
usage_metadata=final_usage or None,
|
||||
)
|
||||
|
||||
def bind_tools(self, tools: Any, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError(
|
||||
|
|
|
|||
|
|
@ -6,20 +6,23 @@ dict and returns {"content": [{"type": "text", "text": str}]}.
|
|||
|
||||
Used by the SDK-native analyst runner to let Claude Code (authenticated via a
|
||||
Max/Pro subscription) call the same data tools the legacy analyst path uses.
|
||||
Callbacks passed in from the graph are forwarded into each tool invocation so
|
||||
that StatsCallbackHandler (and any other handler) sees on_tool_start/end.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
|
||||
def _wrap_lc_tool(lc_tool: Any):
|
||||
def _wrap_lc_tool(lc_tool: Any, callbacks: Optional[List[Any]]):
|
||||
"""Wrap a single LangChain BaseTool as an SDK @tool-decorated async callable."""
|
||||
schema = (
|
||||
lc_tool.args_schema.model_json_schema()
|
||||
if lc_tool.args_schema is not None
|
||||
else {"type": "object", "properties": {}}
|
||||
)
|
||||
config = {"callbacks": callbacks} if callbacks else None
|
||||
|
||||
@tool(
|
||||
name=lc_tool.name,
|
||||
|
|
@ -27,7 +30,8 @@ def _wrap_lc_tool(lc_tool: Any):
|
|||
input_schema=schema,
|
||||
)
|
||||
async def _wrapped(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = lc_tool.invoke(args)
|
||||
# Pass callbacks via config so BaseTool fires on_tool_start/on_tool_end.
|
||||
result = lc_tool.invoke(args, config=config) if config else lc_tool.invoke(args)
|
||||
return {"content": [{"type": "text", "text": str(result)}]}
|
||||
|
||||
return _wrapped
|
||||
|
|
@ -36,13 +40,18 @@ def _wrap_lc_tool(lc_tool: Any):
|
|||
def build_mcp_server(
|
||||
server_name: str,
|
||||
lc_tools: List[Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
) -> Tuple[Any, List[str]]:
|
||||
"""Build an in-process MCP server from LangChain tools.
|
||||
|
||||
Returns the server instance and the list of fully-qualified tool names
|
||||
(``mcp__<server>__<tool>``) suitable for passing to ``allowed_tools``.
|
||||
|
||||
``callbacks`` are forwarded into each tool's LangChain config so that
|
||||
on_tool_start/on_tool_end fire on the stats handler during SDK-driven
|
||||
tool calls.
|
||||
"""
|
||||
wrapped = [_wrap_lc_tool(t) for t in lc_tools]
|
||||
wrapped = [_wrap_lc_tool(t, callbacks) for t in lc_tools]
|
||||
server = create_sdk_mcp_server(name=server_name, version="1.0.0", tools=wrapped)
|
||||
allowed = [f"mcp__{server_name}__{t.name}" for t in lc_tools]
|
||||
return server, allowed
|
||||
|
|
|
|||
Loading…
Reference in New Issue