fix(cli): log all tool calls from streamed chunk messages
This commit is contained in:
parent
fa4d01c23a
commit
9ba1858948
|
|
@ -0,0 +1,19 @@
|
||||||
|
def ingest_chunk_messages(message_buffer, chunk, classify_message_type) -> None:
|
||||||
|
"""Ingest all newly seen messages from a graph stream chunk."""
|
||||||
|
for message in chunk.get("messages", []):
|
||||||
|
msg_id = getattr(message, "id", None)
|
||||||
|
if msg_id is not None:
|
||||||
|
if msg_id in message_buffer._processed_message_ids:
|
||||||
|
continue
|
||||||
|
message_buffer._processed_message_ids.add(msg_id)
|
||||||
|
|
||||||
|
msg_type, content = classify_message_type(message)
|
||||||
|
if content and content.strip():
|
||||||
|
message_buffer.add_message(msg_type, content)
|
||||||
|
|
||||||
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
if isinstance(tool_call, dict):
|
||||||
|
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
|
||||||
|
else:
|
||||||
|
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||||
21
cli/main.py
21
cli/main.py
|
|
@ -28,6 +28,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
from cli.models import AnalystType
|
from cli.models import AnalystType
|
||||||
from cli.utils import *
|
from cli.utils import *
|
||||||
|
from cli.chunk_ingest import ingest_chunk_messages
|
||||||
from cli.announcements import fetch_announcements, display_announcements
|
from cli.announcements import fetch_announcements, display_announcements
|
||||||
from cli.stats_handler import StatsCallbackHandler
|
from cli.stats_handler import StatsCallbackHandler
|
||||||
|
|
||||||
|
|
@ -926,6 +927,7 @@ def format_tool_args(args, max_length=80) -> str:
|
||||||
return result[:max_length - 3] + "..."
|
return result[:max_length - 3] + "..."
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def run_analysis():
|
def run_analysis():
|
||||||
# First get all user selections
|
# First get all user selections
|
||||||
selections = get_user_selections()
|
selections = get_user_selections()
|
||||||
|
|
@ -1053,24 +1055,7 @@ def run_analysis():
|
||||||
# Stream the analysis
|
# Stream the analysis
|
||||||
trace = []
|
trace = []
|
||||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
for chunk in graph.graph.stream(init_agent_state, **args):
|
||||||
# Process all messages in chunk, deduplicating by message ID
|
ingest_chunk_messages(message_buffer, chunk, classify_message_type)
|
||||||
for message in chunk.get("messages", []):
|
|
||||||
msg_id = getattr(message, "id", None)
|
|
||||||
if msg_id is not None:
|
|
||||||
if msg_id in message_buffer._processed_message_ids:
|
|
||||||
continue
|
|
||||||
message_buffer._processed_message_ids.add(msg_id)
|
|
||||||
|
|
||||||
msg_type, content = classify_message_type(message)
|
|
||||||
if content and content.strip():
|
|
||||||
message_buffer.add_message(msg_type, content)
|
|
||||||
|
|
||||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
if isinstance(tool_call, dict):
|
|
||||||
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
|
|
||||||
else:
|
|
||||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
|
||||||
|
|
||||||
# Update analyst statuses based on report state (runs on every chunk)
|
# Update analyst statuses based on report state (runs on every chunk)
|
||||||
update_analyst_statuses(message_buffer, chunk)
|
update_analyst_statuses(message_buffer, chunk)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
from cli.chunk_ingest import ingest_chunk_messages
|
||||||
|
|
||||||
|
|
||||||
|
class FakeToolCall:
|
||||||
|
def __init__(self, name, args):
|
||||||
|
self.name = name
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
|
||||||
|
class FakeMessage:
|
||||||
|
def __init__(self, msg_id, content, tool_calls=None):
|
||||||
|
self.id = msg_id
|
||||||
|
self.content = content
|
||||||
|
self.tool_calls = tool_calls or []
|
||||||
|
|
||||||
|
|
||||||
|
class FakeMessageBuffer:
|
||||||
|
def __init__(self):
|
||||||
|
self._processed_message_ids = set()
|
||||||
|
self.messages = []
|
||||||
|
self.tool_calls = []
|
||||||
|
|
||||||
|
def add_message(self, message_type, content):
|
||||||
|
self.messages.append((message_type, content))
|
||||||
|
|
||||||
|
def add_tool_call(self, tool_name, args):
|
||||||
|
self.tool_calls.append((tool_name, args))
|
||||||
|
|
||||||
|
|
||||||
|
def fake_classifier(message):
|
||||||
|
return "Agent", message.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_ingest_chunk_messages_records_all_tool_calls():
|
||||||
|
message_buffer = FakeMessageBuffer()
|
||||||
|
chunk = {
|
||||||
|
"messages": [
|
||||||
|
FakeMessage(
|
||||||
|
"m1",
|
||||||
|
"first",
|
||||||
|
[
|
||||||
|
{"name": "tool_a", "args": {"x": 1}},
|
||||||
|
FakeToolCall("tool_b", {"y": 2}),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
FakeMessage("m2", "second", [FakeToolCall("tool_c", {"z": 3})]),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
ingest_chunk_messages(message_buffer, chunk, fake_classifier)
|
||||||
|
|
||||||
|
tool_names = [name for name, _ in message_buffer.tool_calls]
|
||||||
|
assert tool_names == ["tool_a", "tool_b", "tool_c"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_ingest_chunk_messages_skips_duplicate_message_ids():
|
||||||
|
message_buffer = FakeMessageBuffer()
|
||||||
|
chunk = {"messages": [FakeMessage("m1", "same", [{"name": "tool_a", "args": {}}])]}
|
||||||
|
|
||||||
|
ingest_chunk_messages(message_buffer, chunk, fake_classifier)
|
||||||
|
ingest_chunk_messages(message_buffer, chunk, fake_classifier)
|
||||||
|
|
||||||
|
assert len(message_buffer.messages) == 1
|
||||||
|
assert len(message_buffer.tool_calls) == 1
|
||||||
Loading…
Reference in New Issue