feat(027-checkpoint-resume-contrib): log resume vs fresh start when checkpoint enabled
This commit is contained in:
parent
dd395bcaaf
commit
1aa2acfdb1
|
|
@ -41,14 +41,21 @@ def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver
|
||||||
|
|
||||||
def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool:
|
def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool:
|
||||||
"""Check whether a resumable checkpoint exists for ticker+date."""
|
"""Check whether a resumable checkpoint exists for ticker+date."""
|
||||||
|
return checkpoint_step(data_dir, ticker, date) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_step(data_dir: str | Path, ticker: str, date: str) -> int | None:
|
||||||
|
"""Return the step number of the latest checkpoint, or None if none exists."""
|
||||||
db = _db_path(data_dir, ticker)
|
db = _db_path(data_dir, ticker)
|
||||||
if not db.exists():
|
if not db.exists():
|
||||||
return False
|
return None
|
||||||
tid = thread_id(ticker, date)
|
tid = thread_id(ticker, date)
|
||||||
with get_checkpointer(data_dir, ticker) as saver:
|
with get_checkpointer(data_dir, ticker) as saver:
|
||||||
config = {"configurable": {"thread_id": tid}}
|
config = {"configurable": {"thread_id": tid}}
|
||||||
cp = saver.get_tuple(config)
|
cp = saver.get_tuple(config)
|
||||||
return cp is not None
|
if cp is None:
|
||||||
|
return None
|
||||||
|
return cp.metadata.get("step")
|
||||||
|
|
||||||
|
|
||||||
def clear_all_checkpoints(data_dir: str | Path) -> int:
|
def clear_all_checkpoints(data_dir: str | Path) -> int:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
# TradingAgents/graph/trading_graph.py
|
# TradingAgents/graph/trading_graph.py
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
|
|
@ -33,7 +34,7 @@ from tradingagents.agents.utils.agent_utils import (
|
||||||
get_global_news
|
get_global_news
|
||||||
)
|
)
|
||||||
|
|
||||||
from .checkpointer import clear_checkpoint, get_checkpointer, thread_id
|
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
from .setup import GraphSetup
|
from .setup import GraphSetup
|
||||||
from .propagation import Propagator
|
from .propagation import Propagator
|
||||||
|
|
@ -44,6 +45,8 @@ from .signal_processing import SignalProcessor
|
||||||
class TradingAgentsGraph:
|
class TradingAgentsGraph:
|
||||||
"""Main class that orchestrates the trading agents framework."""
|
"""Main class that orchestrates the trading agents framework."""
|
||||||
|
|
||||||
|
logger = logging.getLogger("tradingagents.graph")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||||
|
|
@ -205,6 +208,15 @@ class TradingAgentsGraph:
|
||||||
saver = self._checkpointer_ctx.__enter__()
|
saver = self._checkpointer_ctx.__enter__()
|
||||||
self.graph = self.workflow.compile(checkpointer=saver)
|
self.graph = self.workflow.compile(checkpointer=saver)
|
||||||
|
|
||||||
|
# Log resume vs fresh start
|
||||||
|
step = checkpoint_step(
|
||||||
|
self.config["data_cache_dir"], company_name, str(trade_date)
|
||||||
|
)
|
||||||
|
if step is not None:
|
||||||
|
self.logger.info("Resuming from step %d for %s on %s", step, company_name, trade_date)
|
||||||
|
else:
|
||||||
|
self.logger.info("Starting fresh for %s on %s", company_name, trade_date)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self._run_graph(company_name, trade_date)
|
return self._run_graph(company_name, trade_date)
|
||||||
finally:
|
finally:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue