diff --git a/cli/chunk_ingest.py b/cli/chunk_ingest.py index d81bc6f7..8da86cde 100644 --- a/cli/chunk_ingest.py +++ b/cli/chunk_ingest.py @@ -1,13 +1,41 @@ +import json + + +def _tool_call_signature(tool_call): + if isinstance(tool_call, dict): + name = tool_call["name"] + args = tool_call["args"] + else: + name = tool_call.name + args = tool_call.args + return (name, json.dumps(args, sort_keys=True, default=str)) + + +def _message_fingerprint(message, msg_type, content): + tool_calls = tuple(_tool_call_signature(tool_call) for tool_call in getattr(message, "tool_calls", []) or []) + return ( + message.__class__.__name__, + msg_type, + content.strip() if isinstance(content, str) else str(content), + tool_calls, + ) + + def ingest_chunk_messages(message_buffer, chunk, classify_message_type) -> None: """Ingest all newly seen messages from a graph stream chunk.""" for message in chunk.get("messages", []): + msg_type, content = classify_message_type(message) msg_id = getattr(message, "id", None) if msg_id is not None: if msg_id in message_buffer._processed_message_ids: continue message_buffer._processed_message_ids.add(msg_id) + else: + fingerprint = _message_fingerprint(message, msg_type, content) + if fingerprint in message_buffer._processed_message_fingerprints: + continue + message_buffer._processed_message_fingerprints.add(fingerprint) - msg_type, content = classify_message_type(message) if content and content.strip(): message_buffer.add_message(msg_type, content) diff --git a/cli/main.py b/cli/main.py index 2cbb70a1..8810f8df 100644 --- a/cli/main.py +++ b/cli/main.py @@ -82,6 +82,7 @@ class MessageBuffer: self.report_sections = {} self.selected_analysts = [] self._processed_message_ids = set() + self._processed_message_fingerprints = set() def init_for_analysis(self, selected_analysts): """Initialize agent status and report sections based on selected analysts. @@ -117,6 +118,7 @@ class MessageBuffer: self.messages.clear() self.tool_calls.clear() self._processed_message_ids.clear() + self._processed_message_fingerprints.clear() def get_completed_reports_count(self): """Count reports that are finalized (their finalizing agent is completed). diff --git a/tests/test_cli_chunk_message_ingest.py b/tests/test_cli_chunk_message_ingest.py index b4efeef9..a9c4ca90 100644 --- a/tests/test_cli_chunk_message_ingest.py +++ b/tests/test_cli_chunk_message_ingest.py @@ -17,6 +17,7 @@ class FakeMessage: class FakeMessageBuffer: def __init__(self): self._processed_message_ids = set() + self._processed_message_fingerprints = set() self.messages = [] self.tool_calls = [] @@ -62,3 +63,29 @@ def test_ingest_chunk_messages_skips_duplicate_message_ids(): assert len(message_buffer.messages) == 1 assert len(message_buffer.tool_calls) == 1 + + +def test_ingest_chunk_messages_skips_duplicate_messages_without_ids(): + message_buffer = FakeMessageBuffer() + chunk = {"messages": [FakeMessage(None, "same", [{"name": "tool_a", "args": {"x": 1}}])]} + + ingest_chunk_messages(message_buffer, chunk, fake_classifier) + ingest_chunk_messages(message_buffer, chunk, fake_classifier) + + assert len(message_buffer.messages) == 1 + assert len(message_buffer.tool_calls) == 1 + + +def test_ingest_chunk_messages_keeps_distinct_messages_without_ids(): + message_buffer = FakeMessageBuffer() + chunk = { + "messages": [ + FakeMessage(None, "first", [{"name": "tool_a", "args": {"x": 1}}]), + FakeMessage(None, "second", [{"name": "tool_b", "args": {"y": 2}}]), + ] + } + + ingest_chunk_messages(message_buffer, chunk, fake_classifier) + + assert [content for _, content in message_buffer.messages] == ["first", "second"] + assert [name for name, _ in message_buffer.tool_calls] == ["tool_a", "tool_b"]