fix: report LLM calls, tool calls, and token usage for claude_agent

ChatClaudeAgent is a plain Runnable rather than a BaseChatModel, so
LangChain's callback system never fired on_chat_model_start / on_llm_end
for it — leaving the CLI TUI stuck on "LLM: 0" and "Tokens: --" during
runs. Pop callbacks out of the LLM kwargs, invoke them manually around
each SDK call, and attach usage_metadata extracted from the SDK's
ResultMessage (input, output, total — including cached input) to the
returned AIMessage so downstream handlers pick it up.

Tool callbacks now also fire through the MCP wrapper: forward the
callback list into each wrapped LangChain tool's invocation config so
StatsCallbackHandler sees on_tool_start/on_tool_end when the SDK loop
calls a tool.

Verified via direct StatsCallbackHandler round-trip on both Shape A
(ChatClaudeAgent.invoke) and Shape B (run_sdk_analyst): llm_calls,
tool_calls, tokens_in, and tokens_out all increment as expected.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Michael Yang 2026-04-14 17:32:11 -04:00
parent 39182785ce
commit 3d8341c104
3 changed files with 126 additions and 15 deletions

View File

@ -23,7 +23,11 @@ from typing import Any, Dict, List
from langchain_core.messages import AIMessage, HumanMessage
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
from tradingagents.llm_clients.claude_agent_client import (
ChatClaudeAgent,
extract_usage,
fire_llm_callbacks,
)
from tradingagents.llm_clients.mcp_tool_adapter import build_mcp_server
@ -109,7 +113,8 @@ async def _run(
lc_tools: List[Any],
server_name: str,
model: str,
) -> str:
callbacks: List[Any],
) -> tuple[str, Dict[str, int]]:
from claude_agent_sdk import (
AssistantMessage,
ClaudeAgentOptions,
@ -119,7 +124,7 @@ async def _run(
_log(f"[{server_name}] building MCP server with {len(lc_tools)} tools: "
f"{[t.name for t in lc_tools]}")
server, allowed = build_mcp_server(server_name, lc_tools)
server, allowed = build_mcp_server(server_name, lc_tools, callbacks=callbacks)
_log(f"[{server_name}] allowed_tools={allowed}")
options = ClaudeAgentOptions(
@ -140,6 +145,7 @@ async def _run(
start = time.monotonic()
text_parts: List[str] = []
final_usage: Dict[str, int] = {}
msg_count = 0
async for msg in query(prompt=user_prompt, options=options):
msg_count += 1
@ -149,11 +155,15 @@ async def _run(
for block in msg.content:
if isinstance(block, TextBlock):
text_parts.append(block.text)
sdk_usage = getattr(msg, "usage", None)
if isinstance(sdk_usage, dict) and sdk_usage:
final_usage = extract_usage(sdk_usage)
elapsed = time.monotonic() - start
_log(f"[{server_name}] query complete after {elapsed:.1f}s, "
f"{msg_count} messages, {sum(len(t) for t in text_parts)} chars")
return "\n".join(text_parts).strip()
f"{msg_count} messages, {sum(len(t) for t in text_parts)} chars, "
f"usage={final_usage}")
return "\n".join(text_parts).strip(), final_usage
def run_sdk_analyst(
@ -169,20 +179,23 @@ def run_sdk_analyst(
_log(f"=== run_sdk_analyst start: server={server_name} report_field={report_field} "
f"ticker={state.get('company_of_interest')!r} date={state.get('trade_date')!r} ===")
try:
report = asyncio.run(
report, usage = asyncio.run(
_run(
system_prompt=system_prompt,
user_prompt=user_prompt,
lc_tools=lc_tools,
server_name=server_name,
model=llm.model,
callbacks=llm.callbacks,
)
)
except Exception as e:
_log(f"[{server_name}] EXCEPTION: {type(e).__name__}: {e}")
raise
_log(f"=== run_sdk_analyst done: {report_field}={len(report)} chars ===")
_log(f"=== run_sdk_analyst done: {report_field}={len(report)} chars usage={usage} ===")
message = AIMessage(content=report, usage_metadata=usage or None)
fire_llm_callbacks(llm.callbacks, message, user_prompt)
return {
"messages": [AIMessage(content=report)],
"messages": [message],
report_field: report,
}

View File

@ -10,9 +10,11 @@ Shape B.
"""
import asyncio
from typing import Any, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from uuid import uuid4
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import Runnable
@ -69,6 +71,79 @@ def _coerce_input(input: Any) -> Tuple[Optional[str], str]:
return system_prompt, user_prompt
def extract_usage(sdk_usage: Any) -> Dict[str, int]:
"""Normalize the SDK's `usage` dict into LangChain's usage_metadata shape.
Accepts either a plain dict (ResultMessage.usage) or None. Returns a dict
with ``input_tokens``, ``output_tokens``, ``total_tokens`` keys.
"""
if not isinstance(sdk_usage, dict):
return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
# The SDK mirrors Anthropic usage shape. Be defensive across versions.
input_tokens = (
sdk_usage.get("input_tokens")
or sdk_usage.get("prompt_tokens")
or 0
)
output_tokens = (
sdk_usage.get("output_tokens")
or sdk_usage.get("completion_tokens")
or 0
)
# Count cached input against the input budget too so the TUI reflects it.
input_tokens += sdk_usage.get("cache_read_input_tokens", 0) or 0
input_tokens += sdk_usage.get("cache_creation_input_tokens", 0) or 0
total = input_tokens + output_tokens
return {
"input_tokens": int(input_tokens),
"output_tokens": int(output_tokens),
"total_tokens": int(total),
}
def fire_llm_callbacks(
callbacks: List[Any],
message: AIMessage,
prompt_preview: str,
) -> None:
"""Manually fire on_chat_model_start + on_llm_end on the given handlers.
ChatClaudeAgent is a plain Runnable, so LangChain does not fire chat-model
callbacks automatically. We invoke them ourselves so stats handlers
(StatsCallbackHandler in the CLI, etc.) see LLM calls and token usage.
"""
if not callbacks:
return
run_id = uuid4()
serialized = {"name": "ChatClaudeAgent"}
messages = [[{"role": "user", "content": prompt_preview}]]
for cb in callbacks:
if hasattr(cb, "on_chat_model_start"):
try:
cb.on_chat_model_start(serialized, messages, run_id=run_id)
except TypeError:
# Some handlers don't accept run_id; best-effort.
try:
cb.on_chat_model_start(serialized, messages)
except Exception:
pass
except Exception:
pass
result = LLMResult(generations=[[ChatGeneration(message=message)]])
for cb in callbacks:
if hasattr(cb, "on_llm_end"):
try:
cb.on_llm_end(result, run_id=run_id)
except TypeError:
try:
cb.on_llm_end(result)
except Exception:
pass
except Exception:
pass
class ChatClaudeAgent(Runnable):
"""LangChain-compatible Runnable that routes inference through claude-agent-sdk.
@ -78,11 +153,16 @@ class ChatClaudeAgent(Runnable):
def __init__(self, model: str, **kwargs: Any) -> None:
self.model = model
# Pull callbacks out so we can fire them manually around each invoke —
# Runnable doesn't trigger chat-model callbacks the way BaseChatModel does.
self.callbacks: List[Any] = list(kwargs.pop("callbacks", None) or [])
self.kwargs = kwargs
def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> AIMessage:
system_prompt, prompt = _coerce_input(input)
return asyncio.run(self._ainvoke(prompt, system_prompt))
message = asyncio.run(self._ainvoke(prompt, system_prompt))
fire_llm_callbacks(self.callbacks, message, prompt)
return message
async def _ainvoke(self, prompt: str, system_prompt: Optional[str]) -> AIMessage:
from claude_agent_sdk import (
@ -104,13 +184,22 @@ class ChatClaudeAgent(Runnable):
options = ClaudeAgentOptions(**options_kwargs)
text_parts: List[str] = []
final_usage: Dict[str, int] = {}
async for msg in query(prompt=prompt, options=options):
if isinstance(msg, AssistantMessage):
for block in msg.content:
if isinstance(block, TextBlock):
text_parts.append(block.text)
# The ResultMessage at the end carries cumulative usage; prefer it.
# Fall back to AssistantMessage.usage if ResultMessage omits it.
sdk_usage = getattr(msg, "usage", None)
if isinstance(sdk_usage, dict) and sdk_usage:
final_usage = extract_usage(sdk_usage)
return AIMessage(content="\n".join(text_parts))
return AIMessage(
content="\n".join(text_parts),
usage_metadata=final_usage or None,
)
def bind_tools(self, tools: Any, **kwargs: Any) -> Any:
raise NotImplementedError(

View File

@ -6,20 +6,23 @@ dict and returns {"content": [{"type": "text", "text": str}]}.
Used by the SDK-native analyst runner to let Claude Code (authenticated via a
Max/Pro subscription) call the same data tools the legacy analyst path uses.
Callbacks passed in from the graph are forwarded into each tool invocation so
that StatsCallbackHandler (and any other handler) sees on_tool_start/end.
"""
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
from claude_agent_sdk import create_sdk_mcp_server, tool
def _wrap_lc_tool(lc_tool: Any):
def _wrap_lc_tool(lc_tool: Any, callbacks: Optional[List[Any]]):
"""Wrap a single LangChain BaseTool as an SDK @tool-decorated async callable."""
schema = (
lc_tool.args_schema.model_json_schema()
if lc_tool.args_schema is not None
else {"type": "object", "properties": {}}
)
config = {"callbacks": callbacks} if callbacks else None
@tool(
name=lc_tool.name,
@ -27,7 +30,8 @@ def _wrap_lc_tool(lc_tool: Any):
input_schema=schema,
)
async def _wrapped(args: Dict[str, Any]) -> Dict[str, Any]:
result = lc_tool.invoke(args)
# Pass callbacks via config so BaseTool fires on_tool_start/on_tool_end.
result = lc_tool.invoke(args, config=config) if config else lc_tool.invoke(args)
return {"content": [{"type": "text", "text": str(result)}]}
return _wrapped
@ -36,13 +40,18 @@ def _wrap_lc_tool(lc_tool: Any):
def build_mcp_server(
server_name: str,
lc_tools: List[Any],
callbacks: Optional[List[Any]] = None,
) -> Tuple[Any, List[str]]:
"""Build an in-process MCP server from LangChain tools.
Returns the server instance and the list of fully-qualified tool names
(``mcp__<server>__<tool>``) suitable for passing to ``allowed_tools``.
``callbacks`` are forwarded into each tool's LangChain config so that
on_tool_start/on_tool_end fire on the stats handler during SDK-driven
tool calls.
"""
wrapped = [_wrap_lc_tool(t) for t in lc_tools]
wrapped = [_wrap_lc_tool(t, callbacks) for t in lc_tools]
server = create_sdk_mcp_server(name=server_name, version="1.0.0", tools=wrapped)
allowed = [f"mcp__{server_name}__{t.name}" for t in lc_tools]
return server, allowed