From 0bb7ae1cd8eb1bc5d1267fe0ceb181b2b804054b Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 22:18:51 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=AA=20Add=20unit=20tests=20for=20Stats?= =?UTF-8?q?CallbackHandler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a comprehensive test suite for the StatsCallbackHandler in cli/stats_handler.py, covering call counting, token usage extraction, and thread safety. Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com> --- tests/cli/test_stats_handler.py | 108 ++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 tests/cli/test_stats_handler.py diff --git a/tests/cli/test_stats_handler.py b/tests/cli/test_stats_handler.py new file mode 100644 index 00000000..8bad63fb --- /dev/null +++ b/tests/cli/test_stats_handler.py @@ -0,0 +1,108 @@ +import threading +import pytest +from cli.stats_handler import StatsCallbackHandler +from langchain_core.outputs import LLMResult, Generation +from langchain_core.messages import AIMessage + +def test_stats_handler_initial_state(): + handler = StatsCallbackHandler() + stats = handler.get_stats() + assert stats == { + "llm_calls": 0, + "tool_calls": 0, + "tokens_in": 0, + "tokens_out": 0, + } + +def test_stats_handler_on_llm_start(): + handler = StatsCallbackHandler() + handler.on_llm_start(serialized={}, prompts=["test"]) + assert handler.llm_calls == 1 + assert handler.get_stats()["llm_calls"] == 1 + +def test_stats_handler_on_chat_model_start(): + handler = StatsCallbackHandler() + handler.on_chat_model_start(serialized={}, messages=[[]]) + assert handler.llm_calls == 1 + assert handler.get_stats()["llm_calls"] == 1 + +def test_stats_handler_on_tool_start(): + handler = StatsCallbackHandler() + handler.on_tool_start(serialized={}, input_str="test tool") + assert handler.tool_calls == 1 + assert handler.get_stats()["tool_calls"] == 1 + +def test_stats_handler_on_llm_end_with_usage(): + handler = StatsCallbackHandler() + + # Mock usage metadata + usage_metadata = {"input_tokens": 10, "output_tokens": 20} + message = AIMessage(content="test response") + message.usage_metadata = usage_metadata + generation = Generation(message=message, text="test response") + response = LLMResult(generations=[[generation]]) + + handler.on_llm_end(response) + + stats = handler.get_stats() + assert stats["tokens_in"] == 10 + assert stats["tokens_out"] == 20 + +def test_stats_handler_on_llm_end_no_usage(): + handler = StatsCallbackHandler() + + # Generation without message/usage_metadata + generation = Generation(text="test response") + response = LLMResult(generations=[[generation]]) + + handler.on_llm_end(response) + + stats = handler.get_stats() + assert stats["tokens_in"] == 0 + assert stats["tokens_out"] == 0 + +def test_stats_handler_on_llm_end_empty_generations(): + handler = StatsCallbackHandler() + response = LLMResult(generations=[[]]) + handler.on_llm_end(response) + + response_none = LLMResult(generations=[]) + # on_llm_end does try response.generations[0][0], so generations=[] will trigger IndexError which is handled. + handler.on_llm_end(response_none) + + assert handler.tokens_in == 0 + assert handler.tokens_out == 0 + +def test_stats_handler_thread_safety(): + handler = StatsCallbackHandler() + num_threads = 10 + increments_per_thread = 100 + + def worker(): + for _ in range(increments_per_thread): + handler.on_llm_start({}, []) + handler.on_tool_start({}, "") + + # Mock usage metadata for on_llm_end + usage_metadata = {"input_tokens": 1, "output_tokens": 1} + message = AIMessage(content="x") + message.usage_metadata = usage_metadata + generation = Generation(message=message, text="x") + response = LLMResult(generations=[[generation]]) + handler.on_llm_end(response) + + threads = [] + for _ in range(num_threads): + t = threading.Thread(target=worker) + threads.append(t) + t.start() + + for t in threads: + t.join() + + stats = handler.get_stats() + expected_calls = num_threads * increments_per_thread + assert stats["llm_calls"] == expected_calls + assert stats["tool_calls"] == expected_calls + assert stats["tokens_in"] == expected_calls + assert stats["tokens_out"] == expected_calls