"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node.""" import sqlite3 import tempfile import unittest from pathlib import Path from typing import TypedDict from langgraph.checkpoint.sqlite import SqliteSaver from langgraph.graph import END, StateGraph from tradingagents.graph.checkpointer import ( checkpoint_step, clear_checkpoint, get_checkpointer, has_checkpoint, thread_id, ) # Mutable flag to simulate crash on first run _should_crash = False class _SimpleState(TypedDict): count: int def _node_a(state: _SimpleState) -> dict: return {"count": state["count"] + 1} def _node_b(state: _SimpleState) -> dict: if _should_crash: raise RuntimeError("simulated mid-analysis crash") return {"count": state["count"] + 10} def _build_graph() -> StateGraph: builder = StateGraph(_SimpleState) builder.add_node("analyst", _node_a) builder.add_node("trader", _node_b) builder.set_entry_point("analyst") builder.add_edge("analyst", "trader") builder.add_edge("trader", END) return builder class TestCheckpointResume(unittest.TestCase): def setUp(self): self.tmpdir = tempfile.mkdtemp() self.ticker = "TEST" self.date = "2026-04-20" def test_crash_and_resume(self): """Crash at 'trader' node, then resume from checkpoint.""" global _should_crash builder = _build_graph() tid = thread_id(self.ticker, self.date) cfg = {"configurable": {"thread_id": tid}} # Run 1: crash at trader node _should_crash = True with get_checkpointer(self.tmpdir, self.ticker) as saver: graph = builder.compile(checkpointer=saver) with self.assertRaises(RuntimeError): graph.invoke({"count": 0}, config=cfg) # Checkpoint should exist at step 1 (analyst completed) self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) step = checkpoint_step(self.tmpdir, self.ticker, self.date) self.assertEqual(step, 1) # Run 2: resume — trader succeeds this time _should_crash = False with get_checkpointer(self.tmpdir, self.ticker) as saver: graph = builder.compile(checkpointer=saver) result = graph.invoke(None, config=cfg) # analyst added 1, trader added 10 → 11 self.assertEqual(result["count"], 11) def test_clear_checkpoint_allows_fresh_start(self): """After clearing, the graph starts from scratch.""" global _should_crash builder = _build_graph() tid = thread_id(self.ticker, self.date) cfg = {"configurable": {"thread_id": tid}} # Create a checkpoint by crashing _should_crash = True with get_checkpointer(self.tmpdir, self.ticker) as saver: graph = builder.compile(checkpointer=saver) with self.assertRaises(RuntimeError): graph.invoke({"count": 0}, config=cfg) self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) # Clear it clear_checkpoint(self.tmpdir, self.ticker, self.date) self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date)) # Fresh run succeeds from scratch _should_crash = False with get_checkpointer(self.tmpdir, self.ticker) as saver: graph = builder.compile(checkpointer=saver) result = graph.invoke({"count": 0}, config=cfg) self.assertEqual(result["count"], 11) if __name__ == "__main__": unittest.main()