diff --git a/tradingagents/graph/checkpointer.py b/tradingagents/graph/checkpointer.py index 8787ffe5..fc02e06b 100644 --- a/tradingagents/graph/checkpointer.py +++ b/tradingagents/graph/checkpointer.py @@ -41,14 +41,21 @@ def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool: """Check whether a resumable checkpoint exists for ticker+date.""" + return checkpoint_step(data_dir, ticker, date) is not None + + +def checkpoint_step(data_dir: str | Path, ticker: str, date: str) -> int | None: + """Return the step number of the latest checkpoint, or None if none exists.""" db = _db_path(data_dir, ticker) if not db.exists(): - return False + return None tid = thread_id(ticker, date) with get_checkpointer(data_dir, ticker) as saver: config = {"configurable": {"thread_id": tid}} cp = saver.get_tuple(config) - return cp is not None + if cp is None: + return None + return cp.metadata.get("step") def clear_all_checkpoints(data_dir: str | Path) -> int: diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 560ec842..febb5f23 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -1,5 +1,6 @@ # TradingAgents/graph/trading_graph.py +import logging import os from pathlib import Path import json @@ -33,7 +34,7 @@ from tradingagents.agents.utils.agent_utils import ( get_global_news ) -from .checkpointer import clear_checkpoint, get_checkpointer, thread_id +from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id from .conditional_logic import ConditionalLogic from .setup import GraphSetup from .propagation import Propagator @@ -44,6 +45,8 @@ from .signal_processing import SignalProcessor class TradingAgentsGraph: """Main class that orchestrates the trading agents framework.""" + logger = logging.getLogger("tradingagents.graph") + def __init__( self, selected_analysts=["market", "social", "news", "fundamentals"], @@ -205,6 +208,15 @@ class TradingAgentsGraph: saver = self._checkpointer_ctx.__enter__() self.graph = self.workflow.compile(checkpointer=saver) + # Log resume vs fresh start + step = checkpoint_step( + self.config["data_cache_dir"], company_name, str(trade_date) + ) + if step is not None: + self.logger.info("Resuming from step %d for %s on %s", step, company_name, trade_date) + else: + self.logger.info("Starting fresh for %s on %s", company_name, trade_date) + try: return self._run_graph(company_name, trade_date) finally: