feat(027-checkpoint-resume-contrib): verify crash mid-analysis resumes from last node

This commit is contained in:
Clayton Brown 2026-04-20 20:32:23 +10:00
parent 1aa2acfdb1
commit 3c29d9baab
1 changed files with 112 additions and 0 deletions

View File

@ -0,0 +1,112 @@
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
import sqlite3
import tempfile
import unittest
from pathlib import Path
from typing import TypedDict
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import END, StateGraph
from tradingagents.graph.checkpointer import (
checkpoint_step,
clear_checkpoint,
get_checkpointer,
has_checkpoint,
thread_id,
)
# Mutable flag to simulate crash on first run
_should_crash = False
class _SimpleState(TypedDict):
count: int
def _node_a(state: _SimpleState) -> dict:
return {"count": state["count"] + 1}
def _node_b(state: _SimpleState) -> dict:
if _should_crash:
raise RuntimeError("simulated mid-analysis crash")
return {"count": state["count"] + 10}
def _build_graph() -> StateGraph:
builder = StateGraph(_SimpleState)
builder.add_node("analyst", _node_a)
builder.add_node("trader", _node_b)
builder.set_entry_point("analyst")
builder.add_edge("analyst", "trader")
builder.add_edge("trader", END)
return builder
class TestCheckpointResume(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.ticker = "TEST"
self.date = "2026-04-20"
def test_crash_and_resume(self):
"""Crash at 'trader' node, then resume from checkpoint."""
global _should_crash
builder = _build_graph()
tid = thread_id(self.ticker, self.date)
cfg = {"configurable": {"thread_id": tid}}
# Run 1: crash at trader node
_should_crash = True
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
with self.assertRaises(RuntimeError):
graph.invoke({"count": 0}, config=cfg)
# Checkpoint should exist at step 1 (analyst completed)
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
step = checkpoint_step(self.tmpdir, self.ticker, self.date)
self.assertEqual(step, 1)
# Run 2: resume — trader succeeds this time
_should_crash = False
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
result = graph.invoke(None, config=cfg)
# analyst added 1, trader added 10 → 11
self.assertEqual(result["count"], 11)
def test_clear_checkpoint_allows_fresh_start(self):
"""After clearing, the graph starts from scratch."""
global _should_crash
builder = _build_graph()
tid = thread_id(self.ticker, self.date)
cfg = {"configurable": {"thread_id": tid}}
# Create a checkpoint by crashing
_should_crash = True
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
with self.assertRaises(RuntimeError):
graph.invoke({"count": 0}, config=cfg)
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
# Clear it
clear_checkpoint(self.tmpdir, self.ticker, self.date)
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date))
# Fresh run succeeds from scratch
_should_crash = False
with get_checkpointer(self.tmpdir, self.ticker) as saver:
graph = builder.compile(checkpointer=saver)
result = graph.invoke({"count": 0}, config=cfg)
self.assertEqual(result["count"], 11)
if __name__ == "__main__":
unittest.main()