109 lines
3.5 KiB
Python
109 lines
3.5 KiB
Python
import threading
|
|
import pytest
|
|
from cli.stats_handler import StatsCallbackHandler
|
|
from langchain_core.outputs import LLMResult, Generation
|
|
from langchain_core.messages import AIMessage
|
|
|
|
def test_stats_handler_initial_state():
|
|
handler = StatsCallbackHandler()
|
|
stats = handler.get_stats()
|
|
assert stats == {
|
|
"llm_calls": 0,
|
|
"tool_calls": 0,
|
|
"tokens_in": 0,
|
|
"tokens_out": 0,
|
|
}
|
|
|
|
def test_stats_handler_on_llm_start():
|
|
handler = StatsCallbackHandler()
|
|
handler.on_llm_start(serialized={}, prompts=["test"])
|
|
assert handler.llm_calls == 1
|
|
assert handler.get_stats()["llm_calls"] == 1
|
|
|
|
def test_stats_handler_on_chat_model_start():
|
|
handler = StatsCallbackHandler()
|
|
handler.on_chat_model_start(serialized={}, messages=[[]])
|
|
assert handler.llm_calls == 1
|
|
assert handler.get_stats()["llm_calls"] == 1
|
|
|
|
def test_stats_handler_on_tool_start():
|
|
handler = StatsCallbackHandler()
|
|
handler.on_tool_start(serialized={}, input_str="test tool")
|
|
assert handler.tool_calls == 1
|
|
assert handler.get_stats()["tool_calls"] == 1
|
|
|
|
def test_stats_handler_on_llm_end_with_usage():
|
|
handler = StatsCallbackHandler()
|
|
|
|
# Mock usage metadata
|
|
usage_metadata = {"input_tokens": 10, "output_tokens": 20}
|
|
message = AIMessage(content="test response")
|
|
message.usage_metadata = usage_metadata
|
|
generation = Generation(message=message, text="test response")
|
|
response = LLMResult(generations=[[generation]])
|
|
|
|
handler.on_llm_end(response)
|
|
|
|
stats = handler.get_stats()
|
|
assert stats["tokens_in"] == 10
|
|
assert stats["tokens_out"] == 20
|
|
|
|
def test_stats_handler_on_llm_end_no_usage():
|
|
handler = StatsCallbackHandler()
|
|
|
|
# Generation without message/usage_metadata
|
|
generation = Generation(text="test response")
|
|
response = LLMResult(generations=[[generation]])
|
|
|
|
handler.on_llm_end(response)
|
|
|
|
stats = handler.get_stats()
|
|
assert stats["tokens_in"] == 0
|
|
assert stats["tokens_out"] == 0
|
|
|
|
def test_stats_handler_on_llm_end_empty_generations():
|
|
handler = StatsCallbackHandler()
|
|
response = LLMResult(generations=[[]])
|
|
handler.on_llm_end(response)
|
|
|
|
response_none = LLMResult(generations=[])
|
|
# on_llm_end does try response.generations[0][0], so generations=[] will trigger IndexError which is handled.
|
|
handler.on_llm_end(response_none)
|
|
|
|
assert handler.tokens_in == 0
|
|
assert handler.tokens_out == 0
|
|
|
|
def test_stats_handler_thread_safety():
|
|
handler = StatsCallbackHandler()
|
|
num_threads = 10
|
|
increments_per_thread = 100
|
|
|
|
def worker():
|
|
for _ in range(increments_per_thread):
|
|
handler.on_llm_start({}, [])
|
|
handler.on_tool_start({}, "")
|
|
|
|
# Mock usage metadata for on_llm_end
|
|
usage_metadata = {"input_tokens": 1, "output_tokens": 1}
|
|
message = AIMessage(content="x")
|
|
message.usage_metadata = usage_metadata
|
|
generation = Generation(message=message, text="x")
|
|
response = LLMResult(generations=[[generation]])
|
|
handler.on_llm_end(response)
|
|
|
|
threads = []
|
|
for _ in range(num_threads):
|
|
t = threading.Thread(target=worker)
|
|
threads.append(t)
|
|
t.start()
|
|
|
|
for t in threads:
|
|
t.join()
|
|
|
|
stats = handler.get_stats()
|
|
expected_calls = num_threads * increments_per_thread
|
|
assert stats["llm_calls"] == expected_calls
|
|
assert stats["tool_calls"] == expected_calls
|
|
assert stats["tokens_in"] == expected_calls
|
|
assert stats["tokens_out"] == expected_calls
|