feat: add FastAPI-based API for trading agents
- Introduced a new API structure with FastAPI, including routers for runs and settings. - Implemented endpoints for creating, listing, and retrieving run configurations. - Added settings management with load and update functionality. - Integrated SQLite checkpointing for durable state management during analysis. - Updated dependencies in `pyproject.toml` and `requirements.txt` to include FastAPI and related packages. - Enhanced `.gitignore` to exclude SQLite checkpoints and results directories.
This commit is contained in:
parent
0b13145dc0
commit
49283f47d5
|
|
@ -1,3 +1,6 @@
|
||||||
|
# LangGraph SQLite checkpoints (under results_dir/.checkpoints/)
|
||||||
|
**/.checkpoints/
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[codz]
|
*.py[codz]
|
||||||
|
|
@ -217,3 +220,6 @@ __marimo__/
|
||||||
|
|
||||||
# Cache
|
# Cache
|
||||||
**/data_cache/
|
**/data_cache/
|
||||||
|
|
||||||
|
# Results
|
||||||
|
**/results/
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from api.routers import runs, settings
|
||||||
|
|
||||||
|
app = FastAPI(title="TradingAgents API", version="1.0.0")
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["http://localhost:3000"],
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(runs.router, prefix="/runs", tags=["runs"])
|
||||||
|
app.include_router(settings.router, prefix="/settings", tags=["settings"])
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
fastapi==0.115.0
|
||||||
|
uvicorn[standard]==0.30.0
|
||||||
|
sse-starlette==2.1.0
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
pytest==8.3.0
|
||||||
|
httpx==0.27.0
|
||||||
|
pytest-asyncio==0.23.0
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
74
cli/main.py
74
cli/main.py
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import datetime
|
import datetime
|
||||||
|
import uuid
|
||||||
import typer
|
import typer
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
@ -863,6 +864,16 @@ def extract_content_string(content):
|
||||||
return str(content).strip() if not is_empty(content) else None
|
return str(content).strip() if not is_empty(content) else None
|
||||||
|
|
||||||
|
|
||||||
|
def debate_state_text(value) -> str:
|
||||||
|
"""Normalize graph debate fields that may be str or list blocks (e.g. Gemini response.content)."""
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value.strip()
|
||||||
|
extracted = extract_content_string(value)
|
||||||
|
return (extracted or "").strip()
|
||||||
|
|
||||||
|
|
||||||
def classify_message_type(message) -> tuple[str, str | None]:
|
def classify_message_type(message) -> tuple[str, str | None]:
|
||||||
"""Classify LangChain message into display type and extract content.
|
"""Classify LangChain message into display type and extract content.
|
||||||
|
|
||||||
|
|
@ -981,6 +992,38 @@ def run_analysis():
|
||||||
message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call")
|
message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call")
|
||||||
message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section")
|
message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section")
|
||||||
|
|
||||||
|
base_thread_id = compute_analysis_thread_id(selections, selected_analyst_keys)
|
||||||
|
thread_id = base_thread_id
|
||||||
|
init_agent_state = graph.propagator.create_initial_state(
|
||||||
|
selections["ticker"], selections["analysis_date"]
|
||||||
|
)
|
||||||
|
snap = graph.graph.get_state({"configurable": {"thread_id": thread_id}})
|
||||||
|
stream_input = init_agent_state
|
||||||
|
if snap.next:
|
||||||
|
choice = typer.prompt(
|
||||||
|
"Incomplete checkpoint found for this run. [R]esume or start [N]ew run?",
|
||||||
|
default="R",
|
||||||
|
).strip().upper()
|
||||||
|
if choice.startswith("N"):
|
||||||
|
thread_id = f"{base_thread_id}_{uuid.uuid4().hex[:12]}"
|
||||||
|
stream_input = init_agent_state
|
||||||
|
else:
|
||||||
|
stream_input = None
|
||||||
|
elif snap.values:
|
||||||
|
restart = typer.prompt(
|
||||||
|
"A completed run exists for these settings. Start a new run?",
|
||||||
|
default="Y",
|
||||||
|
).strip().upper()
|
||||||
|
if restart not in ("Y", "YES", ""):
|
||||||
|
console.print("[yellow]Exiting without re-running.[/yellow]")
|
||||||
|
return
|
||||||
|
thread_id = f"{base_thread_id}_{uuid.uuid4().hex[:12]}"
|
||||||
|
stream_input = init_agent_state
|
||||||
|
|
||||||
|
args = graph.propagator.get_graph_args(
|
||||||
|
callbacks=[stats_handler], thread_id=thread_id
|
||||||
|
)
|
||||||
|
|
||||||
# Now start the display layout
|
# Now start the display layout
|
||||||
layout = create_layout()
|
layout = create_layout()
|
||||||
|
|
||||||
|
|
@ -997,6 +1040,11 @@ def run_analysis():
|
||||||
"System",
|
"System",
|
||||||
f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}",
|
f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}",
|
||||||
)
|
)
|
||||||
|
message_buffer.add_message("System", f"Checkpoint thread: {thread_id}")
|
||||||
|
if stream_input is None:
|
||||||
|
message_buffer.add_message(
|
||||||
|
"System", "Resuming from saved checkpoint (same thread_id)."
|
||||||
|
)
|
||||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
# Update agent status to in_progress for the first analyst
|
# Update agent status to in_progress for the first analyst
|
||||||
|
|
@ -1010,17 +1058,9 @@ def run_analysis():
|
||||||
)
|
)
|
||||||
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time)
|
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time)
|
||||||
|
|
||||||
# Initialize state and get graph args with callbacks
|
# Stream the analysis (init_agent_state / args / thread_id set above)
|
||||||
init_agent_state = graph.propagator.create_initial_state(
|
|
||||||
selections["ticker"], selections["analysis_date"]
|
|
||||||
)
|
|
||||||
# Pass callbacks to graph config for tool execution tracking
|
|
||||||
# (LLM tracking is handled separately via LLM constructor)
|
|
||||||
args = graph.propagator.get_graph_args(callbacks=[stats_handler])
|
|
||||||
|
|
||||||
# Stream the analysis
|
|
||||||
trace = []
|
trace = []
|
||||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
for chunk in graph.graph.stream(stream_input, **args):
|
||||||
# Process messages if present (skip duplicates via message ID)
|
# Process messages if present (skip duplicates via message ID)
|
||||||
if len(chunk["messages"]) > 0:
|
if len(chunk["messages"]) > 0:
|
||||||
last_message = chunk["messages"][-1]
|
last_message = chunk["messages"][-1]
|
||||||
|
|
@ -1050,9 +1090,9 @@ def run_analysis():
|
||||||
# Research Team - Handle Investment Debate State
|
# Research Team - Handle Investment Debate State
|
||||||
if chunk.get("investment_debate_state"):
|
if chunk.get("investment_debate_state"):
|
||||||
debate_state = chunk["investment_debate_state"]
|
debate_state = chunk["investment_debate_state"]
|
||||||
bull_hist = debate_state.get("bull_history", "").strip()
|
bull_hist = debate_state_text(debate_state.get("bull_history", ""))
|
||||||
bear_hist = debate_state.get("bear_history", "").strip()
|
bear_hist = debate_state_text(debate_state.get("bear_history", ""))
|
||||||
judge = debate_state.get("judge_decision", "").strip()
|
judge = debate_state_text(debate_state.get("judge_decision", ""))
|
||||||
|
|
||||||
# Only update status when there's actual content
|
# Only update status when there's actual content
|
||||||
if bull_hist or bear_hist:
|
if bull_hist or bear_hist:
|
||||||
|
|
@ -1084,10 +1124,10 @@ def run_analysis():
|
||||||
# Risk Management Team - Handle Risk Debate State
|
# Risk Management Team - Handle Risk Debate State
|
||||||
if chunk.get("risk_debate_state"):
|
if chunk.get("risk_debate_state"):
|
||||||
risk_state = chunk["risk_debate_state"]
|
risk_state = chunk["risk_debate_state"]
|
||||||
agg_hist = risk_state.get("aggressive_history", "").strip()
|
agg_hist = debate_state_text(risk_state.get("aggressive_history", ""))
|
||||||
con_hist = risk_state.get("conservative_history", "").strip()
|
con_hist = debate_state_text(risk_state.get("conservative_history", ""))
|
||||||
neu_hist = risk_state.get("neutral_history", "").strip()
|
neu_hist = debate_state_text(risk_state.get("neutral_history", ""))
|
||||||
judge = risk_state.get("judge_decision", "").strip()
|
judge = debate_state_text(risk_state.get("judge_decision", ""))
|
||||||
|
|
||||||
if agg_hist:
|
if agg_hist:
|
||||||
if message_buffer.agent_status.get("Aggressive Analyst") != "completed":
|
if message_buffer.agent_status.get("Aggressive Analyst") != "completed":
|
||||||
|
|
|
||||||
27
cli/utils.py
27
cli/utils.py
|
|
@ -1,5 +1,7 @@
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
import questionary
|
import questionary
|
||||||
from typing import List, Optional, Tuple, Dict
|
from typing import List, Optional, Tuple, Dict, Any
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
|
|
@ -329,3 +331,26 @@ def ask_gemini_thinking_config() -> str | None:
|
||||||
("pointer", "fg:green noinherit"),
|
("pointer", "fg:green noinherit"),
|
||||||
]),
|
]),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_analysis_thread_id(
|
||||||
|
selections: Dict[str, Any], selected_analyst_keys: List[str]
|
||||||
|
) -> str:
|
||||||
|
"""Stable LangGraph thread_id from ticker, date, analysts, and LLM settings."""
|
||||||
|
payload = {
|
||||||
|
"ticker": selections["ticker"].strip().upper(),
|
||||||
|
"analysis_date": selections["analysis_date"],
|
||||||
|
"analyst_keys": selected_analyst_keys,
|
||||||
|
"max_debate_rounds": selections["research_depth"],
|
||||||
|
"max_risk_discuss_rounds": selections["research_depth"],
|
||||||
|
"llm_provider": selections["llm_provider"],
|
||||||
|
"backend_url": selections.get("backend_url"),
|
||||||
|
"quick_think_llm": selections["shallow_thinker"],
|
||||||
|
"deep_think_llm": selections["deep_thinker"],
|
||||||
|
"google_thinking_level": selections.get("google_thinking_level"),
|
||||||
|
"openai_reasoning_effort": selections.get("openai_reasoning_effort"),
|
||||||
|
}
|
||||||
|
h = hashlib.sha256(
|
||||||
|
json.dumps(payload, sort_keys=True).encode("utf-8")
|
||||||
|
).hexdigest()[:32]
|
||||||
|
return f"ta_{h}"
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ dependencies = [
|
||||||
"langchain-google-genai>=2.1.5",
|
"langchain-google-genai>=2.1.5",
|
||||||
"langchain-openai>=0.3.23",
|
"langchain-openai>=0.3.23",
|
||||||
"langgraph>=0.4.8",
|
"langgraph>=0.4.8",
|
||||||
|
"langgraph-checkpoint-sqlite>=2.0.0",
|
||||||
"pandas>=2.3.0",
|
"pandas>=2.3.0",
|
||||||
"parsel>=1.10.0",
|
"parsel>=1.10.0",
|
||||||
"pytz>=2025.2",
|
"pytz>=2025.2",
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ pandas
|
||||||
yfinance
|
yfinance
|
||||||
stockstats
|
stockstats
|
||||||
langgraph
|
langgraph
|
||||||
|
langgraph-checkpoint-sqlite
|
||||||
rank-bm25
|
rank-bm25
|
||||||
setuptools
|
setuptools
|
||||||
backtrader
|
backtrader
|
||||||
|
|
|
||||||
|
|
@ -53,16 +53,23 @@ class Propagator:
|
||||||
"news_report": "",
|
"news_report": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]:
|
def get_graph_args(
|
||||||
|
self,
|
||||||
|
callbacks: Optional[List] = None,
|
||||||
|
thread_id: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Get arguments for the graph invocation.
|
"""Get arguments for the graph invocation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
callbacks: Optional list of callback handlers for tool execution tracking.
|
callbacks: Optional list of callback handlers for tool execution tracking.
|
||||||
Note: LLM callbacks are handled separately via LLM constructor.
|
Note: LLM callbacks are handled separately via LLM constructor.
|
||||||
|
thread_id: LangGraph checkpoint thread (required when graph uses a checkpointer).
|
||||||
"""
|
"""
|
||||||
config = {"recursion_limit": self.max_recur_limit}
|
config: Dict[str, Any] = {"recursion_limit": self.max_recur_limit}
|
||||||
if callbacks:
|
if callbacks:
|
||||||
config["callbacks"] = callbacks
|
config["callbacks"] = callbacks
|
||||||
|
if thread_id is not None:
|
||||||
|
config["configurable"] = {"thread_id": thread_id}
|
||||||
return {
|
return {
|
||||||
"stream_mode": "values",
|
"stream_mode": "values",
|
||||||
"config": config,
|
"config": config,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
# TradingAgents/graph/setup.py
|
# TradingAgents/graph/setup.py
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||||
from langgraph.graph import END, StateGraph, START
|
from langgraph.graph import END, StateGraph, START
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
|
|
@ -38,7 +40,9 @@ class GraphSetup:
|
||||||
self.conditional_logic = conditional_logic
|
self.conditional_logic = conditional_logic
|
||||||
|
|
||||||
def setup_graph(
|
def setup_graph(
|
||||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
self,
|
||||||
|
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||||
|
checkpointer: Optional[BaseCheckpointSaver] = None,
|
||||||
):
|
):
|
||||||
"""Set up and compile the agent workflow graph.
|
"""Set up and compile the agent workflow graph.
|
||||||
|
|
||||||
|
|
@ -199,4 +203,4 @@ class GraphSetup:
|
||||||
workflow.add_edge("Risk Judge", END)
|
workflow.add_edge("Risk Judge", END)
|
||||||
|
|
||||||
# Compile and return
|
# Compile and return
|
||||||
return workflow.compile()
|
return workflow.compile(checkpointer=checkpointer)
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,15 @@
|
||||||
# TradingAgents/graph/trading_graph.py
|
# TradingAgents/graph/trading_graph.py
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import hashlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from typing import Dict, Any, Tuple, List, Optional
|
from typing import Dict, Any, Tuple, List, Optional
|
||||||
|
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||||
|
|
||||||
from tradingagents.llm_clients import create_llm_client
|
from tradingagents.llm_clients import create_llm_client
|
||||||
|
|
||||||
|
|
@ -61,6 +64,7 @@ class TradingAgentsGraph:
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.config = config or DEFAULT_CONFIG
|
self.config = config or DEFAULT_CONFIG
|
||||||
self.callbacks = callbacks or []
|
self.callbacks = callbacks or []
|
||||||
|
self.selected_analysts = list(selected_analysts)
|
||||||
|
|
||||||
# Update the interface's config
|
# Update the interface's config
|
||||||
set_config(self.config)
|
set_config(self.config)
|
||||||
|
|
@ -130,8 +134,37 @@ class TradingAgentsGraph:
|
||||||
self.ticker = None
|
self.ticker = None
|
||||||
self.log_states_dict = {} # date to full state dict
|
self.log_states_dict = {} # date to full state dict
|
||||||
|
|
||||||
# Set up the graph
|
self._sqlite_conn, self.checkpointer = self._create_sqlite_checkpointer(self.config)
|
||||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
# Set up the graph (durable checkpoints for resume after crash)
|
||||||
|
self.graph = self.graph_setup.setup_graph(
|
||||||
|
selected_analysts, checkpointer=self.checkpointer
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_sqlite_checkpointer(
|
||||||
|
config: Dict[str, Any],
|
||||||
|
) -> Tuple[sqlite3.Connection, SqliteSaver]:
|
||||||
|
"""SQLite checkpoint store under results_dir/.checkpoints/langgraph.sqlite.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(conn, checkpointer) – caller must close conn when done.
|
||||||
|
"""
|
||||||
|
results_dir = Path(config.get("results_dir", "./results")).expanduser().resolve()
|
||||||
|
checkpoint_dir = results_dir / ".checkpoints"
|
||||||
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
db_path = checkpoint_dir / "langgraph.sqlite"
|
||||||
|
conn = sqlite3.connect(str(db_path), check_same_thread=False)
|
||||||
|
return conn, SqliteSaver(conn)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the underlying SQLite connection held by the checkpointer."""
|
||||||
|
try:
|
||||||
|
self._sqlite_conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
self.close()
|
||||||
|
|
||||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||||
"""Get provider-specific kwargs for LLM client creation."""
|
"""Get provider-specific kwargs for LLM client creation."""
|
||||||
|
|
@ -186,31 +219,63 @@ class TradingAgentsGraph:
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def propagate(self, company_name, trade_date):
|
def propagate(self, company_name, trade_date, thread_id: Optional[str] = None):
|
||||||
"""Run the trading agents graph for a company on a specific date."""
|
"""Run the trading agents graph for a company on a specific date."""
|
||||||
|
|
||||||
self.ticker = company_name
|
self.ticker = company_name
|
||||||
|
|
||||||
|
if thread_id is None:
|
||||||
|
payload = json.dumps(
|
||||||
|
{
|
||||||
|
"ticker": company_name.strip().upper(),
|
||||||
|
"trade_date": str(trade_date),
|
||||||
|
"analysts": sorted(self.selected_analysts),
|
||||||
|
"llm_provider": self.config.get("llm_provider"),
|
||||||
|
"deep_think_llm": self.config.get("deep_think_llm"),
|
||||||
|
"quick_think_llm": self.config.get("quick_think_llm"),
|
||||||
|
"max_debate_rounds": self.config.get("max_debate_rounds"),
|
||||||
|
"max_risk_discuss_rounds": self.config.get("max_risk_discuss_rounds"),
|
||||||
|
},
|
||||||
|
sort_keys=True,
|
||||||
|
).encode()
|
||||||
|
thread_id = "ta_prog_" + hashlib.sha256(payload).hexdigest()[:24]
|
||||||
|
|
||||||
# Initialize state
|
# Initialize state
|
||||||
init_agent_state = self.propagator.create_initial_state(
|
init_agent_state = self.propagator.create_initial_state(
|
||||||
company_name, trade_date
|
company_name, trade_date
|
||||||
)
|
)
|
||||||
args = self.propagator.get_graph_args()
|
args = self.propagator.get_graph_args(thread_id=thread_id)
|
||||||
|
|
||||||
|
# Determine stream input: resume from checkpoint if an incomplete run exists,
|
||||||
|
# otherwise start fresh. Passing None tells LangGraph to resume from the last
|
||||||
|
# saved checkpoint for this thread_id.
|
||||||
|
thread_config = {"configurable": {"thread_id": thread_id}}
|
||||||
|
snap = self.graph.get_state(thread_config)
|
||||||
|
if snap.next:
|
||||||
|
# Incomplete run found — resume automatically (no user prompt in API mode)
|
||||||
|
stream_input = None
|
||||||
|
else:
|
||||||
|
stream_input = init_agent_state
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
# Debug mode with tracing
|
# Debug mode with tracing
|
||||||
trace = []
|
trace = []
|
||||||
for chunk in self.graph.stream(init_agent_state, **args):
|
for chunk in self.graph.stream(stream_input, **args):
|
||||||
if len(chunk["messages"]) == 0:
|
if len(chunk["messages"]) == 0:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
chunk["messages"][-1].pretty_print()
|
chunk["messages"][-1].pretty_print()
|
||||||
trace.append(chunk)
|
trace.append(chunk)
|
||||||
|
|
||||||
|
if not trace:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Graph stream produced no output — all chunks had empty messages. "
|
||||||
|
f"ticker={company_name}, trade_date={trade_date}, thread_id={thread_id}"
|
||||||
|
)
|
||||||
final_state = trace[-1]
|
final_state = trace[-1]
|
||||||
else:
|
else:
|
||||||
# Standard mode without tracing
|
# Standard mode without tracing
|
||||||
final_state = self.graph.invoke(init_agent_state, **args)
|
final_state = self.graph.invoke(stream_input, **args)
|
||||||
|
|
||||||
# Store current state for reflection
|
# Store current state for reflection
|
||||||
self.curr_state = final_state
|
self.curr_state = final_state
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue