fix: process all chunk messages + harden memory score normalization
Backport from upstream TauricResearch/TradingAgents fa4d01c2:
- cli/main.py: iterate all messages in each chunk (not just the last);
dedupe across chunks with a processed-IDs set instead of tracking
only the most recent message. Prevents dropped tool-call logs when
the graph streams multiple messages per chunk.
- memory.py: guard max() against empty scores and return a float to
avoid ValueError on first-call paths before documents are added.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a3b6efde16
commit
806263defd
40
cli/main.py
40
cli/main.py
|
|
@ -79,7 +79,7 @@ class MessageBuffer:
|
||||||
self.current_agent = None
|
self.current_agent = None
|
||||||
self.report_sections = {}
|
self.report_sections = {}
|
||||||
self.selected_analysts = []
|
self.selected_analysts = []
|
||||||
self._last_message_id = None
|
self._processed_message_ids = set()
|
||||||
|
|
||||||
def init_for_analysis(self, selected_analysts):
|
def init_for_analysis(self, selected_analysts):
|
||||||
"""Initialize agent status and report sections based on selected analysts.
|
"""Initialize agent status and report sections based on selected analysts.
|
||||||
|
|
@ -114,7 +114,7 @@ class MessageBuffer:
|
||||||
self.current_agent = None
|
self.current_agent = None
|
||||||
self.messages.clear()
|
self.messages.clear()
|
||||||
self.tool_calls.clear()
|
self.tool_calls.clear()
|
||||||
self._last_message_id = None
|
self._processed_message_ids.clear()
|
||||||
|
|
||||||
def get_completed_reports_count(self):
|
def get_completed_reports_count(self):
|
||||||
"""Count reports that are finalized (their finalizing agent is completed).
|
"""Count reports that are finalized (their finalizing agent is completed).
|
||||||
|
|
@ -1021,28 +1021,24 @@ def run_analysis():
|
||||||
# Stream the analysis
|
# Stream the analysis
|
||||||
trace = []
|
trace = []
|
||||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
for chunk in graph.graph.stream(init_agent_state, **args):
|
||||||
# Process messages if present (skip duplicates via message ID)
|
# Process all messages in chunk, deduplicating by message ID
|
||||||
if len(chunk["messages"]) > 0:
|
for message in chunk.get("messages", []):
|
||||||
last_message = chunk["messages"][-1]
|
msg_id = getattr(message, "id", None)
|
||||||
msg_id = getattr(last_message, "id", None)
|
if msg_id is not None:
|
||||||
|
if msg_id in message_buffer._processed_message_ids:
|
||||||
|
continue
|
||||||
|
message_buffer._processed_message_ids.add(msg_id)
|
||||||
|
|
||||||
if msg_id != message_buffer._last_message_id:
|
msg_type, content = classify_message_type(message)
|
||||||
message_buffer._last_message_id = msg_id
|
if content and content.strip():
|
||||||
|
message_buffer.add_message(msg_type, content)
|
||||||
|
|
||||||
# Add message to buffer
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||||
msg_type, content = classify_message_type(last_message)
|
for tool_call in message.tool_calls:
|
||||||
if content and content.strip():
|
if isinstance(tool_call, dict):
|
||||||
message_buffer.add_message(msg_type, content)
|
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
|
||||||
|
else:
|
||||||
# Handle tool calls
|
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
|
||||||
for tool_call in last_message.tool_calls:
|
|
||||||
if isinstance(tool_call, dict):
|
|
||||||
message_buffer.add_tool_call(
|
|
||||||
tool_call["name"], tool_call["args"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
|
||||||
|
|
||||||
# Update analyst statuses based on report state (runs on every chunk)
|
# Update analyst statuses based on report state (runs on every chunk)
|
||||||
update_analyst_statuses(message_buffer, chunk)
|
update_analyst_statuses(message_buffer, chunk)
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ class FinancialSituationMemory:
|
||||||
|
|
||||||
# Build results
|
# Build results
|
||||||
results = []
|
results = []
|
||||||
max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores
|
max_score = float(max(scores)) if len(scores) > 0 and max(scores) > 0 else 1.0
|
||||||
|
|
||||||
for idx in top_indices:
|
for idx in top_indices:
|
||||||
# Normalize score to 0-1 range for consistency
|
# Normalize score to 0-1 range for consistency
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue