feat: LangGraph checkpoint resume for crash recovery
Add SqliteSaver-based checkpointing so crashed analyses resume from the last successful graph node instead of restarting from scratch. - checkpointer.py: get_checkpointer(), thread_id(), has/clear_checkpoint() - --checkpoint flag (default: off for backward compatibility) - --clear-checkpoints flag to force fresh start - Per-ticker SQLite DBs for parallel worker safety - Logs 'Resuming from step N' vs 'Starting fresh' - Clears checkpoint on successful completion (no stale state) - Tests: crash resume + different date starts fresh
This commit is contained in:
parent
fa4d01c23a
commit
a4a973f083
14
cli/main.py
14
cli/main.py
|
|
@ -926,7 +926,7 @@ def format_tool_args(args, max_length=80) -> str:
|
|||
return result[:max_length - 3] + "..."
|
||||
return result
|
||||
|
||||
def run_analysis():
|
||||
def run_analysis(checkpoint: bool = False):
|
||||
# First get all user selections
|
||||
selections = get_user_selections()
|
||||
|
||||
|
|
@ -943,6 +943,7 @@ def run_analysis():
|
|||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||
config["anthropic_effort"] = selections.get("anthropic_effort")
|
||||
config["output_language"] = selections.get("output_language", "English")
|
||||
config["checkpoint_enabled"] = checkpoint
|
||||
|
||||
# Create stats callback handler for tracking LLM/tool calls
|
||||
stats_handler = StatsCallbackHandler()
|
||||
|
|
@ -1197,8 +1198,15 @@ def run_analysis():
|
|||
|
||||
|
||||
@app.command()
|
||||
def analyze():
|
||||
run_analysis()
|
||||
def analyze(
|
||||
checkpoint: bool = typer.Option(False, "--checkpoint", help="Enable checkpoint/resume: save state after each node so crashed runs can resume."),
|
||||
clear_checkpoints: bool = typer.Option(False, "--clear-checkpoints", help="Delete all saved checkpoints before running (force fresh start)."),
|
||||
):
|
||||
if clear_checkpoints:
|
||||
from tradingagents.graph.checkpointer import clear_all_checkpoints
|
||||
n = clear_all_checkpoints(DEFAULT_CONFIG["data_cache_dir"])
|
||||
console.print(f"[yellow]Cleared {n} checkpoint(s).[/yellow]")
|
||||
run_analysis(checkpoint=checkpoint)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -0,0 +1,147 @@
|
|||
"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node."""
|
||||
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
from langgraph.graph import END, StateGraph
|
||||
|
||||
from tradingagents.graph.checkpointer import (
|
||||
checkpoint_step,
|
||||
clear_checkpoint,
|
||||
get_checkpointer,
|
||||
has_checkpoint,
|
||||
thread_id,
|
||||
)
|
||||
|
||||
# Mutable flag to simulate crash on first run
|
||||
_should_crash = False
|
||||
|
||||
|
||||
class _SimpleState(TypedDict):
|
||||
count: int
|
||||
|
||||
|
||||
def _node_a(state: _SimpleState) -> dict:
|
||||
return {"count": state["count"] + 1}
|
||||
|
||||
|
||||
def _node_b(state: _SimpleState) -> dict:
|
||||
if _should_crash:
|
||||
raise RuntimeError("simulated mid-analysis crash")
|
||||
return {"count": state["count"] + 10}
|
||||
|
||||
|
||||
def _build_graph() -> StateGraph:
|
||||
builder = StateGraph(_SimpleState)
|
||||
builder.add_node("analyst", _node_a)
|
||||
builder.add_node("trader", _node_b)
|
||||
builder.set_entry_point("analyst")
|
||||
builder.add_edge("analyst", "trader")
|
||||
builder.add_edge("trader", END)
|
||||
return builder
|
||||
|
||||
|
||||
class TestCheckpointResume(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
self.ticker = "TEST"
|
||||
self.date = "2026-04-20"
|
||||
|
||||
def test_crash_and_resume(self):
|
||||
"""Crash at 'trader' node, then resume from checkpoint."""
|
||||
global _should_crash
|
||||
builder = _build_graph()
|
||||
tid = thread_id(self.ticker, self.date)
|
||||
cfg = {"configurable": {"thread_id": tid}}
|
||||
|
||||
# Run 1: crash at trader node
|
||||
_should_crash = True
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph.invoke({"count": 0}, config=cfg)
|
||||
|
||||
# Checkpoint should exist at step 1 (analyst completed)
|
||||
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||
step = checkpoint_step(self.tmpdir, self.ticker, self.date)
|
||||
self.assertEqual(step, 1)
|
||||
|
||||
# Run 2: resume — trader succeeds this time
|
||||
_should_crash = False
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
result = graph.invoke(None, config=cfg)
|
||||
|
||||
# analyst added 1, trader added 10 → 11
|
||||
self.assertEqual(result["count"], 11)
|
||||
|
||||
def test_clear_checkpoint_allows_fresh_start(self):
|
||||
"""After clearing, the graph starts from scratch."""
|
||||
global _should_crash
|
||||
builder = _build_graph()
|
||||
tid = thread_id(self.ticker, self.date)
|
||||
cfg = {"configurable": {"thread_id": tid}}
|
||||
|
||||
# Create a checkpoint by crashing
|
||||
_should_crash = True
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph.invoke({"count": 0}, config=cfg)
|
||||
|
||||
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||
|
||||
# Clear it
|
||||
clear_checkpoint(self.tmpdir, self.ticker, self.date)
|
||||
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||
|
||||
# Fresh run succeeds from scratch
|
||||
_should_crash = False
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
result = graph.invoke({"count": 0}, config=cfg)
|
||||
|
||||
self.assertEqual(result["count"], 11)
|
||||
|
||||
|
||||
def test_different_date_starts_fresh(self):
|
||||
"""A different date must NOT resume from an existing checkpoint."""
|
||||
global _should_crash
|
||||
builder = _build_graph()
|
||||
date2 = "2026-04-21"
|
||||
|
||||
# Run with date1 — crash to leave a checkpoint
|
||||
_should_crash = True
|
||||
tid1 = thread_id(self.ticker, self.date)
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}})
|
||||
|
||||
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||
|
||||
# date2 should have no checkpoint
|
||||
self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2))
|
||||
|
||||
# Run with date2 — should start fresh and succeed
|
||||
_should_crash = False
|
||||
tid2 = thread_id(self.ticker, date2)
|
||||
self.assertNotEqual(tid1, tid2)
|
||||
|
||||
with get_checkpointer(self.tmpdir, self.ticker) as saver:
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}})
|
||||
|
||||
# Fresh run: analyst +1, trader +10 = 11
|
||||
self.assertEqual(result["count"], 11)
|
||||
|
||||
# Original date checkpoint still exists (untouched)
|
||||
self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -15,6 +15,9 @@ DEFAULT_CONFIG = {
|
|||
"google_thinking_level": None, # "high", "minimal", etc.
|
||||
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||
"anthropic_effort": None, # "high", "medium", "low"
|
||||
# Checkpoint/resume: when True, LangGraph saves state after each node
|
||||
# so a crashed run can resume from the last successful step.
|
||||
"checkpoint_enabled": False,
|
||||
# Output language for analyst reports and final decision
|
||||
# Internal agent debate stays in English for reasoning quality
|
||||
"output_language": "English",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,87 @@
|
|||
"""LangGraph checkpoint support for resumable analysis runs.
|
||||
|
||||
Per-ticker SQLite databases so concurrent tickers don't contend.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
|
||||
|
||||
def _db_path(data_dir: str | Path, ticker: str) -> Path:
|
||||
"""Return the SQLite checkpoint DB path for a ticker."""
|
||||
p = Path(data_dir) / "checkpoints"
|
||||
p.mkdir(parents=True, exist_ok=True)
|
||||
return p / f"{ticker.upper()}.db"
|
||||
|
||||
|
||||
def thread_id(ticker: str, date: str) -> str:
|
||||
"""Deterministic thread ID for a ticker+date pair."""
|
||||
return hashlib.sha256(f"{ticker.upper()}:{date}".encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver, None, None]:
|
||||
"""Context manager yielding a SqliteSaver backed by a per-ticker DB."""
|
||||
db = _db_path(data_dir, ticker)
|
||||
conn = sqlite3.connect(str(db), check_same_thread=False)
|
||||
try:
|
||||
saver = SqliteSaver(conn)
|
||||
saver.setup()
|
||||
yield saver
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool:
|
||||
"""Check whether a resumable checkpoint exists for ticker+date."""
|
||||
return checkpoint_step(data_dir, ticker, date) is not None
|
||||
|
||||
|
||||
def checkpoint_step(data_dir: str | Path, ticker: str, date: str) -> int | None:
|
||||
"""Return the step number of the latest checkpoint, or None if none exists."""
|
||||
db = _db_path(data_dir, ticker)
|
||||
if not db.exists():
|
||||
return None
|
||||
tid = thread_id(ticker, date)
|
||||
with get_checkpointer(data_dir, ticker) as saver:
|
||||
config = {"configurable": {"thread_id": tid}}
|
||||
cp = saver.get_tuple(config)
|
||||
if cp is None:
|
||||
return None
|
||||
return cp.metadata.get("step")
|
||||
|
||||
|
||||
def clear_all_checkpoints(data_dir: str | Path) -> int:
|
||||
"""Remove all checkpoint DBs. Returns number of files deleted."""
|
||||
cp_dir = Path(data_dir) / "checkpoints"
|
||||
if not cp_dir.exists():
|
||||
return 0
|
||||
dbs = list(cp_dir.glob("*.db"))
|
||||
for db in dbs:
|
||||
db.unlink()
|
||||
return len(dbs)
|
||||
|
||||
|
||||
def clear_checkpoint(data_dir: str | Path, ticker: str, date: str) -> None:
|
||||
"""Remove checkpoint for a specific ticker+date (delete the whole DB if it's the only thread)."""
|
||||
db = _db_path(data_dir, ticker)
|
||||
if not db.exists():
|
||||
return
|
||||
tid = thread_id(ticker, date)
|
||||
conn = sqlite3.connect(str(db))
|
||||
try:
|
||||
# Delete writes and checkpoints for this thread
|
||||
for table in ("writes", "checkpoints"):
|
||||
conn.execute(f"DELETE FROM {table} WHERE thread_id = ?", (tid,))
|
||||
conn.commit()
|
||||
except sqlite3.OperationalError:
|
||||
pass # table doesn't exist yet
|
||||
finally:
|
||||
conn.close()
|
||||
|
|
@ -197,5 +197,4 @@ class GraphSetup:
|
|||
|
||||
workflow.add_edge("Portfolio Manager", END)
|
||||
|
||||
# Compile and return
|
||||
return workflow.compile()
|
||||
return workflow
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# TradingAgents/graph/trading_graph.py
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
|
@ -33,6 +34,7 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_global_news
|
||||
)
|
||||
|
||||
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .setup import GraphSetup
|
||||
from .propagation import Propagator
|
||||
|
|
@ -43,6 +45,8 @@ from .signal_processing import SignalProcessor
|
|||
class TradingAgentsGraph:
|
||||
"""Main class that orchestrates the trading agents framework."""
|
||||
|
||||
logger = logging.getLogger("tradingagents.graph")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
|
|
@ -129,7 +133,9 @@ class TradingAgentsGraph:
|
|||
self.log_states_dict = {} # date to full state dict
|
||||
|
||||
# Set up the graph
|
||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||||
self.workflow = self.graph_setup.setup_graph(selected_analysts)
|
||||
self.graph = self.workflow.compile()
|
||||
self._checkpointer_ctx = None
|
||||
|
||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||
"""Get provider-specific kwargs for LLM client creation."""
|
||||
|
|
@ -194,12 +200,44 @@ class TradingAgentsGraph:
|
|||
|
||||
self.ticker = company_name
|
||||
|
||||
# Recompile with checkpointer if enabled
|
||||
if self.config.get("checkpoint_enabled"):
|
||||
self._checkpointer_ctx = get_checkpointer(
|
||||
self.config["data_cache_dir"], company_name
|
||||
)
|
||||
saver = self._checkpointer_ctx.__enter__()
|
||||
self.graph = self.workflow.compile(checkpointer=saver)
|
||||
|
||||
# Log resume vs fresh start
|
||||
step = checkpoint_step(
|
||||
self.config["data_cache_dir"], company_name, str(trade_date)
|
||||
)
|
||||
if step is not None:
|
||||
self.logger.info("Resuming from step %d for %s on %s", step, company_name, trade_date)
|
||||
else:
|
||||
self.logger.info("Starting fresh for %s on %s", company_name, trade_date)
|
||||
|
||||
try:
|
||||
return self._run_graph(company_name, trade_date)
|
||||
finally:
|
||||
if self._checkpointer_ctx is not None:
|
||||
self._checkpointer_ctx.__exit__(None, None, None)
|
||||
self._checkpointer_ctx = None
|
||||
self.graph = self.workflow.compile()
|
||||
|
||||
def _run_graph(self, company_name, trade_date):
|
||||
"""Execute the graph and return results."""
|
||||
# Initialize state
|
||||
init_agent_state = self.propagator.create_initial_state(
|
||||
company_name, trade_date
|
||||
)
|
||||
args = self.propagator.get_graph_args()
|
||||
|
||||
# Inject thread_id so same ticker+date resumes, different date starts fresh
|
||||
if self.config.get("checkpoint_enabled"):
|
||||
tid = thread_id(company_name, str(trade_date))
|
||||
args.setdefault("config", {}).setdefault("configurable", {})["thread_id"] = tid
|
||||
|
||||
if self.debug:
|
||||
# Debug mode with tracing
|
||||
trace = []
|
||||
|
|
@ -221,6 +259,10 @@ class TradingAgentsGraph:
|
|||
# Log state
|
||||
self._log_state(trade_date, final_state)
|
||||
|
||||
# Clear checkpoint on successful completion to avoid stale state
|
||||
if self.config.get("checkpoint_enabled"):
|
||||
clear_checkpoint(self.config["data_cache_dir"], company_name, str(trade_date))
|
||||
|
||||
# Return decision and processed signal
|
||||
return final_state, self.process_signal(final_state["final_trade_decision"])
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue