From fa4d01c23acef4882fd74dd5be75dd3c7a4bc5f7 Mon Sep 17 00:00:00 2001 From: Yijia-Xiao Date: Mon, 13 Apr 2026 07:21:33 +0000 Subject: [PATCH] fix: process all chunk messages for tool call logging, harden memory score normalization (#534, #531) --- cli/main.py | 40 +++++++++++++--------------- tradingagents/agents/utils/memory.py | 2 +- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/cli/main.py b/cli/main.py index 52e8a332..33d110fb 100644 --- a/cli/main.py +++ b/cli/main.py @@ -80,7 +80,7 @@ class MessageBuffer: self.current_agent = None self.report_sections = {} self.selected_analysts = [] - self._last_message_id = None + self._processed_message_ids = set() def init_for_analysis(self, selected_analysts): """Initialize agent status and report sections based on selected analysts. @@ -115,7 +115,7 @@ class MessageBuffer: self.current_agent = None self.messages.clear() self.tool_calls.clear() - self._last_message_id = None + self._processed_message_ids.clear() def get_completed_reports_count(self): """Count reports that are finalized (their finalizing agent is completed). @@ -1053,28 +1053,24 @@ def run_analysis(): # Stream the analysis trace = [] for chunk in graph.graph.stream(init_agent_state, **args): - # Process messages if present (skip duplicates via message ID) - if len(chunk["messages"]) > 0: - last_message = chunk["messages"][-1] - msg_id = getattr(last_message, "id", None) + # Process all messages in chunk, deduplicating by message ID + for message in chunk.get("messages", []): + msg_id = getattr(message, "id", None) + if msg_id is not None: + if msg_id in message_buffer._processed_message_ids: + continue + message_buffer._processed_message_ids.add(msg_id) - if msg_id != message_buffer._last_message_id: - message_buffer._last_message_id = msg_id + msg_type, content = classify_message_type(message) + if content and content.strip(): + message_buffer.add_message(msg_type, content) - # Add message to buffer - msg_type, content = classify_message_type(last_message) - if content and content.strip(): - message_buffer.add_message(msg_type, content) - - # Handle tool calls - if hasattr(last_message, "tool_calls") and last_message.tool_calls: - for tool_call in last_message.tool_calls: - if isinstance(tool_call, dict): - message_buffer.add_tool_call( - tool_call["name"], tool_call["args"] - ) - else: - message_buffer.add_tool_call(tool_call.name, tool_call.args) + if hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + if isinstance(tool_call, dict): + message_buffer.add_tool_call(tool_call["name"], tool_call["args"]) + else: + message_buffer.add_tool_call(tool_call.name, tool_call.args) # Update analyst statuses based on report state (runs on every chunk) update_analyst_statuses(message_buffer, chunk) diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index d278b3c3..2aefa7a3 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -78,7 +78,7 @@ class FinancialSituationMemory: # Build results results = [] - max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores + max_score = float(scores.max()) if len(scores) > 0 and scores.max() > 0 else 1.0 for idx in top_indices: # Normalize score to 0-1 range for consistency