From 3d8341c10462e4503071af7a64cb9903e1174014 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 14 Apr 2026 17:32:11 -0400 Subject: [PATCH] fix: report LLM calls, tool calls, and token usage for claude_agent MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ChatClaudeAgent is a plain Runnable rather than a BaseChatModel, so LangChain's callback system never fired on_chat_model_start / on_llm_end for it — leaving the CLI TUI stuck on "LLM: 0" and "Tokens: --" during runs. Pop callbacks out of the LLM kwargs, invoke them manually around each SDK call, and attach usage_metadata extracted from the SDK's ResultMessage (input, output, total — including cached input) to the returned AIMessage so downstream handlers pick it up. Tool callbacks now also fire through the MCP wrapper: forward the callback list into each wrapped LangChain tool's invocation config so StatsCallbackHandler sees on_tool_start/on_tool_end when the SDK loop calls a tool. Verified via direct StatsCallbackHandler round-trip on both Shape A (ChatClaudeAgent.invoke) and Shape B (run_sdk_analyst): llm_calls, tool_calls, tokens_in, and tokens_out all increment as expected. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../agents/analysts/_claude_agent_runner.py | 29 ++++-- .../llm_clients/claude_agent_client.py | 95 ++++++++++++++++++- tradingagents/llm_clients/mcp_tool_adapter.py | 17 +++- 3 files changed, 126 insertions(+), 15 deletions(-) diff --git a/tradingagents/agents/analysts/_claude_agent_runner.py b/tradingagents/agents/analysts/_claude_agent_runner.py index 9223548f..3e513d7c 100644 --- a/tradingagents/agents/analysts/_claude_agent_runner.py +++ b/tradingagents/agents/analysts/_claude_agent_runner.py @@ -23,7 +23,11 @@ from typing import Any, Dict, List from langchain_core.messages import AIMessage, HumanMessage -from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent +from tradingagents.llm_clients.claude_agent_client import ( + ChatClaudeAgent, + extract_usage, + fire_llm_callbacks, +) from tradingagents.llm_clients.mcp_tool_adapter import build_mcp_server @@ -109,7 +113,8 @@ async def _run( lc_tools: List[Any], server_name: str, model: str, -) -> str: + callbacks: List[Any], +) -> tuple[str, Dict[str, int]]: from claude_agent_sdk import ( AssistantMessage, ClaudeAgentOptions, @@ -119,7 +124,7 @@ async def _run( _log(f"[{server_name}] building MCP server with {len(lc_tools)} tools: " f"{[t.name for t in lc_tools]}") - server, allowed = build_mcp_server(server_name, lc_tools) + server, allowed = build_mcp_server(server_name, lc_tools, callbacks=callbacks) _log(f"[{server_name}] allowed_tools={allowed}") options = ClaudeAgentOptions( @@ -140,6 +145,7 @@ async def _run( start = time.monotonic() text_parts: List[str] = [] + final_usage: Dict[str, int] = {} msg_count = 0 async for msg in query(prompt=user_prompt, options=options): msg_count += 1 @@ -149,11 +155,15 @@ async def _run( for block in msg.content: if isinstance(block, TextBlock): text_parts.append(block.text) + sdk_usage = getattr(msg, "usage", None) + if isinstance(sdk_usage, dict) and sdk_usage: + final_usage = extract_usage(sdk_usage) elapsed = time.monotonic() - start _log(f"[{server_name}] query complete after {elapsed:.1f}s, " - f"{msg_count} messages, {sum(len(t) for t in text_parts)} chars") - return "\n".join(text_parts).strip() + f"{msg_count} messages, {sum(len(t) for t in text_parts)} chars, " + f"usage={final_usage}") + return "\n".join(text_parts).strip(), final_usage def run_sdk_analyst( @@ -169,20 +179,23 @@ def run_sdk_analyst( _log(f"=== run_sdk_analyst start: server={server_name} report_field={report_field} " f"ticker={state.get('company_of_interest')!r} date={state.get('trade_date')!r} ===") try: - report = asyncio.run( + report, usage = asyncio.run( _run( system_prompt=system_prompt, user_prompt=user_prompt, lc_tools=lc_tools, server_name=server_name, model=llm.model, + callbacks=llm.callbacks, ) ) except Exception as e: _log(f"[{server_name}] EXCEPTION: {type(e).__name__}: {e}") raise - _log(f"=== run_sdk_analyst done: {report_field}={len(report)} chars ===") + _log(f"=== run_sdk_analyst done: {report_field}={len(report)} chars usage={usage} ===") + message = AIMessage(content=report, usage_metadata=usage or None) + fire_llm_callbacks(llm.callbacks, message, user_prompt) return { - "messages": [AIMessage(content=report)], + "messages": [message], report_field: report, } diff --git a/tradingagents/llm_clients/claude_agent_client.py b/tradingagents/llm_clients/claude_agent_client.py index ccdcf83b..dd0a6e9c 100644 --- a/tradingagents/llm_clients/claude_agent_client.py +++ b/tradingagents/llm_clients/claude_agent_client.py @@ -10,9 +10,11 @@ Shape B. """ import asyncio -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple +from uuid import uuid4 from langchain_core.messages import AIMessage, BaseMessage, SystemMessage +from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.prompt_values import PromptValue from langchain_core.runnables import Runnable @@ -69,6 +71,79 @@ def _coerce_input(input: Any) -> Tuple[Optional[str], str]: return system_prompt, user_prompt +def extract_usage(sdk_usage: Any) -> Dict[str, int]: + """Normalize the SDK's `usage` dict into LangChain's usage_metadata shape. + + Accepts either a plain dict (ResultMessage.usage) or None. Returns a dict + with ``input_tokens``, ``output_tokens``, ``total_tokens`` keys. + """ + if not isinstance(sdk_usage, dict): + return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + # The SDK mirrors Anthropic usage shape. Be defensive across versions. + input_tokens = ( + sdk_usage.get("input_tokens") + or sdk_usage.get("prompt_tokens") + or 0 + ) + output_tokens = ( + sdk_usage.get("output_tokens") + or sdk_usage.get("completion_tokens") + or 0 + ) + # Count cached input against the input budget too so the TUI reflects it. + input_tokens += sdk_usage.get("cache_read_input_tokens", 0) or 0 + input_tokens += sdk_usage.get("cache_creation_input_tokens", 0) or 0 + total = input_tokens + output_tokens + return { + "input_tokens": int(input_tokens), + "output_tokens": int(output_tokens), + "total_tokens": int(total), + } + + +def fire_llm_callbacks( + callbacks: List[Any], + message: AIMessage, + prompt_preview: str, +) -> None: + """Manually fire on_chat_model_start + on_llm_end on the given handlers. + + ChatClaudeAgent is a plain Runnable, so LangChain does not fire chat-model + callbacks automatically. We invoke them ourselves so stats handlers + (StatsCallbackHandler in the CLI, etc.) see LLM calls and token usage. + """ + if not callbacks: + return + run_id = uuid4() + serialized = {"name": "ChatClaudeAgent"} + messages = [[{"role": "user", "content": prompt_preview}]] + for cb in callbacks: + if hasattr(cb, "on_chat_model_start"): + try: + cb.on_chat_model_start(serialized, messages, run_id=run_id) + except TypeError: + # Some handlers don't accept run_id; best-effort. + try: + cb.on_chat_model_start(serialized, messages) + except Exception: + pass + except Exception: + pass + + result = LLMResult(generations=[[ChatGeneration(message=message)]]) + for cb in callbacks: + if hasattr(cb, "on_llm_end"): + try: + cb.on_llm_end(result, run_id=run_id) + except TypeError: + try: + cb.on_llm_end(result) + except Exception: + pass + except Exception: + pass + + class ChatClaudeAgent(Runnable): """LangChain-compatible Runnable that routes inference through claude-agent-sdk. @@ -78,11 +153,16 @@ class ChatClaudeAgent(Runnable): def __init__(self, model: str, **kwargs: Any) -> None: self.model = model + # Pull callbacks out so we can fire them manually around each invoke — + # Runnable doesn't trigger chat-model callbacks the way BaseChatModel does. + self.callbacks: List[Any] = list(kwargs.pop("callbacks", None) or []) self.kwargs = kwargs def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> AIMessage: system_prompt, prompt = _coerce_input(input) - return asyncio.run(self._ainvoke(prompt, system_prompt)) + message = asyncio.run(self._ainvoke(prompt, system_prompt)) + fire_llm_callbacks(self.callbacks, message, prompt) + return message async def _ainvoke(self, prompt: str, system_prompt: Optional[str]) -> AIMessage: from claude_agent_sdk import ( @@ -104,13 +184,22 @@ class ChatClaudeAgent(Runnable): options = ClaudeAgentOptions(**options_kwargs) text_parts: List[str] = [] + final_usage: Dict[str, int] = {} async for msg in query(prompt=prompt, options=options): if isinstance(msg, AssistantMessage): for block in msg.content: if isinstance(block, TextBlock): text_parts.append(block.text) + # The ResultMessage at the end carries cumulative usage; prefer it. + # Fall back to AssistantMessage.usage if ResultMessage omits it. + sdk_usage = getattr(msg, "usage", None) + if isinstance(sdk_usage, dict) and sdk_usage: + final_usage = extract_usage(sdk_usage) - return AIMessage(content="\n".join(text_parts)) + return AIMessage( + content="\n".join(text_parts), + usage_metadata=final_usage or None, + ) def bind_tools(self, tools: Any, **kwargs: Any) -> Any: raise NotImplementedError( diff --git a/tradingagents/llm_clients/mcp_tool_adapter.py b/tradingagents/llm_clients/mcp_tool_adapter.py index bd74f1fd..33665456 100644 --- a/tradingagents/llm_clients/mcp_tool_adapter.py +++ b/tradingagents/llm_clients/mcp_tool_adapter.py @@ -6,20 +6,23 @@ dict and returns {"content": [{"type": "text", "text": str}]}. Used by the SDK-native analyst runner to let Claude Code (authenticated via a Max/Pro subscription) call the same data tools the legacy analyst path uses. +Callbacks passed in from the graph are forwarded into each tool invocation so +that StatsCallbackHandler (and any other handler) sees on_tool_start/end. """ -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from claude_agent_sdk import create_sdk_mcp_server, tool -def _wrap_lc_tool(lc_tool: Any): +def _wrap_lc_tool(lc_tool: Any, callbacks: Optional[List[Any]]): """Wrap a single LangChain BaseTool as an SDK @tool-decorated async callable.""" schema = ( lc_tool.args_schema.model_json_schema() if lc_tool.args_schema is not None else {"type": "object", "properties": {}} ) + config = {"callbacks": callbacks} if callbacks else None @tool( name=lc_tool.name, @@ -27,7 +30,8 @@ def _wrap_lc_tool(lc_tool: Any): input_schema=schema, ) async def _wrapped(args: Dict[str, Any]) -> Dict[str, Any]: - result = lc_tool.invoke(args) + # Pass callbacks via config so BaseTool fires on_tool_start/on_tool_end. + result = lc_tool.invoke(args, config=config) if config else lc_tool.invoke(args) return {"content": [{"type": "text", "text": str(result)}]} return _wrapped @@ -36,13 +40,18 @@ def _wrap_lc_tool(lc_tool: Any): def build_mcp_server( server_name: str, lc_tools: List[Any], + callbacks: Optional[List[Any]] = None, ) -> Tuple[Any, List[str]]: """Build an in-process MCP server from LangChain tools. Returns the server instance and the list of fully-qualified tool names (``mcp____``) suitable for passing to ``allowed_tools``. + + ``callbacks`` are forwarded into each tool's LangChain config so that + on_tool_start/on_tool_end fire on the stats handler during SDK-driven + tool calls. """ - wrapped = [_wrap_lc_tool(t) for t in lc_tools] + wrapped = [_wrap_lc_tool(t, callbacks) for t in lc_tools] server = create_sdk_mcp_server(name=server_name, version="1.0.0", tools=wrapped) allowed = [f"mcp__{server_name}__{t.name}" for t in lc_tools] return server, allowed