feat(027-checkpoint-resume-contrib): wire checkpointer into TradingAgentsGraph.compile()
This commit is contained in:
parent
e74a715e0a
commit
79b57a34fb
|
|
@ -217,3 +217,6 @@ __marimo__/
|
|||
|
||||
# Cache
|
||||
**/data_cache/
|
||||
.kiro/
|
||||
ralph-kiro.sh
|
||||
progress.txt
|
||||
|
|
|
|||
|
|
@ -197,5 +197,4 @@ class GraphSetup:
|
|||
|
||||
workflow.add_edge("Portfolio Manager", END)
|
||||
|
||||
# Compile and return
|
||||
return workflow.compile()
|
||||
return workflow
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_global_news
|
||||
)
|
||||
|
||||
from .checkpointer import get_checkpointer
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .setup import GraphSetup
|
||||
from .propagation import Propagator
|
||||
|
|
@ -129,7 +130,9 @@ class TradingAgentsGraph:
|
|||
self.log_states_dict = {} # date to full state dict
|
||||
|
||||
# Set up the graph
|
||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||||
self.workflow = self.graph_setup.setup_graph(selected_analysts)
|
||||
self.graph = self.workflow.compile()
|
||||
self._checkpointer_ctx = None
|
||||
|
||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||
"""Get provider-specific kwargs for LLM client creation."""
|
||||
|
|
@ -194,6 +197,24 @@ class TradingAgentsGraph:
|
|||
|
||||
self.ticker = company_name
|
||||
|
||||
# Recompile with checkpointer if enabled
|
||||
if self.config.get("checkpoint_enabled"):
|
||||
self._checkpointer_ctx = get_checkpointer(
|
||||
self.config["data_cache_dir"], company_name
|
||||
)
|
||||
saver = self._checkpointer_ctx.__enter__()
|
||||
self.graph = self.workflow.compile(checkpointer=saver)
|
||||
|
||||
try:
|
||||
return self._run_graph(company_name, trade_date)
|
||||
finally:
|
||||
if self._checkpointer_ctx is not None:
|
||||
self._checkpointer_ctx.__exit__(None, None, None)
|
||||
self._checkpointer_ctx = None
|
||||
self.graph = self.workflow.compile()
|
||||
|
||||
def _run_graph(self, company_name, trade_date):
|
||||
"""Execute the graph and return results."""
|
||||
# Initialize state
|
||||
init_agent_state = self.propagator.create_initial_state(
|
||||
company_name, trade_date
|
||||
|
|
|
|||
Loading…
Reference in New Issue