diff --git a/tests/test_json_utils.py b/tests/test_json_utils.py index ea44d3a0..5729b300 100644 --- a/tests/test_json_utils.py +++ b/tests/test_json_utils.py @@ -119,13 +119,12 @@ def test_no_json_at_all(): extract_json("Just some text with no JSON structure at all") -def test_array_input_returns_list(): - """extract_json succeeds on JSON arrays — json.loads parses them as lists. +def test_array_input_raises_value_error(): + """extract_json rejects JSON arrays — only dicts are accepted. - The function's return-type annotation says dict, but the implementation does - not enforce this at runtime. A JSON array is valid JSON, so step 1 - (direct json.loads) succeeds and returns a list. Callers that need a dict - must validate the returned type themselves. + All callers (macro_synthesis, macro_bridge, CLI) call .get() on the result, + so returning a list would cause AttributeError downstream. The function + enforces dict-only return at runtime. """ - result = extract_json('[1, 2, 3]') - assert result == [1, 2, 3] + with pytest.raises(ValueError, match="Expected a JSON object"): + extract_json('[1, 2, 3]') diff --git a/tradingagents/agents/scanners/macro_synthesis.py b/tradingagents/agents/scanners/macro_synthesis.py index c58ef561..b29d517f 100644 --- a/tradingagents/agents/scanners/macro_synthesis.py +++ b/tradingagents/agents/scanners/macro_synthesis.py @@ -1,12 +1,12 @@ import json import logging +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + from tradingagents.agents.utils.json_utils import extract_json logger = logging.getLogger(__name__) -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder - def create_macro_synthesis(llm): def macro_synthesis_node(state): diff --git a/tradingagents/agents/utils/json_utils.py b/tradingagents/agents/utils/json_utils.py index adc28fac..fee623a6 100644 --- a/tradingagents/agents/utils/json_utils.py +++ b/tradingagents/agents/utils/json_utils.py @@ -30,9 +30,16 @@ def extract_json(text: str) -> dict[str, Any]: if not text or not text.strip(): raise ValueError("Empty input — no JSON to extract") + def _ensure_dict(obj: object) -> dict[str, Any]: + if not isinstance(obj, dict): + raise ValueError( + f"Expected a JSON object (dict), got {type(obj).__name__}" + ) + return obj + # 1. Direct parse try: - return json.loads(text) + return _ensure_dict(json.loads(text)) except json.JSONDecodeError: pass @@ -41,7 +48,7 @@ def extract_json(text: str) -> dict[str, Any]: # Try again after stripping think blocks try: - return json.loads(cleaned) + return _ensure_dict(json.loads(cleaned)) except json.JSONDecodeError: pass @@ -50,8 +57,9 @@ def extract_json(text: str) -> dict[str, Any]: fences = re.findall(fence_pattern, cleaned, re.DOTALL) for block in fences: try: - return json.loads(block.strip()) - except json.JSONDecodeError: + return _ensure_dict(json.loads(block.strip())) + except (json.JSONDecodeError, ValueError): + # JSONDecodeError = bad JSON; ValueError = parsed but not a dict continue # 4. Find first '{' to last '}' @@ -59,8 +67,9 @@ def extract_json(text: str) -> dict[str, Any]: last_brace = cleaned.rfind("}") if first_brace != -1 and last_brace > first_brace: try: - return json.loads(cleaned[first_brace : last_brace + 1]) - except json.JSONDecodeError: + return _ensure_dict(json.loads(cleaned[first_brace : last_brace + 1])) + except (json.JSONDecodeError, ValueError): + # JSONDecodeError = bad JSON; ValueError = parsed but not a dict pass raise ValueError( diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 4e0caec2..357109f9 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -256,4 +256,5 @@ def route_to_vendor(method: str, *args, **kwargs): continue error_msg = f"All vendors failed for '{method}' (tried: {', '.join(tried)})" - raise RuntimeError(error_msg) from last_error \ No newline at end of file + raise RuntimeError(error_msg) from last_error +