diff --git a/.gitignore b/.gitignore index 9a2904a9..20fb9795 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# LangGraph SQLite checkpoints (under results_dir/.checkpoints/) +**/.checkpoints/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[codz] @@ -217,3 +220,6 @@ __marimo__/ # Cache **/data_cache/ + +# Results +**/results/ diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/main.py b/api/main.py new file mode 100644 index 00000000..25dbe5a0 --- /dev/null +++ b/api/main.py @@ -0,0 +1,15 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from api.routers import runs, settings + +app = FastAPI(title="TradingAgents API", version="1.0.0") + +app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:3000"], + allow_methods=["*"], + allow_headers=["*"], +) + +app.include_router(runs.router, prefix="/runs", tags=["runs"]) +app.include_router(settings.router, prefix="/settings", tags=["settings"]) diff --git a/api/models/__init__.py b/api/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/requirements.txt b/api/requirements.txt new file mode 100644 index 00000000..ecf2dfe1 --- /dev/null +++ b/api/requirements.txt @@ -0,0 +1,7 @@ +fastapi==0.115.0 +uvicorn[standard]==0.30.0 +sse-starlette==2.1.0 +python-dotenv==1.0.1 +pytest==8.3.0 +httpx==0.27.0 +pytest-asyncio==0.23.0 diff --git a/api/routers/__init__.py b/api/routers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/routers/runs.py b/api/routers/runs.py new file mode 100644 index 00000000..af9233c5 --- /dev/null +++ b/api/routers/runs.py @@ -0,0 +1,3 @@ +from fastapi import APIRouter + +router = APIRouter() diff --git a/api/routers/settings.py b/api/routers/settings.py new file mode 100644 index 00000000..af9233c5 --- /dev/null +++ b/api/routers/settings.py @@ -0,0 +1,3 @@ +from fastapi import APIRouter + +router = APIRouter() diff --git a/api/services/__init__.py b/api/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/store/__init__.py b/api/store/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cli/main.py b/cli/main.py index df6dc891..5ccc3fe9 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1,5 +1,6 @@ from typing import Optional import datetime +import uuid import typer from pathlib import Path from functools import wraps @@ -863,6 +864,16 @@ def extract_content_string(content): return str(content).strip() if not is_empty(content) else None +def debate_state_text(value) -> str: + """Normalize graph debate fields that may be str or list blocks (e.g. Gemini response.content).""" + if value is None: + return "" + if isinstance(value, str): + return value.strip() + extracted = extract_content_string(value) + return (extracted or "").strip() + + def classify_message_type(message) -> tuple[str, str | None]: """Classify LangChain message into display type and extract content. @@ -981,6 +992,38 @@ def run_analysis(): message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call") message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section") + base_thread_id = compute_analysis_thread_id(selections, selected_analyst_keys) + thread_id = base_thread_id + init_agent_state = graph.propagator.create_initial_state( + selections["ticker"], selections["analysis_date"] + ) + snap = graph.graph.get_state({"configurable": {"thread_id": thread_id}}) + stream_input = init_agent_state + if snap.next: + choice = typer.prompt( + "Incomplete checkpoint found for this run. [R]esume or start [N]ew run?", + default="R", + ).strip().upper() + if choice.startswith("N"): + thread_id = f"{base_thread_id}_{uuid.uuid4().hex[:12]}" + stream_input = init_agent_state + else: + stream_input = None + elif snap.values: + restart = typer.prompt( + "A completed run exists for these settings. Start a new run?", + default="Y", + ).strip().upper() + if restart not in ("Y", "YES", ""): + console.print("[yellow]Exiting without re-running.[/yellow]") + return + thread_id = f"{base_thread_id}_{uuid.uuid4().hex[:12]}" + stream_input = init_agent_state + + args = graph.propagator.get_graph_args( + callbacks=[stats_handler], thread_id=thread_id + ) + # Now start the display layout layout = create_layout() @@ -997,6 +1040,11 @@ def run_analysis(): "System", f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}", ) + message_buffer.add_message("System", f"Checkpoint thread: {thread_id}") + if stream_input is None: + message_buffer.add_message( + "System", "Resuming from saved checkpoint (same thread_id)." + ) update_display(layout, stats_handler=stats_handler, start_time=start_time) # Update agent status to in_progress for the first analyst @@ -1010,17 +1058,9 @@ def run_analysis(): ) update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time) - # Initialize state and get graph args with callbacks - init_agent_state = graph.propagator.create_initial_state( - selections["ticker"], selections["analysis_date"] - ) - # Pass callbacks to graph config for tool execution tracking - # (LLM tracking is handled separately via LLM constructor) - args = graph.propagator.get_graph_args(callbacks=[stats_handler]) - - # Stream the analysis + # Stream the analysis (init_agent_state / args / thread_id set above) trace = [] - for chunk in graph.graph.stream(init_agent_state, **args): + for chunk in graph.graph.stream(stream_input, **args): # Process messages if present (skip duplicates via message ID) if len(chunk["messages"]) > 0: last_message = chunk["messages"][-1] @@ -1050,9 +1090,9 @@ def run_analysis(): # Research Team - Handle Investment Debate State if chunk.get("investment_debate_state"): debate_state = chunk["investment_debate_state"] - bull_hist = debate_state.get("bull_history", "").strip() - bear_hist = debate_state.get("bear_history", "").strip() - judge = debate_state.get("judge_decision", "").strip() + bull_hist = debate_state_text(debate_state.get("bull_history", "")) + bear_hist = debate_state_text(debate_state.get("bear_history", "")) + judge = debate_state_text(debate_state.get("judge_decision", "")) # Only update status when there's actual content if bull_hist or bear_hist: @@ -1084,10 +1124,10 @@ def run_analysis(): # Risk Management Team - Handle Risk Debate State if chunk.get("risk_debate_state"): risk_state = chunk["risk_debate_state"] - agg_hist = risk_state.get("aggressive_history", "").strip() - con_hist = risk_state.get("conservative_history", "").strip() - neu_hist = risk_state.get("neutral_history", "").strip() - judge = risk_state.get("judge_decision", "").strip() + agg_hist = debate_state_text(risk_state.get("aggressive_history", "")) + con_hist = debate_state_text(risk_state.get("conservative_history", "")) + neu_hist = debate_state_text(risk_state.get("neutral_history", "")) + judge = debate_state_text(risk_state.get("judge_decision", "")) if agg_hist: if message_buffer.agent_status.get("Aggressive Analyst") != "completed": diff --git a/cli/utils.py b/cli/utils.py index 5a8ec16c..a8ecc1d4 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,5 +1,7 @@ +import hashlib +import json import questionary -from typing import List, Optional, Tuple, Dict +from typing import List, Optional, Tuple, Dict, Any from rich.console import Console @@ -329,3 +331,26 @@ def ask_gemini_thinking_config() -> str | None: ("pointer", "fg:green noinherit"), ]), ).ask() + + +def compute_analysis_thread_id( + selections: Dict[str, Any], selected_analyst_keys: List[str] +) -> str: + """Stable LangGraph thread_id from ticker, date, analysts, and LLM settings.""" + payload = { + "ticker": selections["ticker"].strip().upper(), + "analysis_date": selections["analysis_date"], + "analyst_keys": selected_analyst_keys, + "max_debate_rounds": selections["research_depth"], + "max_risk_discuss_rounds": selections["research_depth"], + "llm_provider": selections["llm_provider"], + "backend_url": selections.get("backend_url"), + "quick_think_llm": selections["shallow_thinker"], + "deep_think_llm": selections["deep_thinker"], + "google_thinking_level": selections.get("google_thinking_level"), + "openai_reasoning_effort": selections.get("openai_reasoning_effort"), + } + h = hashlib.sha256( + json.dumps(payload, sort_keys=True).encode("utf-8") + ).hexdigest()[:32] + return f"ta_{h}" diff --git a/pyproject.toml b/pyproject.toml index 4c91a733..e95cd447 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "langchain-google-genai>=2.1.5", "langchain-openai>=0.3.23", "langgraph>=0.4.8", + "langgraph-checkpoint-sqlite>=2.0.0", "pandas>=2.3.0", "parsel>=1.10.0", "pytz>=2025.2", diff --git a/requirements.txt b/requirements.txt index 184468b8..c3cb642f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ pandas yfinance stockstats langgraph +langgraph-checkpoint-sqlite rank-bm25 setuptools backtrader diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tradingagents/graph/propagation.py b/tradingagents/graph/propagation.py index 0fd10c0c..4007e56e 100644 --- a/tradingagents/graph/propagation.py +++ b/tradingagents/graph/propagation.py @@ -53,16 +53,23 @@ class Propagator: "news_report": "", } - def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]: + def get_graph_args( + self, + callbacks: Optional[List] = None, + thread_id: Optional[str] = None, + ) -> Dict[str, Any]: """Get arguments for the graph invocation. Args: callbacks: Optional list of callback handlers for tool execution tracking. Note: LLM callbacks are handled separately via LLM constructor. + thread_id: LangGraph checkpoint thread (required when graph uses a checkpointer). """ - config = {"recursion_limit": self.max_recur_limit} + config: Dict[str, Any] = {"recursion_limit": self.max_recur_limit} if callbacks: config["callbacks"] = callbacks + if thread_id is not None: + config["configurable"] = {"thread_id": thread_id} return { "stream_mode": "values", "config": config, diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index 772efe7f..0e9984a0 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -1,7 +1,9 @@ # TradingAgents/graph/setup.py -from typing import Dict, Any +from typing import Dict, Any, Optional + from langchain_openai import ChatOpenAI +from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.graph import END, StateGraph, START from langgraph.prebuilt import ToolNode @@ -38,7 +40,9 @@ class GraphSetup: self.conditional_logic = conditional_logic def setup_graph( - self, selected_analysts=["market", "social", "news", "fundamentals"] + self, + selected_analysts=["market", "social", "news", "fundamentals"], + checkpointer: Optional[BaseCheckpointSaver] = None, ): """Set up and compile the agent workflow graph. @@ -199,4 +203,4 @@ class GraphSetup: workflow.add_edge("Risk Judge", END) # Compile and return - return workflow.compile() + return workflow.compile(checkpointer=checkpointer) diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index c7ef0f98..0d12988c 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -1,12 +1,15 @@ # TradingAgents/graph/trading_graph.py import os +import sqlite3 +import hashlib from pathlib import Path import json from datetime import date from typing import Dict, Any, Tuple, List, Optional from langgraph.prebuilt import ToolNode +from langgraph.checkpoint.sqlite import SqliteSaver from tradingagents.llm_clients import create_llm_client @@ -61,6 +64,7 @@ class TradingAgentsGraph: self.debug = debug self.config = config or DEFAULT_CONFIG self.callbacks = callbacks or [] + self.selected_analysts = list(selected_analysts) # Update the interface's config set_config(self.config) @@ -130,8 +134,37 @@ class TradingAgentsGraph: self.ticker = None self.log_states_dict = {} # date to full state dict - # Set up the graph - self.graph = self.graph_setup.setup_graph(selected_analysts) + self._sqlite_conn, self.checkpointer = self._create_sqlite_checkpointer(self.config) + # Set up the graph (durable checkpoints for resume after crash) + self.graph = self.graph_setup.setup_graph( + selected_analysts, checkpointer=self.checkpointer + ) + + @staticmethod + def _create_sqlite_checkpointer( + config: Dict[str, Any], + ) -> Tuple[sqlite3.Connection, SqliteSaver]: + """SQLite checkpoint store under results_dir/.checkpoints/langgraph.sqlite. + + Returns: + (conn, checkpointer) – caller must close conn when done. + """ + results_dir = Path(config.get("results_dir", "./results")).expanduser().resolve() + checkpoint_dir = results_dir / ".checkpoints" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + db_path = checkpoint_dir / "langgraph.sqlite" + conn = sqlite3.connect(str(db_path), check_same_thread=False) + return conn, SqliteSaver(conn) + + def close(self) -> None: + """Close the underlying SQLite connection held by the checkpointer.""" + try: + self._sqlite_conn.close() + except Exception: + pass + + def __del__(self) -> None: + self.close() def _get_provider_kwargs(self) -> Dict[str, Any]: """Get provider-specific kwargs for LLM client creation.""" @@ -186,31 +219,63 @@ class TradingAgentsGraph: ), } - def propagate(self, company_name, trade_date): + def propagate(self, company_name, trade_date, thread_id: Optional[str] = None): """Run the trading agents graph for a company on a specific date.""" self.ticker = company_name + if thread_id is None: + payload = json.dumps( + { + "ticker": company_name.strip().upper(), + "trade_date": str(trade_date), + "analysts": sorted(self.selected_analysts), + "llm_provider": self.config.get("llm_provider"), + "deep_think_llm": self.config.get("deep_think_llm"), + "quick_think_llm": self.config.get("quick_think_llm"), + "max_debate_rounds": self.config.get("max_debate_rounds"), + "max_risk_discuss_rounds": self.config.get("max_risk_discuss_rounds"), + }, + sort_keys=True, + ).encode() + thread_id = "ta_prog_" + hashlib.sha256(payload).hexdigest()[:24] + # Initialize state init_agent_state = self.propagator.create_initial_state( company_name, trade_date ) - args = self.propagator.get_graph_args() + args = self.propagator.get_graph_args(thread_id=thread_id) + + # Determine stream input: resume from checkpoint if an incomplete run exists, + # otherwise start fresh. Passing None tells LangGraph to resume from the last + # saved checkpoint for this thread_id. + thread_config = {"configurable": {"thread_id": thread_id}} + snap = self.graph.get_state(thread_config) + if snap.next: + # Incomplete run found — resume automatically (no user prompt in API mode) + stream_input = None + else: + stream_input = init_agent_state if self.debug: # Debug mode with tracing trace = [] - for chunk in self.graph.stream(init_agent_state, **args): + for chunk in self.graph.stream(stream_input, **args): if len(chunk["messages"]) == 0: pass else: chunk["messages"][-1].pretty_print() trace.append(chunk) + if not trace: + raise RuntimeError( + "Graph stream produced no output — all chunks had empty messages. " + f"ticker={company_name}, trade_date={trade_date}, thread_id={thread_id}" + ) final_state = trace[-1] else: # Standard mode without tracing - final_state = self.graph.invoke(init_agent_state, **args) + final_state = self.graph.invoke(stream_input, **args) # Store current state for reflection self.curr_state = final_state