feat(027-checkpoint-resume-contrib): wire checkpointer into TradingAgentsGraph.compile()
This commit is contained in:
parent
e74a715e0a
commit
79b57a34fb
|
|
@ -217,3 +217,6 @@ __marimo__/
|
||||||
|
|
||||||
# Cache
|
# Cache
|
||||||
**/data_cache/
|
**/data_cache/
|
||||||
|
.kiro/
|
||||||
|
ralph-kiro.sh
|
||||||
|
progress.txt
|
||||||
|
|
|
||||||
|
|
@ -197,5 +197,4 @@ class GraphSetup:
|
||||||
|
|
||||||
workflow.add_edge("Portfolio Manager", END)
|
workflow.add_edge("Portfolio Manager", END)
|
||||||
|
|
||||||
# Compile and return
|
return workflow
|
||||||
return workflow.compile()
|
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ from tradingagents.agents.utils.agent_utils import (
|
||||||
get_global_news
|
get_global_news
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .checkpointer import get_checkpointer
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
from .setup import GraphSetup
|
from .setup import GraphSetup
|
||||||
from .propagation import Propagator
|
from .propagation import Propagator
|
||||||
|
|
@ -129,7 +130,9 @@ class TradingAgentsGraph:
|
||||||
self.log_states_dict = {} # date to full state dict
|
self.log_states_dict = {} # date to full state dict
|
||||||
|
|
||||||
# Set up the graph
|
# Set up the graph
|
||||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
self.workflow = self.graph_setup.setup_graph(selected_analysts)
|
||||||
|
self.graph = self.workflow.compile()
|
||||||
|
self._checkpointer_ctx = None
|
||||||
|
|
||||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||||
"""Get provider-specific kwargs for LLM client creation."""
|
"""Get provider-specific kwargs for LLM client creation."""
|
||||||
|
|
@ -194,6 +197,24 @@ class TradingAgentsGraph:
|
||||||
|
|
||||||
self.ticker = company_name
|
self.ticker = company_name
|
||||||
|
|
||||||
|
# Recompile with checkpointer if enabled
|
||||||
|
if self.config.get("checkpoint_enabled"):
|
||||||
|
self._checkpointer_ctx = get_checkpointer(
|
||||||
|
self.config["data_cache_dir"], company_name
|
||||||
|
)
|
||||||
|
saver = self._checkpointer_ctx.__enter__()
|
||||||
|
self.graph = self.workflow.compile(checkpointer=saver)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self._run_graph(company_name, trade_date)
|
||||||
|
finally:
|
||||||
|
if self._checkpointer_ctx is not None:
|
||||||
|
self._checkpointer_ctx.__exit__(None, None, None)
|
||||||
|
self._checkpointer_ctx = None
|
||||||
|
self.graph = self.workflow.compile()
|
||||||
|
|
||||||
|
def _run_graph(self, company_name, trade_date):
|
||||||
|
"""Execute the graph and return results."""
|
||||||
# Initialize state
|
# Initialize state
|
||||||
init_agent_state = self.propagator.create_initial_state(
|
init_agent_state = self.propagator.create_initial_state(
|
||||||
company_name, trade_date
|
company_name, trade_date
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue