feat: add FastAPI-based API for trading agents

- Introduced a new API structure with FastAPI, including routers for runs and settings.
- Implemented endpoints for creating, listing, and retrieving run configurations.
- Added settings management with load and update functionality.
- Integrated SQLite checkpointing for durable state management during analysis.
- Updated dependencies in `pyproject.toml` and `requirements.txt` to include FastAPI and related packages.
- Enhanced `.gitignore` to exclude SQLite checkpoints and results directories.
This commit is contained in:
Ali AL OGAILI 2026-03-23 04:45:08 +01:00
parent 0b13145dc0
commit 49283f47d5
19 changed files with 206 additions and 29 deletions

6
.gitignore vendored
View File

@ -1,3 +1,6 @@
# LangGraph SQLite checkpoints (under results_dir/.checkpoints/)
**/.checkpoints/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
@ -217,3 +220,6 @@ __marimo__/
# Cache
**/data_cache/
# Results
**/results/

0
api/__init__.py Normal file
View File

15
api/main.py Normal file
View File

@ -0,0 +1,15 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from api.routers import runs, settings
app = FastAPI(title="TradingAgents API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(runs.router, prefix="/runs", tags=["runs"])
app.include_router(settings.router, prefix="/settings", tags=["settings"])

0
api/models/__init__.py Normal file
View File

7
api/requirements.txt Normal file
View File

@ -0,0 +1,7 @@
fastapi==0.115.0
uvicorn[standard]==0.30.0
sse-starlette==2.1.0
python-dotenv==1.0.1
pytest==8.3.0
httpx==0.27.0
pytest-asyncio==0.23.0

0
api/routers/__init__.py Normal file
View File

3
api/routers/runs.py Normal file
View File

@ -0,0 +1,3 @@
from fastapi import APIRouter
router = APIRouter()

3
api/routers/settings.py Normal file
View File

@ -0,0 +1,3 @@
from fastapi import APIRouter
router = APIRouter()

0
api/services/__init__.py Normal file
View File

0
api/store/__init__.py Normal file
View File

View File

@ -1,5 +1,6 @@
from typing import Optional
import datetime
import uuid
import typer
from pathlib import Path
from functools import wraps
@ -863,6 +864,16 @@ def extract_content_string(content):
return str(content).strip() if not is_empty(content) else None
def debate_state_text(value) -> str:
"""Normalize graph debate fields that may be str or list blocks (e.g. Gemini response.content)."""
if value is None:
return ""
if isinstance(value, str):
return value.strip()
extracted = extract_content_string(value)
return (extracted or "").strip()
def classify_message_type(message) -> tuple[str, str | None]:
"""Classify LangChain message into display type and extract content.
@ -981,6 +992,38 @@ def run_analysis():
message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call")
message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section")
base_thread_id = compute_analysis_thread_id(selections, selected_analyst_keys)
thread_id = base_thread_id
init_agent_state = graph.propagator.create_initial_state(
selections["ticker"], selections["analysis_date"]
)
snap = graph.graph.get_state({"configurable": {"thread_id": thread_id}})
stream_input = init_agent_state
if snap.next:
choice = typer.prompt(
"Incomplete checkpoint found for this run. [R]esume or start [N]ew run?",
default="R",
).strip().upper()
if choice.startswith("N"):
thread_id = f"{base_thread_id}_{uuid.uuid4().hex[:12]}"
stream_input = init_agent_state
else:
stream_input = None
elif snap.values:
restart = typer.prompt(
"A completed run exists for these settings. Start a new run?",
default="Y",
).strip().upper()
if restart not in ("Y", "YES", ""):
console.print("[yellow]Exiting without re-running.[/yellow]")
return
thread_id = f"{base_thread_id}_{uuid.uuid4().hex[:12]}"
stream_input = init_agent_state
args = graph.propagator.get_graph_args(
callbacks=[stats_handler], thread_id=thread_id
)
# Now start the display layout
layout = create_layout()
@ -997,6 +1040,11 @@ def run_analysis():
"System",
f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}",
)
message_buffer.add_message("System", f"Checkpoint thread: {thread_id}")
if stream_input is None:
message_buffer.add_message(
"System", "Resuming from saved checkpoint (same thread_id)."
)
update_display(layout, stats_handler=stats_handler, start_time=start_time)
# Update agent status to in_progress for the first analyst
@ -1010,17 +1058,9 @@ def run_analysis():
)
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time)
# Initialize state and get graph args with callbacks
init_agent_state = graph.propagator.create_initial_state(
selections["ticker"], selections["analysis_date"]
)
# Pass callbacks to graph config for tool execution tracking
# (LLM tracking is handled separately via LLM constructor)
args = graph.propagator.get_graph_args(callbacks=[stats_handler])
# Stream the analysis
# Stream the analysis (init_agent_state / args / thread_id set above)
trace = []
for chunk in graph.graph.stream(init_agent_state, **args):
for chunk in graph.graph.stream(stream_input, **args):
# Process messages if present (skip duplicates via message ID)
if len(chunk["messages"]) > 0:
last_message = chunk["messages"][-1]
@ -1050,9 +1090,9 @@ def run_analysis():
# Research Team - Handle Investment Debate State
if chunk.get("investment_debate_state"):
debate_state = chunk["investment_debate_state"]
bull_hist = debate_state.get("bull_history", "").strip()
bear_hist = debate_state.get("bear_history", "").strip()
judge = debate_state.get("judge_decision", "").strip()
bull_hist = debate_state_text(debate_state.get("bull_history", ""))
bear_hist = debate_state_text(debate_state.get("bear_history", ""))
judge = debate_state_text(debate_state.get("judge_decision", ""))
# Only update status when there's actual content
if bull_hist or bear_hist:
@ -1084,10 +1124,10 @@ def run_analysis():
# Risk Management Team - Handle Risk Debate State
if chunk.get("risk_debate_state"):
risk_state = chunk["risk_debate_state"]
agg_hist = risk_state.get("aggressive_history", "").strip()
con_hist = risk_state.get("conservative_history", "").strip()
neu_hist = risk_state.get("neutral_history", "").strip()
judge = risk_state.get("judge_decision", "").strip()
agg_hist = debate_state_text(risk_state.get("aggressive_history", ""))
con_hist = debate_state_text(risk_state.get("conservative_history", ""))
neu_hist = debate_state_text(risk_state.get("neutral_history", ""))
judge = debate_state_text(risk_state.get("judge_decision", ""))
if agg_hist:
if message_buffer.agent_status.get("Aggressive Analyst") != "completed":

View File

@ -1,5 +1,7 @@
import hashlib
import json
import questionary
from typing import List, Optional, Tuple, Dict
from typing import List, Optional, Tuple, Dict, Any
from rich.console import Console
@ -329,3 +331,26 @@ def ask_gemini_thinking_config() -> str | None:
("pointer", "fg:green noinherit"),
]),
).ask()
def compute_analysis_thread_id(
selections: Dict[str, Any], selected_analyst_keys: List[str]
) -> str:
"""Stable LangGraph thread_id from ticker, date, analysts, and LLM settings."""
payload = {
"ticker": selections["ticker"].strip().upper(),
"analysis_date": selections["analysis_date"],
"analyst_keys": selected_analyst_keys,
"max_debate_rounds": selections["research_depth"],
"max_risk_discuss_rounds": selections["research_depth"],
"llm_provider": selections["llm_provider"],
"backend_url": selections.get("backend_url"),
"quick_think_llm": selections["shallow_thinker"],
"deep_think_llm": selections["deep_thinker"],
"google_thinking_level": selections.get("google_thinking_level"),
"openai_reasoning_effort": selections.get("openai_reasoning_effort"),
}
h = hashlib.sha256(
json.dumps(payload, sort_keys=True).encode("utf-8")
).hexdigest()[:32]
return f"ta_{h}"

View File

@ -16,6 +16,7 @@ dependencies = [
"langchain-google-genai>=2.1.5",
"langchain-openai>=0.3.23",
"langgraph>=0.4.8",
"langgraph-checkpoint-sqlite>=2.0.0",
"pandas>=2.3.0",
"parsel>=1.10.0",
"pytz>=2025.2",

View File

@ -6,6 +6,7 @@ pandas
yfinance
stockstats
langgraph
langgraph-checkpoint-sqlite
rank-bm25
setuptools
backtrader

0
tests/__init__.py Normal file
View File

0
tests/api/__init__.py Normal file
View File

View File

@ -53,16 +53,23 @@ class Propagator:
"news_report": "",
}
def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]:
def get_graph_args(
self,
callbacks: Optional[List] = None,
thread_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Get arguments for the graph invocation.
Args:
callbacks: Optional list of callback handlers for tool execution tracking.
Note: LLM callbacks are handled separately via LLM constructor.
thread_id: LangGraph checkpoint thread (required when graph uses a checkpointer).
"""
config = {"recursion_limit": self.max_recur_limit}
config: Dict[str, Any] = {"recursion_limit": self.max_recur_limit}
if callbacks:
config["callbacks"] = callbacks
if thread_id is not None:
config["configurable"] = {"thread_id": thread_id}
return {
"stream_mode": "values",
"config": config,

View File

@ -1,7 +1,9 @@
# TradingAgents/graph/setup.py
from typing import Dict, Any
from typing import Dict, Any, Optional
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode
@ -38,7 +40,9 @@ class GraphSetup:
self.conditional_logic = conditional_logic
def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"]
self,
selected_analysts=["market", "social", "news", "fundamentals"],
checkpointer: Optional[BaseCheckpointSaver] = None,
):
"""Set up and compile the agent workflow graph.
@ -199,4 +203,4 @@ class GraphSetup:
workflow.add_edge("Risk Judge", END)
# Compile and return
return workflow.compile()
return workflow.compile(checkpointer=checkpointer)

View File

@ -1,12 +1,15 @@
# TradingAgents/graph/trading_graph.py
import os
import sqlite3
import hashlib
from pathlib import Path
import json
from datetime import date
from typing import Dict, Any, Tuple, List, Optional
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.sqlite import SqliteSaver
from tradingagents.llm_clients import create_llm_client
@ -61,6 +64,7 @@ class TradingAgentsGraph:
self.debug = debug
self.config = config or DEFAULT_CONFIG
self.callbacks = callbacks or []
self.selected_analysts = list(selected_analysts)
# Update the interface's config
set_config(self.config)
@ -130,8 +134,37 @@ class TradingAgentsGraph:
self.ticker = None
self.log_states_dict = {} # date to full state dict
# Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts)
self._sqlite_conn, self.checkpointer = self._create_sqlite_checkpointer(self.config)
# Set up the graph (durable checkpoints for resume after crash)
self.graph = self.graph_setup.setup_graph(
selected_analysts, checkpointer=self.checkpointer
)
@staticmethod
def _create_sqlite_checkpointer(
config: Dict[str, Any],
) -> Tuple[sqlite3.Connection, SqliteSaver]:
"""SQLite checkpoint store under results_dir/.checkpoints/langgraph.sqlite.
Returns:
(conn, checkpointer) caller must close conn when done.
"""
results_dir = Path(config.get("results_dir", "./results")).expanduser().resolve()
checkpoint_dir = results_dir / ".checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
db_path = checkpoint_dir / "langgraph.sqlite"
conn = sqlite3.connect(str(db_path), check_same_thread=False)
return conn, SqliteSaver(conn)
def close(self) -> None:
"""Close the underlying SQLite connection held by the checkpointer."""
try:
self._sqlite_conn.close()
except Exception:
pass
def __del__(self) -> None:
self.close()
def _get_provider_kwargs(self) -> Dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation."""
@ -186,31 +219,63 @@ class TradingAgentsGraph:
),
}
def propagate(self, company_name, trade_date):
def propagate(self, company_name, trade_date, thread_id: Optional[str] = None):
"""Run the trading agents graph for a company on a specific date."""
self.ticker = company_name
if thread_id is None:
payload = json.dumps(
{
"ticker": company_name.strip().upper(),
"trade_date": str(trade_date),
"analysts": sorted(self.selected_analysts),
"llm_provider": self.config.get("llm_provider"),
"deep_think_llm": self.config.get("deep_think_llm"),
"quick_think_llm": self.config.get("quick_think_llm"),
"max_debate_rounds": self.config.get("max_debate_rounds"),
"max_risk_discuss_rounds": self.config.get("max_risk_discuss_rounds"),
},
sort_keys=True,
).encode()
thread_id = "ta_prog_" + hashlib.sha256(payload).hexdigest()[:24]
# Initialize state
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date
)
args = self.propagator.get_graph_args()
args = self.propagator.get_graph_args(thread_id=thread_id)
# Determine stream input: resume from checkpoint if an incomplete run exists,
# otherwise start fresh. Passing None tells LangGraph to resume from the last
# saved checkpoint for this thread_id.
thread_config = {"configurable": {"thread_id": thread_id}}
snap = self.graph.get_state(thread_config)
if snap.next:
# Incomplete run found — resume automatically (no user prompt in API mode)
stream_input = None
else:
stream_input = init_agent_state
if self.debug:
# Debug mode with tracing
trace = []
for chunk in self.graph.stream(init_agent_state, **args):
for chunk in self.graph.stream(stream_input, **args):
if len(chunk["messages"]) == 0:
pass
else:
chunk["messages"][-1].pretty_print()
trace.append(chunk)
if not trace:
raise RuntimeError(
"Graph stream produced no output — all chunks had empty messages. "
f"ticker={company_name}, trade_date={trade_date}, thread_id={thread_id}"
)
final_state = trace[-1]
else:
# Standard mode without tracing
final_state = self.graph.invoke(init_agent_state, **args)
final_state = self.graph.invoke(stream_input, **args)
# Store current state for reflection
self.curr_state = final_state