feat: add FastAPI-based API for trading agents
- Introduced a new API structure with FastAPI, including routers for runs and settings. - Implemented endpoints for creating, listing, and retrieving run configurations. - Added settings management with load and update functionality. - Integrated SQLite checkpointing for durable state management during analysis. - Updated dependencies in `pyproject.toml` and `requirements.txt` to include FastAPI and related packages. - Enhanced `.gitignore` to exclude SQLite checkpoints and results directories.
This commit is contained in:
parent
0b13145dc0
commit
49283f47d5
|
|
@ -1,3 +1,6 @@
|
|||
# LangGraph SQLite checkpoints (under results_dir/.checkpoints/)
|
||||
**/.checkpoints/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[codz]
|
||||
|
|
@ -217,3 +220,6 @@ __marimo__/
|
|||
|
||||
# Cache
|
||||
**/data_cache/
|
||||
|
||||
# Results
|
||||
**/results/
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from api.routers import runs, settings
|
||||
|
||||
app = FastAPI(title="TradingAgents API", version="1.0.0")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:3000"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(runs.router, prefix="/runs", tags=["runs"])
|
||||
app.include_router(settings.router, prefix="/settings", tags=["settings"])
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
fastapi==0.115.0
|
||||
uvicorn[standard]==0.30.0
|
||||
sse-starlette==2.1.0
|
||||
python-dotenv==1.0.1
|
||||
pytest==8.3.0
|
||||
httpx==0.27.0
|
||||
pytest-asyncio==0.23.0
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
74
cli/main.py
74
cli/main.py
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional
|
||||
import datetime
|
||||
import uuid
|
||||
import typer
|
||||
from pathlib import Path
|
||||
from functools import wraps
|
||||
|
|
@ -863,6 +864,16 @@ def extract_content_string(content):
|
|||
return str(content).strip() if not is_empty(content) else None
|
||||
|
||||
|
||||
def debate_state_text(value) -> str:
|
||||
"""Normalize graph debate fields that may be str or list blocks (e.g. Gemini response.content)."""
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value.strip()
|
||||
extracted = extract_content_string(value)
|
||||
return (extracted or "").strip()
|
||||
|
||||
|
||||
def classify_message_type(message) -> tuple[str, str | None]:
|
||||
"""Classify LangChain message into display type and extract content.
|
||||
|
||||
|
|
@ -981,6 +992,38 @@ def run_analysis():
|
|||
message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call")
|
||||
message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section")
|
||||
|
||||
base_thread_id = compute_analysis_thread_id(selections, selected_analyst_keys)
|
||||
thread_id = base_thread_id
|
||||
init_agent_state = graph.propagator.create_initial_state(
|
||||
selections["ticker"], selections["analysis_date"]
|
||||
)
|
||||
snap = graph.graph.get_state({"configurable": {"thread_id": thread_id}})
|
||||
stream_input = init_agent_state
|
||||
if snap.next:
|
||||
choice = typer.prompt(
|
||||
"Incomplete checkpoint found for this run. [R]esume or start [N]ew run?",
|
||||
default="R",
|
||||
).strip().upper()
|
||||
if choice.startswith("N"):
|
||||
thread_id = f"{base_thread_id}_{uuid.uuid4().hex[:12]}"
|
||||
stream_input = init_agent_state
|
||||
else:
|
||||
stream_input = None
|
||||
elif snap.values:
|
||||
restart = typer.prompt(
|
||||
"A completed run exists for these settings. Start a new run?",
|
||||
default="Y",
|
||||
).strip().upper()
|
||||
if restart not in ("Y", "YES", ""):
|
||||
console.print("[yellow]Exiting without re-running.[/yellow]")
|
||||
return
|
||||
thread_id = f"{base_thread_id}_{uuid.uuid4().hex[:12]}"
|
||||
stream_input = init_agent_state
|
||||
|
||||
args = graph.propagator.get_graph_args(
|
||||
callbacks=[stats_handler], thread_id=thread_id
|
||||
)
|
||||
|
||||
# Now start the display layout
|
||||
layout = create_layout()
|
||||
|
||||
|
|
@ -997,6 +1040,11 @@ def run_analysis():
|
|||
"System",
|
||||
f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}",
|
||||
)
|
||||
message_buffer.add_message("System", f"Checkpoint thread: {thread_id}")
|
||||
if stream_input is None:
|
||||
message_buffer.add_message(
|
||||
"System", "Resuming from saved checkpoint (same thread_id)."
|
||||
)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
# Update agent status to in_progress for the first analyst
|
||||
|
|
@ -1010,17 +1058,9 @@ def run_analysis():
|
|||
)
|
||||
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
# Initialize state and get graph args with callbacks
|
||||
init_agent_state = graph.propagator.create_initial_state(
|
||||
selections["ticker"], selections["analysis_date"]
|
||||
)
|
||||
# Pass callbacks to graph config for tool execution tracking
|
||||
# (LLM tracking is handled separately via LLM constructor)
|
||||
args = graph.propagator.get_graph_args(callbacks=[stats_handler])
|
||||
|
||||
# Stream the analysis
|
||||
# Stream the analysis (init_agent_state / args / thread_id set above)
|
||||
trace = []
|
||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
||||
for chunk in graph.graph.stream(stream_input, **args):
|
||||
# Process messages if present (skip duplicates via message ID)
|
||||
if len(chunk["messages"]) > 0:
|
||||
last_message = chunk["messages"][-1]
|
||||
|
|
@ -1050,9 +1090,9 @@ def run_analysis():
|
|||
# Research Team - Handle Investment Debate State
|
||||
if chunk.get("investment_debate_state"):
|
||||
debate_state = chunk["investment_debate_state"]
|
||||
bull_hist = debate_state.get("bull_history", "").strip()
|
||||
bear_hist = debate_state.get("bear_history", "").strip()
|
||||
judge = debate_state.get("judge_decision", "").strip()
|
||||
bull_hist = debate_state_text(debate_state.get("bull_history", ""))
|
||||
bear_hist = debate_state_text(debate_state.get("bear_history", ""))
|
||||
judge = debate_state_text(debate_state.get("judge_decision", ""))
|
||||
|
||||
# Only update status when there's actual content
|
||||
if bull_hist or bear_hist:
|
||||
|
|
@ -1084,10 +1124,10 @@ def run_analysis():
|
|||
# Risk Management Team - Handle Risk Debate State
|
||||
if chunk.get("risk_debate_state"):
|
||||
risk_state = chunk["risk_debate_state"]
|
||||
agg_hist = risk_state.get("aggressive_history", "").strip()
|
||||
con_hist = risk_state.get("conservative_history", "").strip()
|
||||
neu_hist = risk_state.get("neutral_history", "").strip()
|
||||
judge = risk_state.get("judge_decision", "").strip()
|
||||
agg_hist = debate_state_text(risk_state.get("aggressive_history", ""))
|
||||
con_hist = debate_state_text(risk_state.get("conservative_history", ""))
|
||||
neu_hist = debate_state_text(risk_state.get("neutral_history", ""))
|
||||
judge = debate_state_text(risk_state.get("judge_decision", ""))
|
||||
|
||||
if agg_hist:
|
||||
if message_buffer.agent_status.get("Aggressive Analyst") != "completed":
|
||||
|
|
|
|||
27
cli/utils.py
27
cli/utils.py
|
|
@ -1,5 +1,7 @@
|
|||
import hashlib
|
||||
import json
|
||||
import questionary
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
|
|
@ -329,3 +331,26 @@ def ask_gemini_thinking_config() -> str | None:
|
|||
("pointer", "fg:green noinherit"),
|
||||
]),
|
||||
).ask()
|
||||
|
||||
|
||||
def compute_analysis_thread_id(
|
||||
selections: Dict[str, Any], selected_analyst_keys: List[str]
|
||||
) -> str:
|
||||
"""Stable LangGraph thread_id from ticker, date, analysts, and LLM settings."""
|
||||
payload = {
|
||||
"ticker": selections["ticker"].strip().upper(),
|
||||
"analysis_date": selections["analysis_date"],
|
||||
"analyst_keys": selected_analyst_keys,
|
||||
"max_debate_rounds": selections["research_depth"],
|
||||
"max_risk_discuss_rounds": selections["research_depth"],
|
||||
"llm_provider": selections["llm_provider"],
|
||||
"backend_url": selections.get("backend_url"),
|
||||
"quick_think_llm": selections["shallow_thinker"],
|
||||
"deep_think_llm": selections["deep_thinker"],
|
||||
"google_thinking_level": selections.get("google_thinking_level"),
|
||||
"openai_reasoning_effort": selections.get("openai_reasoning_effort"),
|
||||
}
|
||||
h = hashlib.sha256(
|
||||
json.dumps(payload, sort_keys=True).encode("utf-8")
|
||||
).hexdigest()[:32]
|
||||
return f"ta_{h}"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ dependencies = [
|
|||
"langchain-google-genai>=2.1.5",
|
||||
"langchain-openai>=0.3.23",
|
||||
"langgraph>=0.4.8",
|
||||
"langgraph-checkpoint-sqlite>=2.0.0",
|
||||
"pandas>=2.3.0",
|
||||
"parsel>=1.10.0",
|
||||
"pytz>=2025.2",
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ pandas
|
|||
yfinance
|
||||
stockstats
|
||||
langgraph
|
||||
langgraph-checkpoint-sqlite
|
||||
rank-bm25
|
||||
setuptools
|
||||
backtrader
|
||||
|
|
|
|||
|
|
@ -53,16 +53,23 @@ class Propagator:
|
|||
"news_report": "",
|
||||
}
|
||||
|
||||
def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]:
|
||||
def get_graph_args(
|
||||
self,
|
||||
callbacks: Optional[List] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get arguments for the graph invocation.
|
||||
|
||||
Args:
|
||||
callbacks: Optional list of callback handlers for tool execution tracking.
|
||||
Note: LLM callbacks are handled separately via LLM constructor.
|
||||
thread_id: LangGraph checkpoint thread (required when graph uses a checkpointer).
|
||||
"""
|
||||
config = {"recursion_limit": self.max_recur_limit}
|
||||
config: Dict[str, Any] = {"recursion_limit": self.max_recur_limit}
|
||||
if callbacks:
|
||||
config["callbacks"] = callbacks
|
||||
if thread_id is not None:
|
||||
config["configurable"] = {"thread_id": thread_id}
|
||||
return {
|
||||
"stream_mode": "values",
|
||||
"config": config,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
# TradingAgents/graph/setup.py
|
||||
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.graph import END, StateGraph, START
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
|
|
@ -38,7 +40,9 @@ class GraphSetup:
|
|||
self.conditional_logic = conditional_logic
|
||||
|
||||
def setup_graph(
|
||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
||||
self,
|
||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
checkpointer: Optional[BaseCheckpointSaver] = None,
|
||||
):
|
||||
"""Set up and compile the agent workflow graph.
|
||||
|
||||
|
|
@ -199,4 +203,4 @@ class GraphSetup:
|
|||
workflow.add_edge("Risk Judge", END)
|
||||
|
||||
# Compile and return
|
||||
return workflow.compile()
|
||||
return workflow.compile(checkpointer=checkpointer)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
# TradingAgents/graph/trading_graph.py
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
import json
|
||||
from datetime import date
|
||||
from typing import Dict, Any, Tuple, List, Optional
|
||||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
|
||||
from tradingagents.llm_clients import create_llm_client
|
||||
|
||||
|
|
@ -61,6 +64,7 @@ class TradingAgentsGraph:
|
|||
self.debug = debug
|
||||
self.config = config or DEFAULT_CONFIG
|
||||
self.callbacks = callbacks or []
|
||||
self.selected_analysts = list(selected_analysts)
|
||||
|
||||
# Update the interface's config
|
||||
set_config(self.config)
|
||||
|
|
@ -130,8 +134,37 @@ class TradingAgentsGraph:
|
|||
self.ticker = None
|
||||
self.log_states_dict = {} # date to full state dict
|
||||
|
||||
# Set up the graph
|
||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||||
self._sqlite_conn, self.checkpointer = self._create_sqlite_checkpointer(self.config)
|
||||
# Set up the graph (durable checkpoints for resume after crash)
|
||||
self.graph = self.graph_setup.setup_graph(
|
||||
selected_analysts, checkpointer=self.checkpointer
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_sqlite_checkpointer(
|
||||
config: Dict[str, Any],
|
||||
) -> Tuple[sqlite3.Connection, SqliteSaver]:
|
||||
"""SQLite checkpoint store under results_dir/.checkpoints/langgraph.sqlite.
|
||||
|
||||
Returns:
|
||||
(conn, checkpointer) – caller must close conn when done.
|
||||
"""
|
||||
results_dir = Path(config.get("results_dir", "./results")).expanduser().resolve()
|
||||
checkpoint_dir = results_dir / ".checkpoints"
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
db_path = checkpoint_dir / "langgraph.sqlite"
|
||||
conn = sqlite3.connect(str(db_path), check_same_thread=False)
|
||||
return conn, SqliteSaver(conn)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the underlying SQLite connection held by the checkpointer."""
|
||||
try:
|
||||
self._sqlite_conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.close()
|
||||
|
||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||
"""Get provider-specific kwargs for LLM client creation."""
|
||||
|
|
@ -186,31 +219,63 @@ class TradingAgentsGraph:
|
|||
),
|
||||
}
|
||||
|
||||
def propagate(self, company_name, trade_date):
|
||||
def propagate(self, company_name, trade_date, thread_id: Optional[str] = None):
|
||||
"""Run the trading agents graph for a company on a specific date."""
|
||||
|
||||
self.ticker = company_name
|
||||
|
||||
if thread_id is None:
|
||||
payload = json.dumps(
|
||||
{
|
||||
"ticker": company_name.strip().upper(),
|
||||
"trade_date": str(trade_date),
|
||||
"analysts": sorted(self.selected_analysts),
|
||||
"llm_provider": self.config.get("llm_provider"),
|
||||
"deep_think_llm": self.config.get("deep_think_llm"),
|
||||
"quick_think_llm": self.config.get("quick_think_llm"),
|
||||
"max_debate_rounds": self.config.get("max_debate_rounds"),
|
||||
"max_risk_discuss_rounds": self.config.get("max_risk_discuss_rounds"),
|
||||
},
|
||||
sort_keys=True,
|
||||
).encode()
|
||||
thread_id = "ta_prog_" + hashlib.sha256(payload).hexdigest()[:24]
|
||||
|
||||
# Initialize state
|
||||
init_agent_state = self.propagator.create_initial_state(
|
||||
company_name, trade_date
|
||||
)
|
||||
args = self.propagator.get_graph_args()
|
||||
args = self.propagator.get_graph_args(thread_id=thread_id)
|
||||
|
||||
# Determine stream input: resume from checkpoint if an incomplete run exists,
|
||||
# otherwise start fresh. Passing None tells LangGraph to resume from the last
|
||||
# saved checkpoint for this thread_id.
|
||||
thread_config = {"configurable": {"thread_id": thread_id}}
|
||||
snap = self.graph.get_state(thread_config)
|
||||
if snap.next:
|
||||
# Incomplete run found — resume automatically (no user prompt in API mode)
|
||||
stream_input = None
|
||||
else:
|
||||
stream_input = init_agent_state
|
||||
|
||||
if self.debug:
|
||||
# Debug mode with tracing
|
||||
trace = []
|
||||
for chunk in self.graph.stream(init_agent_state, **args):
|
||||
for chunk in self.graph.stream(stream_input, **args):
|
||||
if len(chunk["messages"]) == 0:
|
||||
pass
|
||||
else:
|
||||
chunk["messages"][-1].pretty_print()
|
||||
trace.append(chunk)
|
||||
|
||||
if not trace:
|
||||
raise RuntimeError(
|
||||
"Graph stream produced no output — all chunks had empty messages. "
|
||||
f"ticker={company_name}, trade_date={trade_date}, thread_id={thread_id}"
|
||||
)
|
||||
final_state = trace[-1]
|
||||
else:
|
||||
# Standard mode without tracing
|
||||
final_state = self.graph.invoke(init_agent_state, **args)
|
||||
final_state = self.graph.invoke(stream_input, **args)
|
||||
|
||||
# Store current state for reflection
|
||||
self.curr_state = final_state
|
||||
|
|
|
|||
Loading…
Reference in New Issue