fix: report LLM calls, tool calls, and token usage for claude_agent
ChatClaudeAgent is a plain Runnable rather than a BaseChatModel, so LangChain's callback system never fired on_chat_model_start / on_llm_end for it — leaving the CLI TUI stuck on "LLM: 0" and "Tokens: --" during runs. Pop callbacks out of the LLM kwargs, invoke them manually around each SDK call, and attach usage_metadata extracted from the SDK's ResultMessage (input, output, total — including cached input) to the returned AIMessage so downstream handlers pick it up. Tool callbacks now also fire through the MCP wrapper: forward the callback list into each wrapped LangChain tool's invocation config so StatsCallbackHandler sees on_tool_start/on_tool_end when the SDK loop calls a tool. Verified via direct StatsCallbackHandler round-trip on both Shape A (ChatClaudeAgent.invoke) and Shape B (run_sdk_analyst): llm_calls, tool_calls, tokens_in, and tokens_out all increment as expected. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
39182785ce
commit
3d8341c104
|
|
@ -23,7 +23,11 @@ from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
from tradingagents.llm_clients.claude_agent_client import ChatClaudeAgent
|
from tradingagents.llm_clients.claude_agent_client import (
|
||||||
|
ChatClaudeAgent,
|
||||||
|
extract_usage,
|
||||||
|
fire_llm_callbacks,
|
||||||
|
)
|
||||||
from tradingagents.llm_clients.mcp_tool_adapter import build_mcp_server
|
from tradingagents.llm_clients.mcp_tool_adapter import build_mcp_server
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -109,7 +113,8 @@ async def _run(
|
||||||
lc_tools: List[Any],
|
lc_tools: List[Any],
|
||||||
server_name: str,
|
server_name: str,
|
||||||
model: str,
|
model: str,
|
||||||
) -> str:
|
callbacks: List[Any],
|
||||||
|
) -> tuple[str, Dict[str, int]]:
|
||||||
from claude_agent_sdk import (
|
from claude_agent_sdk import (
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
ClaudeAgentOptions,
|
ClaudeAgentOptions,
|
||||||
|
|
@ -119,7 +124,7 @@ async def _run(
|
||||||
|
|
||||||
_log(f"[{server_name}] building MCP server with {len(lc_tools)} tools: "
|
_log(f"[{server_name}] building MCP server with {len(lc_tools)} tools: "
|
||||||
f"{[t.name for t in lc_tools]}")
|
f"{[t.name for t in lc_tools]}")
|
||||||
server, allowed = build_mcp_server(server_name, lc_tools)
|
server, allowed = build_mcp_server(server_name, lc_tools, callbacks=callbacks)
|
||||||
_log(f"[{server_name}] allowed_tools={allowed}")
|
_log(f"[{server_name}] allowed_tools={allowed}")
|
||||||
|
|
||||||
options = ClaudeAgentOptions(
|
options = ClaudeAgentOptions(
|
||||||
|
|
@ -140,6 +145,7 @@ async def _run(
|
||||||
start = time.monotonic()
|
start = time.monotonic()
|
||||||
|
|
||||||
text_parts: List[str] = []
|
text_parts: List[str] = []
|
||||||
|
final_usage: Dict[str, int] = {}
|
||||||
msg_count = 0
|
msg_count = 0
|
||||||
async for msg in query(prompt=user_prompt, options=options):
|
async for msg in query(prompt=user_prompt, options=options):
|
||||||
msg_count += 1
|
msg_count += 1
|
||||||
|
|
@ -149,11 +155,15 @@ async def _run(
|
||||||
for block in msg.content:
|
for block in msg.content:
|
||||||
if isinstance(block, TextBlock):
|
if isinstance(block, TextBlock):
|
||||||
text_parts.append(block.text)
|
text_parts.append(block.text)
|
||||||
|
sdk_usage = getattr(msg, "usage", None)
|
||||||
|
if isinstance(sdk_usage, dict) and sdk_usage:
|
||||||
|
final_usage = extract_usage(sdk_usage)
|
||||||
|
|
||||||
elapsed = time.monotonic() - start
|
elapsed = time.monotonic() - start
|
||||||
_log(f"[{server_name}] query complete after {elapsed:.1f}s, "
|
_log(f"[{server_name}] query complete after {elapsed:.1f}s, "
|
||||||
f"{msg_count} messages, {sum(len(t) for t in text_parts)} chars")
|
f"{msg_count} messages, {sum(len(t) for t in text_parts)} chars, "
|
||||||
return "\n".join(text_parts).strip()
|
f"usage={final_usage}")
|
||||||
|
return "\n".join(text_parts).strip(), final_usage
|
||||||
|
|
||||||
|
|
||||||
def run_sdk_analyst(
|
def run_sdk_analyst(
|
||||||
|
|
@ -169,20 +179,23 @@ def run_sdk_analyst(
|
||||||
_log(f"=== run_sdk_analyst start: server={server_name} report_field={report_field} "
|
_log(f"=== run_sdk_analyst start: server={server_name} report_field={report_field} "
|
||||||
f"ticker={state.get('company_of_interest')!r} date={state.get('trade_date')!r} ===")
|
f"ticker={state.get('company_of_interest')!r} date={state.get('trade_date')!r} ===")
|
||||||
try:
|
try:
|
||||||
report = asyncio.run(
|
report, usage = asyncio.run(
|
||||||
_run(
|
_run(
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
user_prompt=user_prompt,
|
user_prompt=user_prompt,
|
||||||
lc_tools=lc_tools,
|
lc_tools=lc_tools,
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
model=llm.model,
|
model=llm.model,
|
||||||
|
callbacks=llm.callbacks,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_log(f"[{server_name}] EXCEPTION: {type(e).__name__}: {e}")
|
_log(f"[{server_name}] EXCEPTION: {type(e).__name__}: {e}")
|
||||||
raise
|
raise
|
||||||
_log(f"=== run_sdk_analyst done: {report_field}={len(report)} chars ===")
|
_log(f"=== run_sdk_analyst done: {report_field}={len(report)} chars usage={usage} ===")
|
||||||
|
message = AIMessage(content=report, usage_metadata=usage or None)
|
||||||
|
fire_llm_callbacks(llm.callbacks, message, user_prompt)
|
||||||
return {
|
return {
|
||||||
"messages": [AIMessage(content=report)],
|
"messages": [message],
|
||||||
report_field: report,
|
report_field: report,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,11 @@ Shape B.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||||
from langchain_core.prompt_values import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
from langchain_core.runnables import Runnable
|
from langchain_core.runnables import Runnable
|
||||||
|
|
||||||
|
|
@ -69,6 +71,79 @@ def _coerce_input(input: Any) -> Tuple[Optional[str], str]:
|
||||||
return system_prompt, user_prompt
|
return system_prompt, user_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def extract_usage(sdk_usage: Any) -> Dict[str, int]:
|
||||||
|
"""Normalize the SDK's `usage` dict into LangChain's usage_metadata shape.
|
||||||
|
|
||||||
|
Accepts either a plain dict (ResultMessage.usage) or None. Returns a dict
|
||||||
|
with ``input_tokens``, ``output_tokens``, ``total_tokens`` keys.
|
||||||
|
"""
|
||||||
|
if not isinstance(sdk_usage, dict):
|
||||||
|
return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||||
|
# The SDK mirrors Anthropic usage shape. Be defensive across versions.
|
||||||
|
input_tokens = (
|
||||||
|
sdk_usage.get("input_tokens")
|
||||||
|
or sdk_usage.get("prompt_tokens")
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
output_tokens = (
|
||||||
|
sdk_usage.get("output_tokens")
|
||||||
|
or sdk_usage.get("completion_tokens")
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
# Count cached input against the input budget too so the TUI reflects it.
|
||||||
|
input_tokens += sdk_usage.get("cache_read_input_tokens", 0) or 0
|
||||||
|
input_tokens += sdk_usage.get("cache_creation_input_tokens", 0) or 0
|
||||||
|
total = input_tokens + output_tokens
|
||||||
|
return {
|
||||||
|
"input_tokens": int(input_tokens),
|
||||||
|
"output_tokens": int(output_tokens),
|
||||||
|
"total_tokens": int(total),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def fire_llm_callbacks(
|
||||||
|
callbacks: List[Any],
|
||||||
|
message: AIMessage,
|
||||||
|
prompt_preview: str,
|
||||||
|
) -> None:
|
||||||
|
"""Manually fire on_chat_model_start + on_llm_end on the given handlers.
|
||||||
|
|
||||||
|
ChatClaudeAgent is a plain Runnable, so LangChain does not fire chat-model
|
||||||
|
callbacks automatically. We invoke them ourselves so stats handlers
|
||||||
|
(StatsCallbackHandler in the CLI, etc.) see LLM calls and token usage.
|
||||||
|
"""
|
||||||
|
if not callbacks:
|
||||||
|
return
|
||||||
|
run_id = uuid4()
|
||||||
|
serialized = {"name": "ChatClaudeAgent"}
|
||||||
|
messages = [[{"role": "user", "content": prompt_preview}]]
|
||||||
|
for cb in callbacks:
|
||||||
|
if hasattr(cb, "on_chat_model_start"):
|
||||||
|
try:
|
||||||
|
cb.on_chat_model_start(serialized, messages, run_id=run_id)
|
||||||
|
except TypeError:
|
||||||
|
# Some handlers don't accept run_id; best-effort.
|
||||||
|
try:
|
||||||
|
cb.on_chat_model_start(serialized, messages)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
result = LLMResult(generations=[[ChatGeneration(message=message)]])
|
||||||
|
for cb in callbacks:
|
||||||
|
if hasattr(cb, "on_llm_end"):
|
||||||
|
try:
|
||||||
|
cb.on_llm_end(result, run_id=run_id)
|
||||||
|
except TypeError:
|
||||||
|
try:
|
||||||
|
cb.on_llm_end(result)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ChatClaudeAgent(Runnable):
|
class ChatClaudeAgent(Runnable):
|
||||||
"""LangChain-compatible Runnable that routes inference through claude-agent-sdk.
|
"""LangChain-compatible Runnable that routes inference through claude-agent-sdk.
|
||||||
|
|
||||||
|
|
@ -78,11 +153,16 @@ class ChatClaudeAgent(Runnable):
|
||||||
|
|
||||||
def __init__(self, model: str, **kwargs: Any) -> None:
|
def __init__(self, model: str, **kwargs: Any) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
|
# Pull callbacks out so we can fire them manually around each invoke —
|
||||||
|
# Runnable doesn't trigger chat-model callbacks the way BaseChatModel does.
|
||||||
|
self.callbacks: List[Any] = list(kwargs.pop("callbacks", None) or [])
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> AIMessage:
|
def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> AIMessage:
|
||||||
system_prompt, prompt = _coerce_input(input)
|
system_prompt, prompt = _coerce_input(input)
|
||||||
return asyncio.run(self._ainvoke(prompt, system_prompt))
|
message = asyncio.run(self._ainvoke(prompt, system_prompt))
|
||||||
|
fire_llm_callbacks(self.callbacks, message, prompt)
|
||||||
|
return message
|
||||||
|
|
||||||
async def _ainvoke(self, prompt: str, system_prompt: Optional[str]) -> AIMessage:
|
async def _ainvoke(self, prompt: str, system_prompt: Optional[str]) -> AIMessage:
|
||||||
from claude_agent_sdk import (
|
from claude_agent_sdk import (
|
||||||
|
|
@ -104,13 +184,22 @@ class ChatClaudeAgent(Runnable):
|
||||||
options = ClaudeAgentOptions(**options_kwargs)
|
options = ClaudeAgentOptions(**options_kwargs)
|
||||||
|
|
||||||
text_parts: List[str] = []
|
text_parts: List[str] = []
|
||||||
|
final_usage: Dict[str, int] = {}
|
||||||
async for msg in query(prompt=prompt, options=options):
|
async for msg in query(prompt=prompt, options=options):
|
||||||
if isinstance(msg, AssistantMessage):
|
if isinstance(msg, AssistantMessage):
|
||||||
for block in msg.content:
|
for block in msg.content:
|
||||||
if isinstance(block, TextBlock):
|
if isinstance(block, TextBlock):
|
||||||
text_parts.append(block.text)
|
text_parts.append(block.text)
|
||||||
|
# The ResultMessage at the end carries cumulative usage; prefer it.
|
||||||
|
# Fall back to AssistantMessage.usage if ResultMessage omits it.
|
||||||
|
sdk_usage = getattr(msg, "usage", None)
|
||||||
|
if isinstance(sdk_usage, dict) and sdk_usage:
|
||||||
|
final_usage = extract_usage(sdk_usage)
|
||||||
|
|
||||||
return AIMessage(content="\n".join(text_parts))
|
return AIMessage(
|
||||||
|
content="\n".join(text_parts),
|
||||||
|
usage_metadata=final_usage or None,
|
||||||
|
)
|
||||||
|
|
||||||
def bind_tools(self, tools: Any, **kwargs: Any) -> Any:
|
def bind_tools(self, tools: Any, **kwargs: Any) -> Any:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|
|
||||||
|
|
@ -6,20 +6,23 @@ dict and returns {"content": [{"type": "text", "text": str}]}.
|
||||||
|
|
||||||
Used by the SDK-native analyst runner to let Claude Code (authenticated via a
|
Used by the SDK-native analyst runner to let Claude Code (authenticated via a
|
||||||
Max/Pro subscription) call the same data tools the legacy analyst path uses.
|
Max/Pro subscription) call the same data tools the legacy analyst path uses.
|
||||||
|
Callbacks passed in from the graph are forwarded into each tool invocation so
|
||||||
|
that StatsCallbackHandler (and any other handler) sees on_tool_start/end.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||||
|
|
||||||
|
|
||||||
def _wrap_lc_tool(lc_tool: Any):
|
def _wrap_lc_tool(lc_tool: Any, callbacks: Optional[List[Any]]):
|
||||||
"""Wrap a single LangChain BaseTool as an SDK @tool-decorated async callable."""
|
"""Wrap a single LangChain BaseTool as an SDK @tool-decorated async callable."""
|
||||||
schema = (
|
schema = (
|
||||||
lc_tool.args_schema.model_json_schema()
|
lc_tool.args_schema.model_json_schema()
|
||||||
if lc_tool.args_schema is not None
|
if lc_tool.args_schema is not None
|
||||||
else {"type": "object", "properties": {}}
|
else {"type": "object", "properties": {}}
|
||||||
)
|
)
|
||||||
|
config = {"callbacks": callbacks} if callbacks else None
|
||||||
|
|
||||||
@tool(
|
@tool(
|
||||||
name=lc_tool.name,
|
name=lc_tool.name,
|
||||||
|
|
@ -27,7 +30,8 @@ def _wrap_lc_tool(lc_tool: Any):
|
||||||
input_schema=schema,
|
input_schema=schema,
|
||||||
)
|
)
|
||||||
async def _wrapped(args: Dict[str, Any]) -> Dict[str, Any]:
|
async def _wrapped(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
result = lc_tool.invoke(args)
|
# Pass callbacks via config so BaseTool fires on_tool_start/on_tool_end.
|
||||||
|
result = lc_tool.invoke(args, config=config) if config else lc_tool.invoke(args)
|
||||||
return {"content": [{"type": "text", "text": str(result)}]}
|
return {"content": [{"type": "text", "text": str(result)}]}
|
||||||
|
|
||||||
return _wrapped
|
return _wrapped
|
||||||
|
|
@ -36,13 +40,18 @@ def _wrap_lc_tool(lc_tool: Any):
|
||||||
def build_mcp_server(
|
def build_mcp_server(
|
||||||
server_name: str,
|
server_name: str,
|
||||||
lc_tools: List[Any],
|
lc_tools: List[Any],
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
) -> Tuple[Any, List[str]]:
|
) -> Tuple[Any, List[str]]:
|
||||||
"""Build an in-process MCP server from LangChain tools.
|
"""Build an in-process MCP server from LangChain tools.
|
||||||
|
|
||||||
Returns the server instance and the list of fully-qualified tool names
|
Returns the server instance and the list of fully-qualified tool names
|
||||||
(``mcp__<server>__<tool>``) suitable for passing to ``allowed_tools``.
|
(``mcp__<server>__<tool>``) suitable for passing to ``allowed_tools``.
|
||||||
|
|
||||||
|
``callbacks`` are forwarded into each tool's LangChain config so that
|
||||||
|
on_tool_start/on_tool_end fire on the stats handler during SDK-driven
|
||||||
|
tool calls.
|
||||||
"""
|
"""
|
||||||
wrapped = [_wrap_lc_tool(t) for t in lc_tools]
|
wrapped = [_wrap_lc_tool(t, callbacks) for t in lc_tools]
|
||||||
server = create_sdk_mcp_server(name=server_name, version="1.0.0", tools=wrapped)
|
server = create_sdk_mcp_server(name=server_name, version="1.0.0", tools=wrapped)
|
||||||
allowed = [f"mcp__{server_name}__{t.name}" for t in lc_tools]
|
allowed = [f"mcp__{server_name}__{t.name}" for t in lc_tools]
|
||||||
return server, allowed
|
return server, allowed
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue