Merge pull request #78 from aguzererler/test-stats-handler-2915727675949458763
This commit is contained in:
commit
0b2bd82422
|
|
@ -0,0 +1,108 @@
|
|||
import threading
|
||||
import pytest
|
||||
from cli.stats_handler import StatsCallbackHandler
|
||||
from langchain_core.outputs import LLMResult, Generation
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
def test_stats_handler_initial_state():
|
||||
handler = StatsCallbackHandler()
|
||||
stats = handler.get_stats()
|
||||
assert stats == {
|
||||
"llm_calls": 0,
|
||||
"tool_calls": 0,
|
||||
"tokens_in": 0,
|
||||
"tokens_out": 0,
|
||||
}
|
||||
|
||||
def test_stats_handler_on_llm_start():
|
||||
handler = StatsCallbackHandler()
|
||||
handler.on_llm_start(serialized={}, prompts=["test"])
|
||||
assert handler.llm_calls == 1
|
||||
assert handler.get_stats()["llm_calls"] == 1
|
||||
|
||||
def test_stats_handler_on_chat_model_start():
|
||||
handler = StatsCallbackHandler()
|
||||
handler.on_chat_model_start(serialized={}, messages=[[]])
|
||||
assert handler.llm_calls == 1
|
||||
assert handler.get_stats()["llm_calls"] == 1
|
||||
|
||||
def test_stats_handler_on_tool_start():
|
||||
handler = StatsCallbackHandler()
|
||||
handler.on_tool_start(serialized={}, input_str="test tool")
|
||||
assert handler.tool_calls == 1
|
||||
assert handler.get_stats()["tool_calls"] == 1
|
||||
|
||||
def test_stats_handler_on_llm_end_with_usage():
|
||||
handler = StatsCallbackHandler()
|
||||
|
||||
# Mock usage metadata
|
||||
usage_metadata = {"input_tokens": 10, "output_tokens": 20}
|
||||
message = AIMessage(content="test response")
|
||||
message.usage_metadata = usage_metadata
|
||||
generation = Generation(message=message, text="test response")
|
||||
response = LLMResult(generations=[[generation]])
|
||||
|
||||
handler.on_llm_end(response)
|
||||
|
||||
stats = handler.get_stats()
|
||||
assert stats["tokens_in"] == 10
|
||||
assert stats["tokens_out"] == 20
|
||||
|
||||
def test_stats_handler_on_llm_end_no_usage():
|
||||
handler = StatsCallbackHandler()
|
||||
|
||||
# Generation without message/usage_metadata
|
||||
generation = Generation(text="test response")
|
||||
response = LLMResult(generations=[[generation]])
|
||||
|
||||
handler.on_llm_end(response)
|
||||
|
||||
stats = handler.get_stats()
|
||||
assert stats["tokens_in"] == 0
|
||||
assert stats["tokens_out"] == 0
|
||||
|
||||
def test_stats_handler_on_llm_end_empty_generations():
|
||||
handler = StatsCallbackHandler()
|
||||
response = LLMResult(generations=[[]])
|
||||
handler.on_llm_end(response)
|
||||
|
||||
response_none = LLMResult(generations=[])
|
||||
# on_llm_end does try response.generations[0][0], so generations=[] will trigger IndexError which is handled.
|
||||
handler.on_llm_end(response_none)
|
||||
|
||||
assert handler.tokens_in == 0
|
||||
assert handler.tokens_out == 0
|
||||
|
||||
def test_stats_handler_thread_safety():
|
||||
handler = StatsCallbackHandler()
|
||||
num_threads = 10
|
||||
increments_per_thread = 100
|
||||
|
||||
def worker():
|
||||
for _ in range(increments_per_thread):
|
||||
handler.on_llm_start({}, [])
|
||||
handler.on_tool_start({}, "")
|
||||
|
||||
# Mock usage metadata for on_llm_end
|
||||
usage_metadata = {"input_tokens": 1, "output_tokens": 1}
|
||||
message = AIMessage(content="x")
|
||||
message.usage_metadata = usage_metadata
|
||||
generation = Generation(message=message, text="x")
|
||||
response = LLMResult(generations=[[generation]])
|
||||
handler.on_llm_end(response)
|
||||
|
||||
threads = []
|
||||
for _ in range(num_threads):
|
||||
t = threading.Thread(target=worker)
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
stats = handler.get_stats()
|
||||
expected_calls = num_threads * increments_per_thread
|
||||
assert stats["llm_calls"] == expected_calls
|
||||
assert stats["tool_calls"] == expected_calls
|
||||
assert stats["tokens_in"] == expected_calls
|
||||
assert stats["tokens_out"] == expected_calls
|
||||
Loading…
Reference in New Issue