77 lines
2.3 KiB
Python
77 lines
2.3 KiB
Python
import threading
|
|
from typing import Any, Dict, List
|
|
|
|
from langchain_core.callbacks import BaseCallbackHandler
|
|
from langchain_core.messages import AIMessage
|
|
from langchain_core.outputs import LLMResult
|
|
|
|
|
|
class StatsCallbackHandler(BaseCallbackHandler):
|
|
"""Callback handler that tracks LLM calls, tool calls, and token usage."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._lock = threading.Lock()
|
|
self.llm_calls = 0
|
|
self.tool_calls = 0
|
|
self.tokens_in = 0
|
|
self.tokens_out = 0
|
|
|
|
def on_llm_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
prompts: List[str],
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Increment LLM call counter when an LLM starts."""
|
|
with self._lock:
|
|
self.llm_calls += 1
|
|
|
|
def on_chat_model_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
messages: List[List[Any]],
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Increment LLM call counter when a chat model starts."""
|
|
with self._lock:
|
|
self.llm_calls += 1
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
"""Extract token usage from LLM response."""
|
|
try:
|
|
generation = response.generations[0][0]
|
|
except (IndexError, TypeError):
|
|
return
|
|
|
|
usage_metadata = None
|
|
if hasattr(generation, "message"):
|
|
message = generation.message
|
|
if isinstance(message, AIMessage) and hasattr(message, "usage_metadata"):
|
|
usage_metadata = message.usage_metadata
|
|
|
|
if usage_metadata:
|
|
with self._lock:
|
|
self.tokens_in += usage_metadata.get("input_tokens", 0)
|
|
self.tokens_out += usage_metadata.get("output_tokens", 0)
|
|
|
|
def on_tool_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
input_str: str,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Increment tool call counter when a tool starts."""
|
|
with self._lock:
|
|
self.tool_calls += 1
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Return current statistics."""
|
|
with self._lock:
|
|
return {
|
|
"llm_calls": self.llm_calls,
|
|
"tool_calls": self.tool_calls,
|
|
"tokens_in": self.tokens_in,
|
|
"tokens_out": self.tokens_out,
|
|
}
|