148 lines
4.9 KiB
Python
148 lines
4.9 KiB
Python
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
|
|
|
|
import sqlite3
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
from typing import TypedDict
|
|
|
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
from langgraph.graph import END, StateGraph
|
|
|
|
from tradingagents.graph.checkpointer import (
|
|
checkpoint_step,
|
|
clear_checkpoint,
|
|
get_checkpointer,
|
|
has_checkpoint,
|
|
thread_id,
|
|
)
|
|
|
|
# Mutable flag to simulate crash on first run
|
|
_should_crash = False
|
|
|
|
|
|
class _SimpleState(TypedDict):
|
|
count: int
|
|
|
|
|
|
def _node_a(state: _SimpleState) -> dict:
|
|
return {"count": state["count"] + 1}
|
|
|
|
|
|
def _node_b(state: _SimpleState) -> dict:
|
|
if _should_crash:
|
|
raise RuntimeError("simulated mid-analysis crash")
|
|
return {"count": state["count"] + 10}
|
|
|
|
|
|
def _build_graph() -> StateGraph:
|
|
builder = StateGraph(_SimpleState)
|
|
builder.add_node("analyst", _node_a)
|
|
builder.add_node("trader", _node_b)
|
|
builder.set_entry_point("analyst")
|
|
builder.add_edge("analyst", "trader")
|
|
builder.add_edge("trader", END)
|
|
return builder
|
|
|
|
|
|
class TestCheckpointResume(unittest.TestCase):
|
|
def setUp(self):
|
|
self.tmpdir = tempfile.mkdtemp()
|
|
self.ticker = "TEST"
|
|
self.date = "2026-04-20"
|
|
|
|
def test_crash_and_resume(self):
|
|
"""Crash at 'trader' node, then resume from checkpoint."""
|
|
global _should_crash
|
|
builder = _build_graph()
|
|
tid = thread_id(self.ticker, self.date)
|
|
cfg = {"configurable": {"thread_id": tid}}
|
|
|
|
# Run 1: crash at trader node
|
|
_should_crash = True
|
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
|
graph = builder.compile(checkpointer=saver)
|
|
with self.assertRaises(RuntimeError):
|
|
graph.invoke({"count": 0}, config=cfg)
|
|
|
|
# Checkpoint should exist at step 1 (analyst completed)
|
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
|
step = checkpoint_step(self.tmpdir, self.ticker, self.date)
|
|
self.assertEqual(step, 1)
|
|
|
|
# Run 2: resume — trader succeeds this time
|
|
_should_crash = False
|
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
|
graph = builder.compile(checkpointer=saver)
|
|
result = graph.invoke(None, config=cfg)
|
|
|
|
# analyst added 1, trader added 10 → 11
|
|
self.assertEqual(result["count"], 11)
|
|
|
|
def test_clear_checkpoint_allows_fresh_start(self):
|
|
"""After clearing, the graph starts from scratch."""
|
|
global _should_crash
|
|
builder = _build_graph()
|
|
tid = thread_id(self.ticker, self.date)
|
|
cfg = {"configurable": {"thread_id": tid}}
|
|
|
|
# Create a checkpoint by crashing
|
|
_should_crash = True
|
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
|
graph = builder.compile(checkpointer=saver)
|
|
with self.assertRaises(RuntimeError):
|
|
graph.invoke({"count": 0}, config=cfg)
|
|
|
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
|
|
|
# Clear it
|
|
clear_checkpoint(self.tmpdir, self.ticker, self.date)
|
|
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
|
|
|
# Fresh run succeeds from scratch
|
|
_should_crash = False
|
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
|
graph = builder.compile(checkpointer=saver)
|
|
result = graph.invoke({"count": 0}, config=cfg)
|
|
|
|
self.assertEqual(result["count"], 11)
|
|
|
|
|
|
def test_different_date_starts_fresh(self):
|
|
"""A different date must NOT resume from an existing checkpoint."""
|
|
global _should_crash
|
|
builder = _build_graph()
|
|
date2 = "2026-04-21"
|
|
|
|
# Run with date1 — crash to leave a checkpoint
|
|
_should_crash = True
|
|
tid1 = thread_id(self.ticker, self.date)
|
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
|
graph = builder.compile(checkpointer=saver)
|
|
with self.assertRaises(RuntimeError):
|
|
graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}})
|
|
|
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
|
|
|
# date2 should have no checkpoint
|
|
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2))
|
|
|
|
# Run with date2 — should start fresh and succeed
|
|
_should_crash = False
|
|
tid2 = thread_id(self.ticker, date2)
|
|
self.assertNotEqual(tid1, tid2)
|
|
|
|
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
|
graph = builder.compile(checkpointer=saver)
|
|
result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}})
|
|
|
|
# Fresh run: analyst +1, trader +10 = 11
|
|
self.assertEqual(result["count"], 11)
|
|
|
|
# Original date checkpoint still exists (untouched)
|
|
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|