diff --git a/tests/portfolio/test_report_store.py b/tests/portfolio/test_report_store.py index 9f042a26..3aa9dc57 100644 --- a/tests/portfolio/test_report_store.py +++ b/tests/portfolio/test_report_store.py @@ -174,3 +174,96 @@ def test_read_json_raises_on_corrupt_file(report_store, tmp_reports): corrupt.write_text("not valid json{{{", encoding="utf-8") with pytest.raises(ReportStoreError): report_store._read_json(corrupt) + + +# --------------------------------------------------------------------------- +# _sanitize +# --------------------------------------------------------------------------- + + +class _FakeMessage: + """Minimal stand-in for a LangChain HumanMessage / AIMessage.""" + + def __init__(self, type_: str, content: str) -> None: + self.type = type_ + self.content = content + + +class _FakeMessageWithDict(_FakeMessage): + """Stand-in that also exposes a .dict() method like LangChain BaseMessage.""" + + def dict(self) -> dict: + return {"type": self.type, "content": self.content, "extra": "field"} + + +def test_sanitize_primitives_passthrough(): + """Primitive values must be returned unchanged.""" + assert ReportStore._sanitize(None) is None + assert ReportStore._sanitize(True) is True + assert ReportStore._sanitize(42) == 42 + assert ReportStore._sanitize(3.14) == 3.14 + assert ReportStore._sanitize("hello") == "hello" + + +def test_sanitize_plain_dict_passthrough(): + """A plain JSON-safe dict must survive _sanitize unchanged.""" + data = {"a": 1, "b": [2, 3], "c": {"d": "e"}} + assert ReportStore._sanitize(data) == data + + +def test_sanitize_list_and_tuple(): + """Lists and tuples of primitives must be returned as lists.""" + assert ReportStore._sanitize([1, 2, 3]) == [1, 2, 3] + assert ReportStore._sanitize((1, "x")) == [1, "x"] + + +def test_sanitize_message_without_dict_method(): + """A message-like object without .dict() must be converted to type/content.""" + msg = _FakeMessage("human", "hello world") + result = ReportStore._sanitize(msg) + assert result == {"type": "human", "content": "hello world"} + + +def test_sanitize_message_with_dict_method(): + """A message-like object with .dict() must be sanitized via that dict.""" + msg = _FakeMessageWithDict("ai", "response text") + result = ReportStore._sanitize(msg) + assert result == {"type": "ai", "content": "response text", "extra": "field"} + + +def test_sanitize_nested_messages_in_state(): + """Messages nested inside a LangGraph-style state dict must be sanitized.""" + msg = _FakeMessage("human", "buy signal") + state = { + "messages": [msg], + "investment_debate_state": {"history": [msg]}, + "ticker": "AAPL", + } + result = ReportStore._sanitize(state) + assert result["ticker"] == "AAPL" + assert result["messages"] == [{"type": "human", "content": "buy signal"}] + debate = result["investment_debate_state"]["history"] + assert debate == [{"type": "human", "content": "buy signal"}] + + +def test_sanitize_arbitrary_non_serializable_falls_back_to_str(): + """An arbitrary non-serializable object must fall back to str().""" + + class _Weird: + def __str__(self) -> str: + return "weird_value" + + result = ReportStore._sanitize(_Weird()) + assert result == "weird_value" + + +def test_write_json_with_message_objects_does_not_raise(report_store, tmp_reports): + """_write_json must not raise when data contains message-like objects.""" + msg = _FakeMessage("human", "test") + data = {"messages": [msg], "ticker": "TSLA"} + path = tmp_reports / "test_output.json" + written = report_store._write_json(path, data) + assert written.exists() + loaded = json.loads(written.read_text(encoding="utf-8")) + assert loaded["ticker"] == "TSLA" + assert loaded["messages"] == [{"type": "human", "content": "test"}] diff --git a/tradingagents/portfolio/report_store.py b/tradingagents/portfolio/report_store.py index 2d641693..2976ae65 100644 --- a/tradingagents/portfolio/report_store.py +++ b/tradingagents/portfolio/report_store.py @@ -64,6 +64,35 @@ class ReportStore: """ return self._base_dir / "daily" / date / "portfolio" + @staticmethod + def _sanitize(obj: Any) -> Any: + """Recursively convert non-JSON-serializable objects to safe types. + + Handles LangChain message objects (``HumanMessage``, ``AIMessage``, + etc.) that appear in LangGraph state dicts, as well as any other + arbitrary objects that are not natively JSON-serializable. + """ + if obj is None or isinstance(obj, (bool, int, float, str)): + return obj + if isinstance(obj, dict): + return {k: ReportStore._sanitize(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [ReportStore._sanitize(item) for item in obj] + # LangChain BaseMessage objects expose .type and .content + if hasattr(obj, "type") and hasattr(obj, "content"): + try: + if hasattr(obj, "dict") and callable(obj.dict): + return ReportStore._sanitize(obj.dict()) + except Exception: + pass + return {"type": str(obj.type), "content": str(obj.content)} + # Generic fallback: try a serialization probe first + try: + json.dumps(obj) + return obj + except (TypeError, ValueError): + return str(obj) + def _write_json(self, path: Path, data: dict[str, Any]) -> Path: """Write a dict to a JSON file, creating parent directories as needed. @@ -79,7 +108,8 @@ class ReportStore: """ try: path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(data, indent=2), encoding="utf-8") + sanitized = self._sanitize(data) + path.write_text(json.dumps(sanitized, indent=2), encoding="utf-8") return path except OSError as exc: raise ReportStoreError(f"Failed to write {path}: {exc}") from exc