Compare commits

...

1 Commits

Author SHA1 Message Date
gujishh 0e242d33d2 fix(cli): log all tool calls from streamed chunk messages 2026-04-12 14:42:38 +09:00
3 changed files with 88 additions and 24 deletions

19
cli/chunk_ingest.py Normal file
View File

@ -0,0 +1,19 @@
def ingest_chunk_messages(message_buffer, chunk, classify_message_type) -> None:
"""Ingest all newly seen messages from a graph stream chunk."""
for message in chunk.get("messages", []):
msg_id = getattr(message, "id", None)
if msg_id is not None:
if msg_id in message_buffer._processed_message_ids:
continue
message_buffer._processed_message_ids.add(msg_id)
msg_type, content = classify_message_type(message)
if content and content.strip():
message_buffer.add_message(msg_type, content)
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
if isinstance(tool_call, dict):
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
else:
message_buffer.add_tool_call(tool_call.name, tool_call.args)

View File

@ -27,6 +27,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from cli.models import AnalystType
from cli.utils import *
from cli.chunk_ingest import ingest_chunk_messages
from cli.announcements import fetch_announcements, display_announcements
from cli.stats_handler import StatsCallbackHandler
@ -79,7 +80,7 @@ class MessageBuffer:
self.current_agent = None
self.report_sections = {}
self.selected_analysts = []
self._last_message_id = None
self._processed_message_ids = set()
def init_for_analysis(self, selected_analysts):
"""Initialize agent status and report sections based on selected analysts.
@ -114,7 +115,7 @@ class MessageBuffer:
self.current_agent = None
self.messages.clear()
self.tool_calls.clear()
self._last_message_id = None
self._processed_message_ids.clear()
def get_completed_reports_count(self):
"""Count reports that are finalized (their finalizing agent is completed).
@ -925,6 +926,7 @@ def format_tool_args(args, max_length=80) -> str:
return result[:max_length - 3] + "..."
return result
def run_analysis():
# First get all user selections
selections = get_user_selections()
@ -1052,28 +1054,7 @@ def run_analysis():
# Stream the analysis
trace = []
for chunk in graph.graph.stream(init_agent_state, **args):
# Process messages if present (skip duplicates via message ID)
if len(chunk["messages"]) > 0:
last_message = chunk["messages"][-1]
msg_id = getattr(last_message, "id", None)
if msg_id != message_buffer._last_message_id:
message_buffer._last_message_id = msg_id
# Add message to buffer
msg_type, content = classify_message_type(last_message)
if content and content.strip():
message_buffer.add_message(msg_type, content)
# Handle tool calls
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict):
message_buffer.add_tool_call(
tool_call["name"], tool_call["args"]
)
else:
message_buffer.add_tool_call(tool_call.name, tool_call.args)
ingest_chunk_messages(message_buffer, chunk, classify_message_type)
# Update analyst statuses based on report state (runs on every chunk)
update_analyst_statuses(message_buffer, chunk)

View File

@ -0,0 +1,64 @@
from cli.chunk_ingest import ingest_chunk_messages
class FakeToolCall:
def __init__(self, name, args):
self.name = name
self.args = args
class FakeMessage:
def __init__(self, msg_id, content, tool_calls=None):
self.id = msg_id
self.content = content
self.tool_calls = tool_calls or []
class FakeMessageBuffer:
def __init__(self):
self._processed_message_ids = set()
self.messages = []
self.tool_calls = []
def add_message(self, message_type, content):
self.messages.append((message_type, content))
def add_tool_call(self, tool_name, args):
self.tool_calls.append((tool_name, args))
def fake_classifier(message):
return "Agent", message.content
def test_ingest_chunk_messages_records_all_tool_calls():
message_buffer = FakeMessageBuffer()
chunk = {
"messages": [
FakeMessage(
"m1",
"first",
[
{"name": "tool_a", "args": {"x": 1}},
FakeToolCall("tool_b", {"y": 2}),
],
),
FakeMessage("m2", "second", [FakeToolCall("tool_c", {"z": 3})]),
]
}
ingest_chunk_messages(message_buffer, chunk, fake_classifier)
tool_names = [name for name, _ in message_buffer.tool_calls]
assert tool_names == ["tool_a", "tool_b", "tool_c"]
def test_ingest_chunk_messages_skips_duplicate_message_ids():
message_buffer = FakeMessageBuffer()
chunk = {"messages": [FakeMessage("m1", "same", [{"name": "tool_a", "args": {}}])]}
ingest_chunk_messages(message_buffer, chunk, fake_classifier)
ingest_chunk_messages(message_buffer, chunk, fake_classifier)
assert len(message_buffer.messages) == 1
assert len(message_buffer.tool_calls) == 1