diff --git a/cli/main.py b/cli/main.py index f6c57782..fdf543cb 100644 --- a/cli/main.py +++ b/cli/main.py @@ -890,6 +890,34 @@ def classify_message_type(message) -> tuple[str, str | None]: return ("System", content) +def parse_tool_call(tool_call) -> tuple[str, dict | str]: + """Parse a tool call into a name and arguments dictionary. + Handles dicts, objects with name/args attributes, and string representations. + """ + import ast + + if isinstance(tool_call, dict): + tool_name = tool_call.get("name", "Unknown Tool") + args = tool_call.get("args", tool_call.get("arguments", {})) + return tool_name, args + + if isinstance(tool_call, str): + try: + tool_call_dict = ast.literal_eval(tool_call) + if not isinstance(tool_call_dict, dict): + tool_call_dict = {} + except (ValueError, SyntaxError): + tool_call_dict = {} + + tool_name = tool_call_dict.get("name", "Unknown Tool") + args = tool_call_dict.get("args", tool_call_dict.get("arguments", {})) + return tool_name, args + + # Fallback for objects with name and args attributes + tool_name = getattr(tool_call, "name", "Unknown Tool") + args = getattr(tool_call, "args", getattr(tool_call, "arguments", {})) + return tool_name, args + def format_tool_args(args, max_length=80) -> str: """Format tool arguments for terminal display.""" result = str(args) @@ -1051,12 +1079,8 @@ def run_analysis(): # Handle tool calls if hasattr(last_message, "tool_calls") and last_message.tool_calls: for tool_call in last_message.tool_calls: - if isinstance(tool_call, dict): - message_buffer.add_tool_call( - tool_call["name"], tool_call["args"] - ) - else: - message_buffer.add_tool_call(tool_call.name, tool_call.args) + tool_name, tool_args = parse_tool_call(tool_call) + message_buffer.add_tool_call(tool_name, tool_args) # Update analyst statuses based on report state (runs on every chunk) update_analyst_statuses(message_buffer, chunk) diff --git a/tests/unit/test_cli_main_tools.py b/tests/unit/test_cli_main_tools.py new file mode 100644 index 00000000..2514bd98 --- /dev/null +++ b/tests/unit/test_cli_main_tools.py @@ -0,0 +1,53 @@ +"""Tests for cli.main tool call parsing utility functions.""" +import pytest +from cli.main import parse_tool_call + +class MockToolCall: + def __init__(self, name, args): + self.name = name + self.args = args + +def test_parse_tool_call_dict_with_args(): + tool_call = {"name": "get_stock_price", "args": {"ticker": "AAPL"}} + name, args = parse_tool_call(tool_call) + assert name == "get_stock_price" + assert args == {"ticker": "AAPL"} + +def test_parse_tool_call_dict_with_arguments(): + tool_call = {"name": "get_stock_price", "arguments": {"ticker": "AAPL"}} + name, args = parse_tool_call(tool_call) + assert name == "get_stock_price" + assert args == {"ticker": "AAPL"} + +def test_parse_tool_call_string_valid_dict(): + tool_call = '{"name": "get_news", "args": {"ticker": "TSLA"}}' + name, args = parse_tool_call(tool_call) + assert name == "get_news" + assert args == {"ticker": "TSLA"} + +def test_parse_tool_call_string_value_error(): + # 'unknown_variable' is a valid expression but raises ValueError in literal_eval + tool_call = 'unknown_variable' + name, args = parse_tool_call(tool_call) + assert name == "Unknown Tool" + assert args == {} + +def test_parse_tool_call_string_syntax_error(): + # '{"name": "get_news"' is missing a closing brace, raises SyntaxError + tool_call = '{"name": "get_news"' + name, args = parse_tool_call(tool_call) + assert name == "Unknown Tool" + assert args == {} + +def test_parse_tool_call_string_not_dict(): + # A valid string but doesn't evaluate to a dict + tool_call = '"just a string"' + name, args = parse_tool_call(tool_call) + assert name == "Unknown Tool" + assert args == {} + +def test_parse_tool_call_object(): + tool_call = MockToolCall("get_sentiment", {"ticker": "GOOG"}) + name, args = parse_tool_call(tool_call) + assert name == "get_sentiment" + assert args == {"ticker": "GOOG"}