From a4a973f083cfcc3046cf6e3351599d4bf198b093 Mon Sep 17 00:00:00 2001 From: Clayton Brown Date: Mon, 20 Apr 2026 22:53:04 +1000 Subject: [PATCH] feat: LangGraph checkpoint resume for crash recovery Add SqliteSaver-based checkpointing so crashed analyses resume from the last successful graph node instead of restarting from scratch. - checkpointer.py: get_checkpointer(), thread_id(), has/clear_checkpoint() - --checkpoint flag (default: off for backward compatibility) - --clear-checkpoints flag to force fresh start - Per-ticker SQLite DBs for parallel worker safety - Logs 'Resuming from step N' vs 'Starting fresh' - Clears checkpoint on successful completion (no stale state) - Tests: crash resume + different date starts fresh --- cli/main.py | 14 ++- tests/test_checkpoint_resume.py | 147 +++++++++++++++++++++++++++ tradingagents/default_config.py | 3 + tradingagents/graph/checkpointer.py | 87 ++++++++++++++++ tradingagents/graph/setup.py | 3 +- tradingagents/graph/trading_graph.py | 44 +++++++- 6 files changed, 292 insertions(+), 6 deletions(-) create mode 100644 tests/test_checkpoint_resume.py create mode 100644 tradingagents/graph/checkpointer.py diff --git a/cli/main.py b/cli/main.py index 33d110fb..81c04aa7 100644 --- a/cli/main.py +++ b/cli/main.py @@ -926,7 +926,7 @@ def format_tool_args(args, max_length=80) -> str: return result[:max_length - 3] + "..." return result -def run_analysis(): +def run_analysis(checkpoint: bool = False): # First get all user selections selections = get_user_selections() @@ -943,6 +943,7 @@ def run_analysis(): config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") config["anthropic_effort"] = selections.get("anthropic_effort") config["output_language"] = selections.get("output_language", "English") + config["checkpoint_enabled"] = checkpoint # Create stats callback handler for tracking LLM/tool calls stats_handler = StatsCallbackHandler() @@ -1197,8 +1198,15 @@ def run_analysis(): @app.command() -def analyze(): - run_analysis() +def analyze( + checkpoint: bool = typer.Option(False, "--checkpoint", help="Enable checkpoint/resume: save state after each node so crashed runs can resume."), + clear_checkpoints: bool = typer.Option(False, "--clear-checkpoints", help="Delete all saved checkpoints before running (force fresh start)."), +): + if clear_checkpoints: + from tradingagents.graph.checkpointer import clear_all_checkpoints + n = clear_all_checkpoints(DEFAULT_CONFIG["data_cache_dir"]) + console.print(f"[yellow]Cleared {n} checkpoint(s).[/yellow]") + run_analysis(checkpoint=checkpoint) if __name__ == "__main__": diff --git a/tests/test_checkpoint_resume.py b/tests/test_checkpoint_resume.py new file mode 100644 index 00000000..6f2692bd --- /dev/null +++ b/tests/test_checkpoint_resume.py @@ -0,0 +1,147 @@ +"""Test checkpoint resume: crash mid-analysis, re-run resumes from last node.""" + +import sqlite3 +import tempfile +import unittest +from pathlib import Path +from typing import TypedDict + +from langgraph.checkpoint.sqlite import SqliteSaver +from langgraph.graph import END, StateGraph + +from tradingagents.graph.checkpointer import ( + checkpoint_step, + clear_checkpoint, + get_checkpointer, + has_checkpoint, + thread_id, +) + +# Mutable flag to simulate crash on first run +_should_crash = False + + +class _SimpleState(TypedDict): + count: int + + +def _node_a(state: _SimpleState) -> dict: + return {"count": state["count"] + 1} + + +def _node_b(state: _SimpleState) -> dict: + if _should_crash: + raise RuntimeError("simulated mid-analysis crash") + return {"count": state["count"] + 10} + + +def _build_graph() -> StateGraph: + builder = StateGraph(_SimpleState) + builder.add_node("analyst", _node_a) + builder.add_node("trader", _node_b) + builder.set_entry_point("analyst") + builder.add_edge("analyst", "trader") + builder.add_edge("trader", END) + return builder + + +class TestCheckpointResume(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.ticker = "TEST" + self.date = "2026-04-20" + + def test_crash_and_resume(self): + """Crash at 'trader' node, then resume from checkpoint.""" + global _should_crash + builder = _build_graph() + tid = thread_id(self.ticker, self.date) + cfg = {"configurable": {"thread_id": tid}} + + # Run 1: crash at trader node + _should_crash = True + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + with self.assertRaises(RuntimeError): + graph.invoke({"count": 0}, config=cfg) + + # Checkpoint should exist at step 1 (analyst completed) + self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) + step = checkpoint_step(self.tmpdir, self.ticker, self.date) + self.assertEqual(step, 1) + + # Run 2: resume — trader succeeds this time + _should_crash = False + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + result = graph.invoke(None, config=cfg) + + # analyst added 1, trader added 10 → 11 + self.assertEqual(result["count"], 11) + + def test_clear_checkpoint_allows_fresh_start(self): + """After clearing, the graph starts from scratch.""" + global _should_crash + builder = _build_graph() + tid = thread_id(self.ticker, self.date) + cfg = {"configurable": {"thread_id": tid}} + + # Create a checkpoint by crashing + _should_crash = True + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + with self.assertRaises(RuntimeError): + graph.invoke({"count": 0}, config=cfg) + + self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) + + # Clear it + clear_checkpoint(self.tmpdir, self.ticker, self.date) + self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, self.date)) + + # Fresh run succeeds from scratch + _should_crash = False + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + result = graph.invoke({"count": 0}, config=cfg) + + self.assertEqual(result["count"], 11) + + + def test_different_date_starts_fresh(self): + """A different date must NOT resume from an existing checkpoint.""" + global _should_crash + builder = _build_graph() + date2 = "2026-04-21" + + # Run with date1 — crash to leave a checkpoint + _should_crash = True + tid1 = thread_id(self.ticker, self.date) + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + with self.assertRaises(RuntimeError): + graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid1}}) + + self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) + + # date2 should have no checkpoint + self.assertFalse(has_checkpoint(self.tmpdir, self.ticker, date2)) + + # Run with date2 — should start fresh and succeed + _should_crash = False + tid2 = thread_id(self.ticker, date2) + self.assertNotEqual(tid1, tid2) + + with get_checkpointer(self.tmpdir, self.ticker) as saver: + graph = builder.compile(checkpointer=saver) + result = graph.invoke({"count": 0}, config={"configurable": {"thread_id": tid2}}) + + # Fresh run: analyst +1, trader +10 = 11 + self.assertEqual(result["count"], 11) + + # Original date checkpoint still exists (untouched) + self.assertTrue(has_checkpoint(self.tmpdir, self.ticker, self.date)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index a9b75e4b..c0c441ea 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -15,6 +15,9 @@ DEFAULT_CONFIG = { "google_thinking_level": None, # "high", "minimal", etc. "openai_reasoning_effort": None, # "medium", "high", "low" "anthropic_effort": None, # "high", "medium", "low" + # Checkpoint/resume: when True, LangGraph saves state after each node + # so a crashed run can resume from the last successful step. + "checkpoint_enabled": False, # Output language for analyst reports and final decision # Internal agent debate stays in English for reasoning quality "output_language": "English", diff --git a/tradingagents/graph/checkpointer.py b/tradingagents/graph/checkpointer.py new file mode 100644 index 00000000..fc02e06b --- /dev/null +++ b/tradingagents/graph/checkpointer.py @@ -0,0 +1,87 @@ +"""LangGraph checkpoint support for resumable analysis runs. + +Per-ticker SQLite databases so concurrent tickers don't contend. +""" + +from __future__ import annotations + +import hashlib +import sqlite3 +from contextlib import contextmanager +from pathlib import Path +from typing import Generator + +from langgraph.checkpoint.sqlite import SqliteSaver + + +def _db_path(data_dir: str | Path, ticker: str) -> Path: + """Return the SQLite checkpoint DB path for a ticker.""" + p = Path(data_dir) / "checkpoints" + p.mkdir(parents=True, exist_ok=True) + return p / f"{ticker.upper()}.db" + + +def thread_id(ticker: str, date: str) -> str: + """Deterministic thread ID for a ticker+date pair.""" + return hashlib.sha256(f"{ticker.upper()}:{date}".encode()).hexdigest()[:16] + + +@contextmanager +def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver, None, None]: + """Context manager yielding a SqliteSaver backed by a per-ticker DB.""" + db = _db_path(data_dir, ticker) + conn = sqlite3.connect(str(db), check_same_thread=False) + try: + saver = SqliteSaver(conn) + saver.setup() + yield saver + finally: + conn.close() + + +def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool: + """Check whether a resumable checkpoint exists for ticker+date.""" + return checkpoint_step(data_dir, ticker, date) is not None + + +def checkpoint_step(data_dir: str | Path, ticker: str, date: str) -> int | None: + """Return the step number of the latest checkpoint, or None if none exists.""" + db = _db_path(data_dir, ticker) + if not db.exists(): + return None + tid = thread_id(ticker, date) + with get_checkpointer(data_dir, ticker) as saver: + config = {"configurable": {"thread_id": tid}} + cp = saver.get_tuple(config) + if cp is None: + return None + return cp.metadata.get("step") + + +def clear_all_checkpoints(data_dir: str | Path) -> int: + """Remove all checkpoint DBs. Returns number of files deleted.""" + cp_dir = Path(data_dir) / "checkpoints" + if not cp_dir.exists(): + return 0 + dbs = list(cp_dir.glob("*.db")) + for db in dbs: + db.unlink() + return len(dbs) + + +def clear_checkpoint(data_dir: str | Path, ticker: str, date: str) -> None: + """Remove checkpoint for a specific ticker+date (delete the whole DB if it's the only thread).""" + db = _db_path(data_dir, ticker) + if not db.exists(): + return + tid = thread_id(ticker, date) + conn = sqlite3.connect(str(db)) + try: + # Delete writes and checkpoints for this thread + for table in ("writes", "checkpoints"): + conn.execute(f"DELETE FROM {table} WHERE thread_id = ?", (tid,)) + conn.commit() + except sqlite3.OperationalError: + pass # table doesn't exist yet + finally: + conn.close() diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index ae90489c..c0331c51 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -197,5 +197,4 @@ class GraphSetup: workflow.add_edge("Portfolio Manager", END) - # Compile and return - return workflow.compile() + return workflow diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 78bc13e5..febb5f23 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -1,5 +1,6 @@ # TradingAgents/graph/trading_graph.py +import logging import os from pathlib import Path import json @@ -33,6 +34,7 @@ from tradingagents.agents.utils.agent_utils import ( get_global_news ) +from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id from .conditional_logic import ConditionalLogic from .setup import GraphSetup from .propagation import Propagator @@ -43,6 +45,8 @@ from .signal_processing import SignalProcessor class TradingAgentsGraph: """Main class that orchestrates the trading agents framework.""" + logger = logging.getLogger("tradingagents.graph") + def __init__( self, selected_analysts=["market", "social", "news", "fundamentals"], @@ -129,7 +133,9 @@ class TradingAgentsGraph: self.log_states_dict = {} # date to full state dict # Set up the graph - self.graph = self.graph_setup.setup_graph(selected_analysts) + self.workflow = self.graph_setup.setup_graph(selected_analysts) + self.graph = self.workflow.compile() + self._checkpointer_ctx = None def _get_provider_kwargs(self) -> Dict[str, Any]: """Get provider-specific kwargs for LLM client creation.""" @@ -194,12 +200,44 @@ class TradingAgentsGraph: self.ticker = company_name + # Recompile with checkpointer if enabled + if self.config.get("checkpoint_enabled"): + self._checkpointer_ctx = get_checkpointer( + self.config["data_cache_dir"], company_name + ) + saver = self._checkpointer_ctx.__enter__() + self.graph = self.workflow.compile(checkpointer=saver) + + # Log resume vs fresh start + step = checkpoint_step( + self.config["data_cache_dir"], company_name, str(trade_date) + ) + if step is not None: + self.logger.info("Resuming from step %d for %s on %s", step, company_name, trade_date) + else: + self.logger.info("Starting fresh for %s on %s", company_name, trade_date) + + try: + return self._run_graph(company_name, trade_date) + finally: + if self._checkpointer_ctx is not None: + self._checkpointer_ctx.__exit__(None, None, None) + self._checkpointer_ctx = None + self.graph = self.workflow.compile() + + def _run_graph(self, company_name, trade_date): + """Execute the graph and return results.""" # Initialize state init_agent_state = self.propagator.create_initial_state( company_name, trade_date ) args = self.propagator.get_graph_args() + # Inject thread_id so same ticker+date resumes, different date starts fresh + if self.config.get("checkpoint_enabled"): + tid = thread_id(company_name, str(trade_date)) + args.setdefault("config", {}).setdefault("configurable", {})["thread_id"] = tid + if self.debug: # Debug mode with tracing trace = [] @@ -221,6 +259,10 @@ class TradingAgentsGraph: # Log state self._log_state(trade_date, final_state) + # Clear checkpoint on successful completion to avoid stale state + if self.config.get("checkpoint_enabled"): + clear_checkpoint(self.config["data_cache_dir"], company_name, str(trade_date)) + # Return decision and processed signal return final_state, self.process_signal(final_state["final_trade_decision"])