TradingAgents/tests/agents/utils/test_agent_utils.py

182 lines
6.4 KiB
Python

from unittest.mock import Mock
from langchain_core.messages import HumanMessage, RemoveMessage
from tradingagents.agents.utils.agent_utils import create_msg_delete
class TestCreateMsgDelete:
"""Test suite for create_msg_delete function."""
def test_create_msg_delete_returns_callable(self):
"""Test that create_msg_delete returns a callable function."""
delete_func = create_msg_delete()
assert callable(delete_func)
def test_delete_messages_removes_all_messages(self):
"""Test that delete_messages removes all existing messages."""
# Create mock messages with IDs
mock_msg1 = Mock(spec=HumanMessage)
mock_msg1.id = "msg_1"
mock_msg2 = Mock(spec=HumanMessage)
mock_msg2.id = "msg_2"
mock_msg3 = Mock(spec=HumanMessage)
mock_msg3.id = "msg_3"
state = {"messages": [mock_msg1, mock_msg2, mock_msg3]}
delete_func = create_msg_delete()
result = delete_func(state)
# Should return removal operations for all messages plus a placeholder
assert "messages" in result
messages = result["messages"]
# First 3 should be RemoveMessage operations
removal_count = sum(1 for msg in messages if isinstance(msg, RemoveMessage))
assert removal_count == 3
# Last message should be the placeholder HumanMessage
assert isinstance(messages[-1], HumanMessage)
assert messages[-1].content == "Continue"
def test_delete_messages_empty_state(self):
"""Test delete_messages with an empty message list."""
state = {"messages": []}
delete_func = create_msg_delete()
result = delete_func(state)
# Should only contain the placeholder message
assert len(result["messages"]) == 1
assert isinstance(result["messages"][0], HumanMessage)
assert result["messages"][0].content == "Continue"
def test_delete_messages_single_message(self):
"""Test delete_messages with a single message."""
mock_msg = Mock(spec=HumanMessage)
mock_msg.id = "single_msg"
state = {"messages": [mock_msg]}
delete_func = create_msg_delete()
result = delete_func(state)
assert len(result["messages"]) == 2 # 1 removal + 1 placeholder
assert isinstance(result["messages"][0], RemoveMessage)
assert isinstance(result["messages"][1], HumanMessage)
def test_delete_messages_preserves_message_ids(self):
"""Test that RemoveMessage operations use correct message IDs."""
msg_ids = ["id_1", "id_2", "id_3", "id_4"]
mock_messages = []
for msg_id in msg_ids:
mock_msg = Mock(spec=HumanMessage)
mock_msg.id = msg_id
mock_messages.append(mock_msg)
state = {"messages": mock_messages}
delete_func = create_msg_delete()
result = delete_func(state)
# Extract RemoveMessage operations
removal_operations = [
msg for msg in result["messages"] if isinstance(msg, RemoveMessage)
]
removal_ids = [op.id for op in removal_operations]
# All original message IDs should be in removal operations
for original_id in msg_ids:
assert original_id in removal_ids
def test_delete_messages_anthropic_compatibility(self):
"""Test that the placeholder message ensures Anthropic API compatibility."""
# Anthropic requires at least one message in the conversation
mock_msg = Mock(spec=HumanMessage)
mock_msg.id = "test_msg"
state = {"messages": [mock_msg]}
delete_func = create_msg_delete()
result = delete_func(state)
# Verify placeholder is a HumanMessage (required by Anthropic)
placeholder = result["messages"][-1]
assert isinstance(placeholder, HumanMessage)
assert placeholder.content == "Continue"
def test_delete_messages_large_message_list(self):
"""Test delete_messages with a large number of messages."""
# Create 100 mock messages
mock_messages = []
for i in range(100):
mock_msg = Mock(spec=HumanMessage)
mock_msg.id = f"msg_{i}"
mock_messages.append(mock_msg)
state = {"messages": mock_messages}
delete_func = create_msg_delete()
result = delete_func(state)
# Should have 100 removal operations + 1 placeholder
assert len(result["messages"]) == 101
# Count removal operations
removal_count = sum(
1 for msg in result["messages"] if isinstance(msg, RemoveMessage)
)
assert removal_count == 100
def test_delete_messages_multiple_calls(self):
"""Test that create_msg_delete can be called multiple times."""
mock_msg1 = Mock(spec=HumanMessage)
mock_msg1.id = "msg_1"
mock_msg2 = Mock(spec=HumanMessage)
mock_msg2.id = "msg_2"
state1 = {"messages": [mock_msg1]}
state2 = {"messages": [mock_msg1, mock_msg2]}
delete_func1 = create_msg_delete()
delete_func2 = create_msg_delete()
result1 = delete_func1(state1)
result2 = delete_func2(state2)
# Each call should work independently
assert len(result1["messages"]) == 2 # 1 removal + placeholder
assert len(result2["messages"]) == 3 # 2 removals + placeholder
def test_delete_messages_state_immutability(self):
"""Test that delete_messages doesn't modify the original state."""
mock_msg = Mock(spec=HumanMessage)
mock_msg.id = "test_id"
original_state = {"messages": [mock_msg]}
original_msg_count = len(original_state["messages"])
delete_func = create_msg_delete()
result = delete_func(original_state)
# Original state should remain unchanged
assert len(original_state["messages"]) == original_msg_count
assert original_state["messages"][0] is mock_msg
def test_delete_messages_return_structure(self):
"""Test that delete_messages returns the correct structure."""
mock_msg = Mock(spec=HumanMessage)
mock_msg.id = "test_msg"
state = {"messages": [mock_msg]}
delete_func = create_msg_delete()
result = delete_func(state)
# Result should be a dict with 'messages' key
assert isinstance(result, dict)
assert "messages" in result
assert isinstance(result["messages"], list)