TradingAgents/tests/api/test_token_handler.py

75 lines
2.5 KiB
Python

import threading
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage
from langchain_core.outputs import LLMResult, ChatGeneration
from api.callbacks.token_handler import TokenCallbackHandler
def _make_response(input_tokens: int, output_tokens: int) -> LLMResult:
"""Build a minimal LLMResult that on_llm_end will parse."""
msg = AIMessage(content="ok")
msg.usage_metadata = {"input_tokens": input_tokens, "output_tokens": output_tokens}
gen = ChatGeneration(message=msg)
return LLMResult(generations=[[gen]])
def test_snapshot_and_reset_returns_delta():
handler = TokenCallbackHandler()
handler.on_llm_end(_make_response(100, 40))
result = handler.snapshot_and_reset()
assert result == {"in": 100, "out": 40}
def test_snapshot_and_reset_zeroes_counters():
handler = TokenCallbackHandler()
handler.on_llm_end(_make_response(100, 40))
handler.snapshot_and_reset()
second = handler.snapshot_and_reset()
assert second == {"in": 0, "out": 0}
def test_multiple_llm_calls_accumulate():
handler = TokenCallbackHandler()
handler.on_llm_end(_make_response(100, 40))
handler.on_llm_end(_make_response(200, 60))
result = handler.snapshot_and_reset()
assert result == {"in": 300, "out": 100}
def test_concurrent_on_llm_end_does_not_corrupt():
handler = TokenCallbackHandler()
threads = [
threading.Thread(target=handler.on_llm_end, args=(_make_response(10, 5),))
for _ in range(20)
]
for t in threads:
t.start()
for t in threads:
t.join()
result = handler.snapshot_and_reset()
assert result == {"in": 200, "out": 100}
def test_missing_usage_metadata_does_not_crash():
handler = TokenCallbackHandler()
msg = AIMessage(content="no metadata")
# No usage_metadata attribute
gen = ChatGeneration(message=msg)
response = LLMResult(generations=[[gen]])
handler.on_llm_end(response) # should not raise
assert handler.snapshot_and_reset() == {"in": 0, "out": 0}
def test_empty_outer_generations_does_not_crash():
handler = TokenCallbackHandler()
response = LLMResult(generations=[])
handler.on_llm_end(response) # IndexError guard
assert handler.snapshot_and_reset() == {"in": 0, "out": 0}
def test_empty_inner_generations_does_not_crash():
handler = TokenCallbackHandler()
response = LLMResult(generations=[[]])
handler.on_llm_end(response) # IndexError guard
assert handler.snapshot_and_reset() == {"in": 0, "out": 0}