Merge pull request #96 from aguzererler/copilot/fix-report-saving-in-runs
Copilot/fix report saving in runs
This commit is contained in:
commit
e568cec68c
|
|
@ -174,3 +174,96 @@ def test_read_json_raises_on_corrupt_file(report_store, tmp_reports):
|
||||||
corrupt.write_text("not valid json{{{", encoding="utf-8")
|
corrupt.write_text("not valid json{{{", encoding="utf-8")
|
||||||
with pytest.raises(ReportStoreError):
|
with pytest.raises(ReportStoreError):
|
||||||
report_store._read_json(corrupt)
|
report_store._read_json(corrupt)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _sanitize
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeMessage:
|
||||||
|
"""Minimal stand-in for a LangChain HumanMessage / AIMessage."""
|
||||||
|
|
||||||
|
def __init__(self, type_: str, content: str) -> None:
|
||||||
|
self.type = type_
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeMessageWithDict(_FakeMessage):
|
||||||
|
"""Stand-in that also exposes a .dict() method like LangChain BaseMessage."""
|
||||||
|
|
||||||
|
def dict(self) -> dict:
|
||||||
|
return {"type": self.type, "content": self.content, "extra": "field"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_primitives_passthrough():
|
||||||
|
"""Primitive values must be returned unchanged."""
|
||||||
|
assert ReportStore._sanitize(None) is None
|
||||||
|
assert ReportStore._sanitize(True) is True
|
||||||
|
assert ReportStore._sanitize(42) == 42
|
||||||
|
assert ReportStore._sanitize(3.14) == 3.14
|
||||||
|
assert ReportStore._sanitize("hello") == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_plain_dict_passthrough():
|
||||||
|
"""A plain JSON-safe dict must survive _sanitize unchanged."""
|
||||||
|
data = {"a": 1, "b": [2, 3], "c": {"d": "e"}}
|
||||||
|
assert ReportStore._sanitize(data) == data
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_list_and_tuple():
|
||||||
|
"""Lists and tuples of primitives must be returned as lists."""
|
||||||
|
assert ReportStore._sanitize([1, 2, 3]) == [1, 2, 3]
|
||||||
|
assert ReportStore._sanitize((1, "x")) == [1, "x"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_message_without_dict_method():
|
||||||
|
"""A message-like object without .dict() must be converted to type/content."""
|
||||||
|
msg = _FakeMessage("human", "hello world")
|
||||||
|
result = ReportStore._sanitize(msg)
|
||||||
|
assert result == {"type": "human", "content": "hello world"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_message_with_dict_method():
|
||||||
|
"""A message-like object with .dict() must be sanitized via that dict."""
|
||||||
|
msg = _FakeMessageWithDict("ai", "response text")
|
||||||
|
result = ReportStore._sanitize(msg)
|
||||||
|
assert result == {"type": "ai", "content": "response text", "extra": "field"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_nested_messages_in_state():
|
||||||
|
"""Messages nested inside a LangGraph-style state dict must be sanitized."""
|
||||||
|
msg = _FakeMessage("human", "buy signal")
|
||||||
|
state = {
|
||||||
|
"messages": [msg],
|
||||||
|
"investment_debate_state": {"history": [msg]},
|
||||||
|
"ticker": "AAPL",
|
||||||
|
}
|
||||||
|
result = ReportStore._sanitize(state)
|
||||||
|
assert result["ticker"] == "AAPL"
|
||||||
|
assert result["messages"] == [{"type": "human", "content": "buy signal"}]
|
||||||
|
debate = result["investment_debate_state"]["history"]
|
||||||
|
assert debate == [{"type": "human", "content": "buy signal"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_arbitrary_non_serializable_falls_back_to_str():
|
||||||
|
"""An arbitrary non-serializable object must fall back to str()."""
|
||||||
|
|
||||||
|
class _Weird:
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "weird_value"
|
||||||
|
|
||||||
|
result = ReportStore._sanitize(_Weird())
|
||||||
|
assert result == "weird_value"
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_json_with_message_objects_does_not_raise(report_store, tmp_reports):
|
||||||
|
"""_write_json must not raise when data contains message-like objects."""
|
||||||
|
msg = _FakeMessage("human", "test")
|
||||||
|
data = {"messages": [msg], "ticker": "TSLA"}
|
||||||
|
path = tmp_reports / "test_output.json"
|
||||||
|
written = report_store._write_json(path, data)
|
||||||
|
assert written.exists()
|
||||||
|
loaded = json.loads(written.read_text(encoding="utf-8"))
|
||||||
|
assert loaded["ticker"] == "TSLA"
|
||||||
|
assert loaded["messages"] == [{"type": "human", "content": "test"}]
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,35 @@ class ReportStore:
|
||||||
"""
|
"""
|
||||||
return self._base_dir / "daily" / date / "portfolio"
|
return self._base_dir / "daily" / date / "portfolio"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize(obj: Any) -> Any:
|
||||||
|
"""Recursively convert non-JSON-serializable objects to safe types.
|
||||||
|
|
||||||
|
Handles LangChain message objects (``HumanMessage``, ``AIMessage``,
|
||||||
|
etc.) that appear in LangGraph state dicts, as well as any other
|
||||||
|
arbitrary objects that are not natively JSON-serializable.
|
||||||
|
"""
|
||||||
|
if obj is None or isinstance(obj, (bool, int, float, str)):
|
||||||
|
return obj
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: ReportStore._sanitize(v) for k, v in obj.items()}
|
||||||
|
if isinstance(obj, (list, tuple)):
|
||||||
|
return [ReportStore._sanitize(item) for item in obj]
|
||||||
|
# LangChain BaseMessage objects expose .type and .content
|
||||||
|
if hasattr(obj, "type") and hasattr(obj, "content"):
|
||||||
|
try:
|
||||||
|
if hasattr(obj, "dict") and callable(obj.dict):
|
||||||
|
return ReportStore._sanitize(obj.dict())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {"type": str(obj.type), "content": str(obj.content)}
|
||||||
|
# Generic fallback: try a serialization probe first
|
||||||
|
try:
|
||||||
|
json.dumps(obj)
|
||||||
|
return obj
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
def _write_json(self, path: Path, data: dict[str, Any]) -> Path:
|
def _write_json(self, path: Path, data: dict[str, Any]) -> Path:
|
||||||
"""Write a dict to a JSON file, creating parent directories as needed.
|
"""Write a dict to a JSON file, creating parent directories as needed.
|
||||||
|
|
||||||
|
|
@ -79,7 +108,8 @@ class ReportStore:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
sanitized = self._sanitize(data)
|
||||||
|
path.write_text(json.dumps(sanitized, indent=2), encoding="utf-8")
|
||||||
return path
|
return path
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ReportStoreError(f"Failed to write {path}: {exc}") from exc
|
raise ReportStoreError(f"Failed to write {path}: {exc}") from exc
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue