feat(027-checkpoint-resume-contrib): verify crash mid-analysis resumes from last node
This commit is contained in:
parent
1aa2acfdb1
commit
3c29d9baab
|
|
@ -0,0 +1,112 @@
|
|||
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
|
||||
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
from langgraph.graph import END, StateGraph
|
||||
|
||||
from tradingagents.graph.checkpointer import (
|
||||
checkpoint_step,
|
||||
clear_checkpoint,
|
||||
get_checkpointer,
|
||||
has_checkpoint,
|
||||
thread_id,
|
||||
)
|
||||
|
||||
# Mutable flag to simulate crash on first run
|
||||
_should_crash = False
|
||||
|
||||
|
||||
class _SimpleState(TypedDict):
|
||||
count: int
|
||||
|
||||
|
||||
def _node_a(state: _SimpleState) -> dict:
|
||||
return {"count": state["count"] + 1}
|
||||
|
||||
|
||||
def _node_b(state: _SimpleState) -> dict:
|
||||
if _should_crash:
|
||||
raise RuntimeError("simulated mid-analysis crash")
|
||||
return {"count": state["count"] + 10}
|
||||
|
||||
|
||||
def _build_graph() -> StateGraph:
|
||||
builder = StateGraph(_SimpleState)
|
||||
builder.add_node("analyst", _node_a)
|
||||
builder.add_node("trader", _node_b)
|
||||
builder.set_entry_point("analyst")
|
||||
builder.add_edge("analyst", "trader")
|
||||
builder.add_edge("trader", END)
|
||||
return builder
|
||||
|
||||
|
||||
class TestCheckpointResume(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
self.ticker = "TEST"
|
||||
self.date = "2026-04-20"
|
||||
|
||||
def test_crash_and_resume(self):
|
||||
"""Crash at 'trader' node, then resume from checkpoint."""
|
||||
global _should_crash
|
||||
builder = _build_graph()
|
||||
tid = thread_id(self.ticker, self.date)
|
||||
cfg = {"configurable": {"thread_id": tid}}
|
||||
|
||||
# Run 1: crash at trader node
|
||||
_should_crash = True
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph.invoke({"count": 0}, config=cfg)
|
||||
|
||||
# Checkpoint should exist at step 1 (analyst completed)
|
||||
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||
step = checkpoint_step(self.tmpdir, self.ticker, self.date)
|
||||
self.assertEqual(step, 1)
|
||||
|
||||
# Run 2: resume — trader succeeds this time
|
||||
_should_crash = False
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
result = graph.invoke(None, config=cfg)
|
||||
|
||||
# analyst added 1, trader added 10 → 11
|
||||
self.assertEqual(result["count"], 11)
|
||||
|
||||
def test_clear_checkpoint_allows_fresh_start(self):
|
||||
"""After clearing, the graph starts from scratch."""
|
||||
global _should_crash
|
||||
builder = _build_graph()
|
||||
tid = thread_id(self.ticker, self.date)
|
||||
cfg = {"configurable": {"thread_id": tid}}
|
||||
|
||||
# Create a checkpoint by crashing
|
||||
_should_crash = True
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph.invoke({"count": 0}, config=cfg)
|
||||
|
||||
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||
|
||||
# Clear it
|
||||
clear_checkpoint(self.tmpdir, self.ticker, self.date)
|
||||
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||
|
||||
# Fresh run succeeds from scratch
|
||||
_should_crash = False
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
result = graph.invoke({"count": 0}, config=cfg)
|
||||
|
||||
self.assertEqual(result["count"], 11)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue