🧪 Add unit tests for StatsCallbackHandler

Adds a comprehensive test suite for the StatsCallbackHandler in cli/stats_handler.py,
covering call counting, token usage extraction, and thread safety.

Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com>
This commit is contained in:
google-labs-jules[bot] 2026-03-21 22:18:51 +00:00
parent a7b8c996f2
commit 0bb7ae1cd8
1 changed files with 108 additions and 0 deletions

View File

@ -0,0 +1,108 @@
import threading
import pytest
from cli.stats_handler import StatsCallbackHandler
from langchain_core.outputs import LLMResult, Generation
from langchain_core.messages import AIMessage
def test_stats_handler_initial_state():
handler = StatsCallbackHandler()
stats = handler.get_stats()
assert stats == {
"llm_calls": 0,
"tool_calls": 0,
"tokens_in": 0,
"tokens_out": 0,
}
def test_stats_handler_on_llm_start():
handler = StatsCallbackHandler()
handler.on_llm_start(serialized={}, prompts=["test"])
assert handler.llm_calls == 1
assert handler.get_stats()["llm_calls"] == 1
def test_stats_handler_on_chat_model_start():
handler = StatsCallbackHandler()
handler.on_chat_model_start(serialized={}, messages=[[]])
assert handler.llm_calls == 1
assert handler.get_stats()["llm_calls"] == 1
def test_stats_handler_on_tool_start():
handler = StatsCallbackHandler()
handler.on_tool_start(serialized={}, input_str="test tool")
assert handler.tool_calls == 1
assert handler.get_stats()["tool_calls"] == 1
def test_stats_handler_on_llm_end_with_usage():
handler = StatsCallbackHandler()
# Mock usage metadata
usage_metadata = {"input_tokens": 10, "output_tokens": 20}
message = AIMessage(content="test response")
message.usage_metadata = usage_metadata
generation = Generation(message=message, text="test response")
response = LLMResult(generations=[[generation]])
handler.on_llm_end(response)
stats = handler.get_stats()
assert stats["tokens_in"] == 10
assert stats["tokens_out"] == 20
def test_stats_handler_on_llm_end_no_usage():
handler = StatsCallbackHandler()
# Generation without message/usage_metadata
generation = Generation(text="test response")
response = LLMResult(generations=[[generation]])
handler.on_llm_end(response)
stats = handler.get_stats()
assert stats["tokens_in"] == 0
assert stats["tokens_out"] == 0
def test_stats_handler_on_llm_end_empty_generations():
handler = StatsCallbackHandler()
response = LLMResult(generations=[[]])
handler.on_llm_end(response)
response_none = LLMResult(generations=[])
# on_llm_end does try response.generations[0][0], so generations=[] will trigger IndexError which is handled.
handler.on_llm_end(response_none)
assert handler.tokens_in == 0
assert handler.tokens_out == 0
def test_stats_handler_thread_safety():
handler = StatsCallbackHandler()
num_threads = 10
increments_per_thread = 100
def worker():
for _ in range(increments_per_thread):
handler.on_llm_start({}, [])
handler.on_tool_start({}, "")
# Mock usage metadata for on_llm_end
usage_metadata = {"input_tokens": 1, "output_tokens": 1}
message = AIMessage(content="x")
message.usage_metadata = usage_metadata
generation = Generation(message=message, text="x")
response = LLMResult(generations=[[generation]])
handler.on_llm_end(response)
threads = []
for _ in range(num_threads):
t = threading.Thread(target=worker)
threads.append(t)
t.start()
for t in threads:
t.join()
stats = handler.get_stats()
expected_calls = num_threads * increments_per_thread
assert stats["llm_calls"] == expected_calls
assert stats["tool_calls"] == expected_calls
assert stats["tokens_in"] == expected_calls
assert stats["tokens_out"] == expected_calls