diff --git a/tests/test_checkpoint_resume.py b/tests/test_checkpoint_resume.py new file mode 100644 index 00000000..8ca46704 --- /dev/null +++ b/tests/test_checkpoint_resume.py @@ -0,0 +1,112 @@ +"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node.""" + +import sqlite3 +import tempfile +import unittest +from pathlib import Path +from typing import TypedDict + +from langgraph.checkpoint.sqlite import SqliteSaver +from langgraph.graph import END, StateGraph + +from tradingagents.graph.checkpointer import ( + checkpoint_step, + clear_checkpoint, + get_checkpointer, + has_checkpoint, + thread_id, +) + +# Mutable flag to simulate crash on first run +_should_crash = False + + +class _SimpleState(TypedDict): + count: int + + +def _node_a(state: _SimpleState) -> dict: + return {"count": state["count"] + 1} + + +def _node_b(state: _SimpleState) -> dict: + if _should_crash: + raise RuntimeError("simulated mid-analysis crash") + return {"count": state["count"] + 10} + + +def _build_graph() -> StateGraph: + builder = StateGraph(_SimpleState) + builder.add_node("analyst", _node_a) + builder.add_node("trader", _node_b) + builder.set_entry_point("analyst") + builder.add_edge("analyst", "trader") + builder.add_edge("trader", END) + return builder + + +class TestCheckpointResume(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.ticker = "TEST" + self.date = "2026-04-20" + + def test_crash_and_resume(self): + """Crash at 'trader' node, then resume from checkpoint.""" + global _should_crash + builder = _build_graph() + tid = thread_id(self.ticker, self.date) + cfg = {"configurable": {"thread_id": tid}} + + # Run 1: crash at trader node + _should_crash = True + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + with self.assertRaises(RuntimeError): + graph.invoke({"count": 0}, config=cfg) + + # Checkpoint should exist at step 1 (analyst completed) + self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) + step = checkpoint_step(self.tmpdir, self.ticker, self.date) + self.assertEqual(step, 1) + + # Run 2: resume — trader succeeds this time + _should_crash = False + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + result = graph.invoke(None, config=cfg) + + # analyst added 1, trader added 10 → 11 + self.assertEqual(result["count"], 11) + + def test_clear_checkpoint_allows_fresh_start(self): + """After clearing, the graph starts from scratch.""" + global _should_crash + builder = _build_graph() + tid = thread_id(self.ticker, self.date) + cfg = {"configurable": {"thread_id": tid}} + + # Create a checkpoint by crashing + _should_crash = True + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + with self.assertRaises(RuntimeError): + graph.invoke({"count": 0}, config=cfg) + + self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) + + # Clear it + clear_checkpoint(self.tmpdir, self.ticker, self.date) + self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date)) + + # Fresh run succeeds from scratch + _should_crash = False + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + result = graph.invoke({"count": 0}, config=cfg) + + self.assertEqual(result["count"], 11) + + +if __name__ == "__main__": + unittest.main()