diff --git a/cli/chunk_ingest.py b/cli/chunk_ingest.py new file mode 100644 index 00000000..8da86cde --- /dev/null +++ b/cli/chunk_ingest.py @@ -0,0 +1,47 @@ +import json + + +def _tool_call_signature(tool_call): + if isinstance(tool_call, dict): + name = tool_call["name"] + args = tool_call["args"] + else: + name = tool_call.name + args = tool_call.args + return (name, json.dumps(args, sort_keys=True, default=str)) + + +def _message_fingerprint(message, msg_type, content): + tool_calls = tuple(_tool_call_signature(tool_call) for tool_call in getattr(message, "tool_calls", []) or []) + return ( + message.__class__.__name__, + msg_type, + content.strip() if isinstance(content, str) else str(content), + tool_calls, + ) + + +def ingest_chunk_messages(message_buffer, chunk, classify_message_type) -> None: + """Ingest all newly seen messages from a graph stream chunk.""" + for message in chunk.get("messages", []): + msg_type, content = classify_message_type(message) + msg_id = getattr(message, "id", None) + if msg_id is not None: + if msg_id in message_buffer._processed_message_ids: + continue + message_buffer._processed_message_ids.add(msg_id) + else: + fingerprint = _message_fingerprint(message, msg_type, content) + if fingerprint in message_buffer._processed_message_fingerprints: + continue + message_buffer._processed_message_fingerprints.add(fingerprint) + + if content and content.strip(): + message_buffer.add_message(msg_type, content) + + if hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + if isinstance(tool_call, dict): + message_buffer.add_tool_call(tool_call["name"], tool_call["args"]) + else: + message_buffer.add_tool_call(tool_call.name, tool_call.args) diff --git a/cli/main.py b/cli/main.py index 33d110fb..8810f8df 100644 --- a/cli/main.py +++ b/cli/main.py @@ -28,6 +28,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG from cli.models import AnalystType from cli.utils import * +from cli.chunk_ingest import ingest_chunk_messages from cli.announcements import fetch_announcements, display_announcements from cli.stats_handler import StatsCallbackHandler @@ -81,6 +82,7 @@ class MessageBuffer: self.report_sections = {} self.selected_analysts = [] self._processed_message_ids = set() + self._processed_message_fingerprints = set() def init_for_analysis(self, selected_analysts): """Initialize agent status and report sections based on selected analysts. @@ -116,6 +118,7 @@ class MessageBuffer: self.messages.clear() self.tool_calls.clear() self._processed_message_ids.clear() + self._processed_message_fingerprints.clear() def get_completed_reports_count(self): """Count reports that are finalized (their finalizing agent is completed). @@ -926,6 +929,7 @@ def format_tool_args(args, max_length=80) -> str: return result[:max_length - 3] + "..." return result + def run_analysis(): # First get all user selections selections = get_user_selections() @@ -1053,24 +1057,7 @@ def run_analysis(): # Stream the analysis trace = [] for chunk in graph.graph.stream(init_agent_state, **args): - # Process all messages in chunk, deduplicating by message ID - for message in chunk.get("messages", []): - msg_id = getattr(message, "id", None) - if msg_id is not None: - if msg_id in message_buffer._processed_message_ids: - continue - message_buffer._processed_message_ids.add(msg_id) - - msg_type, content = classify_message_type(message) - if content and content.strip(): - message_buffer.add_message(msg_type, content) - - if hasattr(message, "tool_calls") and message.tool_calls: - for tool_call in message.tool_calls: - if isinstance(tool_call, dict): - message_buffer.add_tool_call(tool_call["name"], tool_call["args"]) - else: - message_buffer.add_tool_call(tool_call.name, tool_call.args) + ingest_chunk_messages(message_buffer, chunk, classify_message_type) # Update analyst statuses based on report state (runs on every chunk) update_analyst_statuses(message_buffer, chunk) diff --git a/tests/test_cli_chunk_message_ingest.py b/tests/test_cli_chunk_message_ingest.py new file mode 100644 index 00000000..a9c4ca90 --- /dev/null +++ b/tests/test_cli_chunk_message_ingest.py @@ -0,0 +1,91 @@ +from cli.chunk_ingest import ingest_chunk_messages + + +class FakeToolCall: + def __init__(self, name, args): + self.name = name + self.args = args + + +class FakeMessage: + def __init__(self, msg_id, content, tool_calls=None): + self.id = msg_id + self.content = content + self.tool_calls = tool_calls or [] + + +class FakeMessageBuffer: + def __init__(self): + self._processed_message_ids = set() + self._processed_message_fingerprints = set() + self.messages = [] + self.tool_calls = [] + + def add_message(self, message_type, content): + self.messages.append((message_type, content)) + + def add_tool_call(self, tool_name, args): + self.tool_calls.append((tool_name, args)) + + +def fake_classifier(message): + return "Agent", message.content + + +def test_ingest_chunk_messages_records_all_tool_calls(): + message_buffer = FakeMessageBuffer() + chunk = { + "messages": [ + FakeMessage( + "m1", + "first", + [ + {"name": "tool_a", "args": {"x": 1}}, + FakeToolCall("tool_b", {"y": 2}), + ], + ), + FakeMessage("m2", "second", [FakeToolCall("tool_c", {"z": 3})]), + ] + } + + ingest_chunk_messages(message_buffer, chunk, fake_classifier) + + tool_names = [name for name, _ in message_buffer.tool_calls] + assert tool_names == ["tool_a", "tool_b", "tool_c"] + + +def test_ingest_chunk_messages_skips_duplicate_message_ids(): + message_buffer = FakeMessageBuffer() + chunk = {"messages": [FakeMessage("m1", "same", [{"name": "tool_a", "args": {}}])]} + + ingest_chunk_messages(message_buffer, chunk, fake_classifier) + ingest_chunk_messages(message_buffer, chunk, fake_classifier) + + assert len(message_buffer.messages) == 1 + assert len(message_buffer.tool_calls) == 1 + + +def test_ingest_chunk_messages_skips_duplicate_messages_without_ids(): + message_buffer = FakeMessageBuffer() + chunk = {"messages": [FakeMessage(None, "same", [{"name": "tool_a", "args": {"x": 1}}])]} + + ingest_chunk_messages(message_buffer, chunk, fake_classifier) + ingest_chunk_messages(message_buffer, chunk, fake_classifier) + + assert len(message_buffer.messages) == 1 + assert len(message_buffer.tool_calls) == 1 + + +def test_ingest_chunk_messages_keeps_distinct_messages_without_ids(): + message_buffer = FakeMessageBuffer() + chunk = { + "messages": [ + FakeMessage(None, "first", [{"name": "tool_a", "args": {"x": 1}}]), + FakeMessage(None, "second", [{"name": "tool_b", "args": {"y": 2}}]), + ] + } + + ingest_chunk_messages(message_buffer, chunk, fake_classifier) + + assert [content for _, content in message_buffer.messages] == ["first", "second"] + assert [name for name, _ in message_buffer.tool_calls] == ["tool_a", "tool_b"]