diff --git a/cli/chunk_ingest.py b/cli/chunk_ingest.py new file mode 100644 index 00000000..d81bc6f7 --- /dev/null +++ b/cli/chunk_ingest.py @@ -0,0 +1,19 @@ +def ingest_chunk_messages(message_buffer, chunk, classify_message_type) -> None: + """Ingest all newly seen messages from a graph stream chunk.""" + for message in chunk.get("messages", []): + msg_id = getattr(message, "id", None) + if msg_id is not None: + if msg_id in message_buffer._processed_message_ids: + continue + message_buffer._processed_message_ids.add(msg_id) + + msg_type, content = classify_message_type(message) + if content and content.strip(): + message_buffer.add_message(msg_type, content) + + if hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + if isinstance(tool_call, dict): + message_buffer.add_tool_call(tool_call["name"], tool_call["args"]) + else: + message_buffer.add_tool_call(tool_call.name, tool_call.args) diff --git a/cli/main.py b/cli/main.py index 33d110fb..2cbb70a1 100644 --- a/cli/main.py +++ b/cli/main.py @@ -28,6 +28,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG from cli.models import AnalystType from cli.utils import * +from cli.chunk_ingest import ingest_chunk_messages from cli.announcements import fetch_announcements, display_announcements from cli.stats_handler import StatsCallbackHandler @@ -926,6 +927,7 @@ def format_tool_args(args, max_length=80) -> str: return result[:max_length - 3] + "..." return result + def run_analysis(): # First get all user selections selections = get_user_selections() @@ -1053,24 +1055,7 @@ def run_analysis(): # Stream the analysis trace = [] for chunk in graph.graph.stream(init_agent_state, **args): - # Process all messages in chunk, deduplicating by message ID - for message in chunk.get("messages", []): - msg_id = getattr(message, "id", None) - if msg_id is not None: - if msg_id in message_buffer._processed_message_ids: - continue - message_buffer._processed_message_ids.add(msg_id) - - msg_type, content = classify_message_type(message) - if content and content.strip(): - message_buffer.add_message(msg_type, content) - - if hasattr(message, "tool_calls") and message.tool_calls: - for tool_call in message.tool_calls: - if isinstance(tool_call, dict): - message_buffer.add_tool_call(tool_call["name"], tool_call["args"]) - else: - message_buffer.add_tool_call(tool_call.name, tool_call.args) + ingest_chunk_messages(message_buffer, chunk, classify_message_type) # Update analyst statuses based on report state (runs on every chunk) update_analyst_statuses(message_buffer, chunk) diff --git a/tests/test_cli_chunk_message_ingest.py b/tests/test_cli_chunk_message_ingest.py new file mode 100644 index 00000000..b4efeef9 --- /dev/null +++ b/tests/test_cli_chunk_message_ingest.py @@ -0,0 +1,64 @@ +from cli.chunk_ingest import ingest_chunk_messages + + +class FakeToolCall: + def __init__(self, name, args): + self.name = name + self.args = args + + +class FakeMessage: + def __init__(self, msg_id, content, tool_calls=None): + self.id = msg_id + self.content = content + self.tool_calls = tool_calls or [] + + +class FakeMessageBuffer: + def __init__(self): + self._processed_message_ids = set() + self.messages = [] + self.tool_calls = [] + + def add_message(self, message_type, content): + self.messages.append((message_type, content)) + + def add_tool_call(self, tool_name, args): + self.tool_calls.append((tool_name, args)) + + +def fake_classifier(message): + return "Agent", message.content + + +def test_ingest_chunk_messages_records_all_tool_calls(): + message_buffer = FakeMessageBuffer() + chunk = { + "messages": [ + FakeMessage( + "m1", + "first", + [ + {"name": "tool_a", "args": {"x": 1}}, + FakeToolCall("tool_b", {"y": 2}), + ], + ), + FakeMessage("m2", "second", [FakeToolCall("tool_c", {"z": 3})]), + ] + } + + ingest_chunk_messages(message_buffer, chunk, fake_classifier) + + tool_names = [name for name, _ in message_buffer.tool_calls] + assert tool_names == ["tool_a", "tool_b", "tool_c"] + + +def test_ingest_chunk_messages_skips_duplicate_message_ids(): + message_buffer = FakeMessageBuffer() + chunk = {"messages": [FakeMessage("m1", "same", [{"name": "tool_a", "args": {}}])]} + + ingest_chunk_messages(message_buffer, chunk, fake_classifier) + ingest_chunk_messages(message_buffer, chunk, fake_classifier) + + assert len(message_buffer.messages) == 1 + assert len(message_buffer.tool_calls) == 1