diff --git a/tradingagents/graph/checkpointer.py b/tradingagents/graph/checkpointer.py new file mode 100644 index 00000000..b0c28039 --- /dev/null +++ b/tradingagents/graph/checkpointer.py @@ -0,0 +1,69 @@ +"""LangGraph checkpoint support for resumable analysis runs. + +Per-ticker SQLite databases so concurrent tickers don't contend. +""" + +from __future__ import annotations + +import hashlib +import sqlite3 +from contextlib import contextmanager +from pathlib import Path +from typing import Generator + +from langgraph.checkpoint.sqlite import SqliteSaver + + +def _db_path(data_dir: str | Path, ticker: str) -> Path: + """Return the SQLite checkpoint DB path for a ticker.""" + p = Path(data_dir) / "checkpoints" + p.mkdir(parents=True, exist_ok=True) + return p / f"{ticker.upper()}.db" + + +def thread_id(ticker: str, date: str) -> str: + """Deterministic thread ID for a ticker+date pair.""" + return hashlib.sha256(f"{ticker.upper()}:{date}".encode()).hexdigest()[:16] + + +@contextmanager +def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver, None, None]: + """Context manager yielding a SqliteSaver backed by a per-ticker DB.""" + db = _db_path(data_dir, ticker) + conn = sqlite3.connect(str(db), check_same_thread=False) + try: + saver = SqliteSaver(conn) + saver.setup() + yield saver + finally: + conn.close() + + +def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool: + """Check whether a resumable checkpoint exists for ticker+date.""" + db = _db_path(data_dir, ticker) + if not db.exists(): + return False + tid = thread_id(ticker, date) + with get_checkpointer(data_dir, ticker) as saver: + config = {"configurable": {"thread_id": tid}} + cp = saver.get_tuple(config) + return cp is not None + + +def clear_checkpoint(data_dir: str | Path, ticker: str, date: str) -> None: + """Remove checkpoint for a specific ticker+date (delete the whole DB if it's the only thread).""" + db = _db_path(data_dir, ticker) + if not db.exists(): + return + tid = thread_id(ticker, date) + conn = sqlite3.connect(str(db)) + try: + # Delete writes and checkpoints for this thread + for table in ("writes", "checkpoints"): + conn.execute(f"DELETE FROM {table} WHERE thread_id = ?", (tid,)) + conn.commit() + except sqlite3.OperationalError: + pass # table doesn't exist yet + finally: + conn.close()