# TradingAgents/graph/trading_graph.py import os import sqlite3 import hashlib from pathlib import Path import json from datetime import date from typing import Dict, Any, Tuple, List, Optional from langgraph.prebuilt import ToolNode from langgraph.checkpoint.sqlite import SqliteSaver from tradingagents.llm_clients import create_llm_client from tradingagents.agents import * from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.agents.utils.memory import FinancialSituationMemory from tradingagents.agents.utils.agent_states import ( AgentState, InvestDebateState, RiskDebateState, ) from tradingagents.dataflows.config import set_config # Import the new abstract tool methods from agent_utils from tradingagents.agents.utils.agent_utils import ( get_stock_data, get_indicators, get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, get_news, get_insider_transactions, get_global_news ) from .conditional_logic import ConditionalLogic from .setup import GraphSetup from .propagation import Propagator from .reflection import Reflector from .signal_processing import SignalProcessor class TradingAgentsGraph: """Main class that orchestrates the trading agents framework.""" def __init__( self, selected_analysts=["market", "social", "news", "fundamentals"], debug=False, config: Dict[str, Any] = None, callbacks: Optional[List] = None, ): """Initialize the trading agents graph and components. Args: selected_analysts: List of analyst types to include debug: Whether to run in debug mode config: Configuration dictionary. If None, uses default config callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats) """ self.debug = debug self.config = config or DEFAULT_CONFIG self.callbacks = callbacks or [] self.selected_analysts = list(selected_analysts) # Update the interface's config set_config(self.config) # Create necessary directories os.makedirs( os.path.join(self.config["project_dir"], "dataflows/data_cache"), exist_ok=True, ) # Initialize LLMs with provider-specific thinking configuration llm_kwargs = self._get_provider_kwargs() # Add callbacks to kwargs if provided (passed to LLM constructor) if self.callbacks: llm_kwargs["callbacks"] = self.callbacks deep_client = create_llm_client( provider=self.config["llm_provider"], model=self.config["deep_think_llm"], base_url=self.config.get("backend_url"), **llm_kwargs, ) quick_client = create_llm_client( provider=self.config["llm_provider"], model=self.config["quick_think_llm"], base_url=self.config.get("backend_url"), **llm_kwargs, ) self.deep_thinking_llm = deep_client.get_llm() self.quick_thinking_llm = quick_client.get_llm() # Initialize memories self.bull_memory = FinancialSituationMemory("bull_memory", self.config) self.bear_memory = FinancialSituationMemory("bear_memory", self.config) self.trader_memory = FinancialSituationMemory("trader_memory", self.config) self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config) self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config) # Create tool nodes self.tool_nodes = self._create_tool_nodes() # Initialize components self.conditional_logic = ConditionalLogic( max_debate_rounds=self.config["max_debate_rounds"], max_risk_discuss_rounds=self.config["max_risk_discuss_rounds"], ) self.graph_setup = GraphSetup( self.quick_thinking_llm, self.deep_thinking_llm, self.tool_nodes, self.bull_memory, self.bear_memory, self.trader_memory, self.invest_judge_memory, self.risk_manager_memory, self.conditional_logic, ) self.propagator = Propagator() self.reflector = Reflector(self.quick_thinking_llm) self.signal_processor = SignalProcessor(self.quick_thinking_llm) # State tracking self.curr_state = None self.ticker = None self.log_states_dict = {} # date to full state dict self._sqlite_conn, self.checkpointer = self._create_sqlite_checkpointer(self.config) # Set up the graph (durable checkpoints for resume after crash) self.graph = self.graph_setup.setup_graph( selected_analysts, checkpointer=self.checkpointer ) @staticmethod def _create_sqlite_checkpointer( config: Dict[str, Any], ) -> Tuple[sqlite3.Connection, SqliteSaver]: """SQLite checkpoint store under results_dir/.checkpoints/langgraph.sqlite. Returns: (conn, checkpointer) – caller must close conn when done. """ results_dir = Path(config.get("results_dir", "./results")).expanduser().resolve() checkpoint_dir = results_dir / ".checkpoints" checkpoint_dir.mkdir(parents=True, exist_ok=True) db_path = checkpoint_dir / "langgraph.sqlite" conn = sqlite3.connect(str(db_path), check_same_thread=False) return conn, SqliteSaver(conn) def close(self) -> None: """Close the underlying SQLite connection held by the checkpointer.""" try: self._sqlite_conn.close() except Exception: pass def __del__(self) -> None: self.close() def _get_provider_kwargs(self) -> Dict[str, Any]: """Get provider-specific kwargs for LLM client creation.""" kwargs = {} provider = self.config.get("llm_provider", "").lower() if provider == "google": thinking_level = self.config.get("google_thinking_level") if thinking_level: kwargs["thinking_level"] = thinking_level elif provider == "openai": reasoning_effort = self.config.get("openai_reasoning_effort") if reasoning_effort: kwargs["reasoning_effort"] = reasoning_effort return kwargs def _create_tool_nodes(self) -> Dict[str, ToolNode]: """Create tool nodes for different data sources using abstract methods.""" return { "market": ToolNode( [ # Core stock data tools get_stock_data, # Technical indicators get_indicators, ] ), "social": ToolNode( [ # News tools for social media analysis get_news, ] ), "news": ToolNode( [ # News and insider information get_news, get_global_news, get_insider_transactions, ] ), "fundamentals": ToolNode( [ # Fundamental analysis tools get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, ] ), } def propagate(self, company_name, trade_date, thread_id: Optional[str] = None): """Run the trading agents graph for a company on a specific date.""" self.ticker = company_name if thread_id is None: payload = json.dumps( { "ticker": company_name.strip().upper(), "trade_date": str(trade_date), "analysts": sorted(self.selected_analysts), "llm_provider": self.config.get("llm_provider"), "deep_think_llm": self.config.get("deep_think_llm"), "quick_think_llm": self.config.get("quick_think_llm"), "max_debate_rounds": self.config.get("max_debate_rounds"), "max_risk_discuss_rounds": self.config.get("max_risk_discuss_rounds"), }, sort_keys=True, ).encode() thread_id = "ta_prog_" + hashlib.sha256(payload).hexdigest()[:24] # Initialize state init_agent_state = self.propagator.create_initial_state( company_name, trade_date ) args = self.propagator.get_graph_args(thread_id=thread_id) # Determine stream input: resume from checkpoint if an incomplete run exists, # otherwise start fresh. Passing None tells LangGraph to resume from the last # saved checkpoint for this thread_id. thread_config = {"configurable": {"thread_id": thread_id}} snap = self.graph.get_state(thread_config) if snap.next: # Incomplete run found — resume automatically (no user prompt in API mode) stream_input = None else: stream_input = init_agent_state if self.debug: # Debug mode with tracing trace = [] for chunk in self.graph.stream(stream_input, **args): if len(chunk["messages"]) == 0: pass else: chunk["messages"][-1].pretty_print() trace.append(chunk) if not trace: raise RuntimeError( "Graph stream produced no output — all chunks had empty messages. " f"ticker={company_name}, trade_date={trade_date}, thread_id={thread_id}" ) final_state = trace[-1] else: # Standard mode without tracing final_state = self.graph.invoke(stream_input, **args) # Store current state for reflection self.curr_state = final_state # Log state self._log_state(trade_date, final_state) # Return decision and processed signal return final_state, self.process_signal(final_state["final_trade_decision"]) def _log_state(self, trade_date, final_state): """Log the final state to a JSON file.""" self.log_states_dict[str(trade_date)] = { "company_of_interest": final_state["company_of_interest"], "trade_date": final_state["trade_date"], "market_report": final_state["market_report"], "sentiment_report": final_state["sentiment_report"], "news_report": final_state["news_report"], "fundamentals_report": final_state["fundamentals_report"], "investment_debate_state": { "bull_history": final_state["investment_debate_state"]["bull_history"], "bear_history": final_state["investment_debate_state"]["bear_history"], "history": final_state["investment_debate_state"]["history"], "current_response": final_state["investment_debate_state"][ "current_response" ], "judge_decision": final_state["investment_debate_state"][ "judge_decision" ], }, "trader_investment_decision": final_state["trader_investment_plan"], "risk_debate_state": { "aggressive_history": final_state["risk_debate_state"]["aggressive_history"], "conservative_history": final_state["risk_debate_state"]["conservative_history"], "neutral_history": final_state["risk_debate_state"]["neutral_history"], "history": final_state["risk_debate_state"]["history"], "judge_decision": final_state["risk_debate_state"]["judge_decision"], }, "investment_plan": final_state["investment_plan"], "final_trade_decision": final_state["final_trade_decision"], } # Save to file directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/") directory.mkdir(parents=True, exist_ok=True) with open( f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json", "w", encoding="utf-8", ) as f: json.dump(self.log_states_dict, f, indent=4) def reflect_and_remember(self, returns_losses): """Reflect on decisions and update memory based on returns.""" self.reflector.reflect_bull_researcher( self.curr_state, returns_losses, self.bull_memory ) self.reflector.reflect_bear_researcher( self.curr_state, returns_losses, self.bear_memory ) self.reflector.reflect_trader( self.curr_state, returns_losses, self.trader_memory ) self.reflector.reflect_invest_judge( self.curr_state, returns_losses, self.invest_judge_memory ) self.reflector.reflect_risk_manager( self.curr_state, returns_losses, self.risk_manager_memory ) def process_signal(self, full_signal): """Process a signal to extract the core decision.""" return self.signal_processor.process_signal(full_signal)