Merge pull request #68 from aguzererler/feat/test-tool-call-parsing-12253012184888559232
🧪 Add robust parsing and tests for LLM string tool calls
This commit is contained in:
commit
df735bf6e9
36
cli/main.py
36
cli/main.py
|
|
@ -890,6 +890,34 @@ def classify_message_type(message) -> tuple[str, str | None]:
|
||||||
return ("System", content)
|
return ("System", content)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_call(tool_call) -> tuple[str, dict | str]:
|
||||||
|
"""Parse a tool call into a name and arguments dictionary.
|
||||||
|
Handles dicts, objects with name/args attributes, and string representations.
|
||||||
|
"""
|
||||||
|
import ast
|
||||||
|
|
||||||
|
if isinstance(tool_call, dict):
|
||||||
|
tool_name = tool_call.get("name", "Unknown Tool")
|
||||||
|
args = tool_call.get("args", tool_call.get("arguments", {}))
|
||||||
|
return tool_name, args
|
||||||
|
|
||||||
|
if isinstance(tool_call, str):
|
||||||
|
try:
|
||||||
|
tool_call_dict = ast.literal_eval(tool_call)
|
||||||
|
if not isinstance(tool_call_dict, dict):
|
||||||
|
tool_call_dict = {}
|
||||||
|
except (ValueError, SyntaxError):
|
||||||
|
tool_call_dict = {}
|
||||||
|
|
||||||
|
tool_name = tool_call_dict.get("name", "Unknown Tool")
|
||||||
|
args = tool_call_dict.get("args", tool_call_dict.get("arguments", {}))
|
||||||
|
return tool_name, args
|
||||||
|
|
||||||
|
# Fallback for objects with name and args attributes
|
||||||
|
tool_name = getattr(tool_call, "name", "Unknown Tool")
|
||||||
|
args = getattr(tool_call, "args", getattr(tool_call, "arguments", {}))
|
||||||
|
return tool_name, args
|
||||||
|
|
||||||
def format_tool_args(args, max_length=80) -> str:
|
def format_tool_args(args, max_length=80) -> str:
|
||||||
"""Format tool arguments for terminal display."""
|
"""Format tool arguments for terminal display."""
|
||||||
result = str(args)
|
result = str(args)
|
||||||
|
|
@ -1051,12 +1079,8 @@ def run_analysis():
|
||||||
# Handle tool calls
|
# Handle tool calls
|
||||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||||
for tool_call in last_message.tool_calls:
|
for tool_call in last_message.tool_calls:
|
||||||
if isinstance(tool_call, dict):
|
tool_name, tool_args = parse_tool_call(tool_call)
|
||||||
message_buffer.add_tool_call(
|
message_buffer.add_tool_call(tool_name, tool_args)
|
||||||
tool_call["name"], tool_call["args"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
|
||||||
|
|
||||||
# Update analyst statuses based on report state (runs on every chunk)
|
# Update analyst statuses based on report state (runs on every chunk)
|
||||||
update_analyst_statuses(message_buffer, chunk)
|
update_analyst_statuses(message_buffer, chunk)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
"""Tests for cli.main tool call parsing utility functions."""
|
||||||
|
import pytest
|
||||||
|
from cli.main import parse_tool_call
|
||||||
|
|
||||||
|
class MockToolCall:
|
||||||
|
def __init__(self, name, args):
|
||||||
|
self.name = name
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def test_parse_tool_call_dict_with_args():
|
||||||
|
tool_call = {"name": "get_stock_price", "args": {"ticker": "AAPL"}}
|
||||||
|
name, args = parse_tool_call(tool_call)
|
||||||
|
assert name == "get_stock_price"
|
||||||
|
assert args == {"ticker": "AAPL"}
|
||||||
|
|
||||||
|
def test_parse_tool_call_dict_with_arguments():
|
||||||
|
tool_call = {"name": "get_stock_price", "arguments": {"ticker": "AAPL"}}
|
||||||
|
name, args = parse_tool_call(tool_call)
|
||||||
|
assert name == "get_stock_price"
|
||||||
|
assert args == {"ticker": "AAPL"}
|
||||||
|
|
||||||
|
def test_parse_tool_call_string_valid_dict():
|
||||||
|
tool_call = '{"name": "get_news", "args": {"ticker": "TSLA"}}'
|
||||||
|
name, args = parse_tool_call(tool_call)
|
||||||
|
assert name == "get_news"
|
||||||
|
assert args == {"ticker": "TSLA"}
|
||||||
|
|
||||||
|
def test_parse_tool_call_string_value_error():
|
||||||
|
# 'unknown_variable' is a valid expression but raises ValueError in literal_eval
|
||||||
|
tool_call = 'unknown_variable'
|
||||||
|
name, args = parse_tool_call(tool_call)
|
||||||
|
assert name == "Unknown Tool"
|
||||||
|
assert args == {}
|
||||||
|
|
||||||
|
def test_parse_tool_call_string_syntax_error():
|
||||||
|
# '{"name": "get_news"' is missing a closing brace, raises SyntaxError
|
||||||
|
tool_call = '{"name": "get_news"'
|
||||||
|
name, args = parse_tool_call(tool_call)
|
||||||
|
assert name == "Unknown Tool"
|
||||||
|
assert args == {}
|
||||||
|
|
||||||
|
def test_parse_tool_call_string_not_dict():
|
||||||
|
# A valid string but doesn't evaluate to a dict
|
||||||
|
tool_call = '"just a string"'
|
||||||
|
name, args = parse_tool_call(tool_call)
|
||||||
|
assert name == "Unknown Tool"
|
||||||
|
assert args == {}
|
||||||
|
|
||||||
|
def test_parse_tool_call_object():
|
||||||
|
tool_call = MockToolCall("get_sentiment", {"ticker": "GOOG"})
|
||||||
|
name, args = parse_tool_call(tool_call)
|
||||||
|
assert name == "get_sentiment"
|
||||||
|
assert args == {"ticker": "GOOG"}
|
||||||
Loading…
Reference in New Issue