75 lines
2.5 KiB
Python
75 lines
2.5 KiB
Python
import threading
|
|
from unittest.mock import MagicMock
|
|
from langchain_core.messages import AIMessage
|
|
from langchain_core.outputs import LLMResult, ChatGeneration
|
|
from api.callbacks.token_handler import TokenCallbackHandler
|
|
|
|
|
|
def _make_response(input_tokens: int, output_tokens: int) -> LLMResult:
|
|
"""Build a minimal LLMResult that on_llm_end will parse."""
|
|
msg = AIMessage(content="ok")
|
|
msg.usage_metadata = {"input_tokens": input_tokens, "output_tokens": output_tokens}
|
|
gen = ChatGeneration(message=msg)
|
|
return LLMResult(generations=[[gen]])
|
|
|
|
|
|
def test_snapshot_and_reset_returns_delta():
|
|
handler = TokenCallbackHandler()
|
|
handler.on_llm_end(_make_response(100, 40))
|
|
result = handler.snapshot_and_reset()
|
|
assert result == {"in": 100, "out": 40}
|
|
|
|
|
|
def test_snapshot_and_reset_zeroes_counters():
|
|
handler = TokenCallbackHandler()
|
|
handler.on_llm_end(_make_response(100, 40))
|
|
handler.snapshot_and_reset()
|
|
second = handler.snapshot_and_reset()
|
|
assert second == {"in": 0, "out": 0}
|
|
|
|
|
|
def test_multiple_llm_calls_accumulate():
|
|
handler = TokenCallbackHandler()
|
|
handler.on_llm_end(_make_response(100, 40))
|
|
handler.on_llm_end(_make_response(200, 60))
|
|
result = handler.snapshot_and_reset()
|
|
assert result == {"in": 300, "out": 100}
|
|
|
|
|
|
def test_concurrent_on_llm_end_does_not_corrupt():
|
|
handler = TokenCallbackHandler()
|
|
threads = [
|
|
threading.Thread(target=handler.on_llm_end, args=(_make_response(10, 5),))
|
|
for _ in range(20)
|
|
]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
result = handler.snapshot_and_reset()
|
|
assert result == {"in": 200, "out": 100}
|
|
|
|
|
|
def test_missing_usage_metadata_does_not_crash():
|
|
handler = TokenCallbackHandler()
|
|
msg = AIMessage(content="no metadata")
|
|
# No usage_metadata attribute
|
|
gen = ChatGeneration(message=msg)
|
|
response = LLMResult(generations=[[gen]])
|
|
handler.on_llm_end(response) # should not raise
|
|
assert handler.snapshot_and_reset() == {"in": 0, "out": 0}
|
|
|
|
|
|
def test_empty_outer_generations_does_not_crash():
|
|
handler = TokenCallbackHandler()
|
|
response = LLMResult(generations=[])
|
|
handler.on_llm_end(response) # IndexError guard
|
|
assert handler.snapshot_and_reset() == {"in": 0, "out": 0}
|
|
|
|
|
|
def test_empty_inner_generations_does_not_crash():
|
|
handler = TokenCallbackHandler()
|
|
response = LLMResult(generations=[[]])
|
|
handler.on_llm_end(response) # IndexError guard
|
|
assert handler.snapshot_and_reset() == {"in": 0, "out": 0}
|