diff --git a/cli/main.py b/cli/main.py index 4316ebed..81c04aa7 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1200,7 +1200,12 @@ def run_analysis(checkpoint: bool = False): @app.command() def analyze( checkpoint: bool = typer.Option(False, "--checkpoint", help="Enable checkpoint/resume: save state after each node so crashed runs can resume."), + clear_checkpoints: bool = typer.Option(False, "--clear-checkpoints", help="Delete all saved checkpoints before running (force fresh start)."), ): + if clear_checkpoints: + from tradingagents.graph.checkpointer import clear_all_checkpoints + n = clear_all_checkpoints(DEFAULT_CONFIG["data_cache_dir"]) + console.print(f"[yellow]Cleared {n} checkpoint(s).[/yellow]") run_analysis(checkpoint=checkpoint) diff --git a/tradingagents/graph/checkpointer.py b/tradingagents/graph/checkpointer.py index b0c28039..8787ffe5 100644 --- a/tradingagents/graph/checkpointer.py +++ b/tradingagents/graph/checkpointer.py @@ -51,6 +51,17 @@ def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool: return cp is not None +def clear_all_checkpoints(data_dir: str | Path) -> int: + """Remove all checkpoint DBs. Returns number of files deleted.""" + cp_dir = Path(data_dir) / "checkpoints" + if not cp_dir.exists(): + return 0 + dbs = list(cp_dir.glob("*.db")) + for db in dbs: + db.unlink() + return len(dbs) + + def clear_checkpoint(data_dir: str | Path, ticker: str, date: str) -> None: """Remove checkpoint for a specific ticker+date (delete the whole DB if it's the only thread).""" db = _db_path(data_dir, ticker)