diff --git a/tests/cli/test_stats_handler.py b/tests/cli/test_stats_handler.py index 8bad63fb..c2d4d7f8 100644 --- a/tests/cli/test_stats_handler.py +++ b/tests/cli/test_stats_handler.py @@ -1,8 +1,9 @@ import threading import pytest from cli.stats_handler import StatsCallbackHandler -from langchain_core.outputs import LLMResult, Generation +from langchain_core.outputs import LLMResult, Generation, ChatGeneration from langchain_core.messages import AIMessage +from langchain_core.messages.ai import UsageMetadata def test_stats_handler_initial_state(): handler = StatsCallbackHandler() @@ -35,11 +36,10 @@ def test_stats_handler_on_tool_start(): def test_stats_handler_on_llm_end_with_usage(): handler = StatsCallbackHandler() - # Mock usage metadata - usage_metadata = {"input_tokens": 10, "output_tokens": 20} - message = AIMessage(content="test response") - message.usage_metadata = usage_metadata - generation = Generation(message=message, text="test response") + # ChatGeneration wraps chat messages; Generation (plain text) has no .message attr. + usage_metadata = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30) + message = AIMessage(content="test response", usage_metadata=usage_metadata) + generation = ChatGeneration(message=message) response = LLMResult(generations=[[generation]]) handler.on_llm_end(response) @@ -83,11 +83,10 @@ def test_stats_handler_thread_safety(): handler.on_llm_start({}, []) handler.on_tool_start({}, "") - # Mock usage metadata for on_llm_end - usage_metadata = {"input_tokens": 1, "output_tokens": 1} - message = AIMessage(content="x") - message.usage_metadata = usage_metadata - generation = Generation(message=message, text="x") + # ChatGeneration wraps chat messages with usage_metadata + usage_metadata = UsageMetadata(input_tokens=1, output_tokens=1, total_tokens=2) + message = AIMessage(content="x", usage_metadata=usage_metadata) + generation = ChatGeneration(message=message) response = LLMResult(generations=[[generation]]) handler.on_llm_end(response)