Replace Chainlit with FastAPI SSE backend

Swap Chainlit chatbot UI for a minimal FastAPI service with:
- POST /analyze to start analysis
- GET /analyze/{id}/stream for SSE progress events
- GET /health for Railway healthcheck

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
dtarkent2-sys 2026-02-20 02:43:24 +00:00
parent ac782d179d
commit 52228414ed
3 changed files with 172 additions and 176 deletions

View File

@ -12,5 +12,4 @@ RUN pip install --no-cache-dir -r requirements.txt
COPY . . COPY . .
# Chainlit listens on $PORT (Railway sets this automatically) CMD uvicorn app:app --host 0.0.0.0 --port ${PORT:-8000}
CMD chainlit run app.py --host 0.0.0.0 --port ${PORT:-8000}

341
app.py
View File

@ -1,12 +1,16 @@
"""Chainlit web UI for TradingAgents — mirrors the CLI experience.""" """FastAPI SSE backend for TradingAgents."""
import os import os
import re
import time import time
import datetime import uuid
import asyncio
import json
from datetime import date from datetime import date
import chainlit as cl from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.default_config import DEFAULT_CONFIG
@ -16,22 +20,23 @@ from cli.main import (
classify_message_type, classify_message_type,
update_analyst_statuses, update_analyst_statuses,
update_research_team_status, update_research_team_status,
ANALYST_ORDER,
) )
app = FastAPI(title="TradingAgents API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def parse_ticker_date(text: str): # Active analysis queues: id -> asyncio.Queue
"""Extract ticker symbol and optional date from user message.""" analyses: dict[str, asyncio.Queue] = {}
date_match = re.search(r"(\d{4}-\d{2}-\d{2})", text)
trade_date = date_match.group(1) if date_match else str(date.today())
candidates = re.findall(r"\b([A-Z]{1,5})\b", text)
skip = {"I", "A", "THE", "AND", "OR", "FOR", "TO", "IN", "ON", "AT", "IS", class AnalyzeRequest(BaseModel):
"IT", "OF", "BY", "AS", "AN", "BE", "IF", "SO", "DO", "MY", "UP", ticker: str
"NO", "NOT", "ALL", "BUT", "HOW", "GET", "HAS", "HAD", "CAN", date: str | None = None
"WHAT", "ABOUT", "BUY", "SELL", "HOLD"}
tickers = [c for c in candidates if c not in skip]
return tickers[0] if tickers else None, trade_date
def build_config(): def build_config():
@ -52,75 +57,42 @@ def build_config():
return config return config
def format_agent_status_table(buf): def get_stats_dict(stats_handler, buf, start_time):
"""Build a markdown table showing agent status (like the CLI progress panel).""" """Build stats dict for SSE events."""
teams = {
"Analyst Team": ["Market Analyst", "Social Analyst", "News Analyst", "Fundamentals Analyst"],
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
"Trading Team": ["Trader"],
"Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"],
"Portfolio Management": ["Portfolio Manager"],
}
icons = {"pending": "\u23f3", "in_progress": "\u26a1", "completed": "\u2705", "error": "\u274c"}
lines = ["| Team | Agent | Status |", "|---|---|---|"]
for team, agents in teams.items():
active = [a for a in agents if a in buf.agent_status]
for i, agent in enumerate(active):
status = buf.agent_status.get(agent, "pending")
icon = icons.get(status, "")
team_col = team if i == 0 else ""
lines.append(f"| {team_col} | {agent} | {icon} {status} |")
return "\n".join(lines)
def format_stats(stats_handler, buf, start_time):
"""Format footer stats like the CLI."""
s = stats_handler.get_stats() s = stats_handler.get_stats()
agents_done = sum(1 for v in buf.agent_status.values() if v == "completed") agents_done = sum(1 for v in buf.agent_status.values() if v == "completed")
agents_total = len(buf.agent_status)
reports_done = buf.get_completed_reports_count()
reports_total = len(buf.report_sections)
elapsed = time.time() - start_time elapsed = time.time() - start_time
elapsed_str = f"{int(elapsed // 60):02d}:{int(elapsed % 60):02d}" return {
"agents_done": agents_done,
return ( "agents_total": len(buf.agent_status),
f"Agents: {agents_done}/{agents_total} | " "llm_calls": s["llm_calls"],
f"LLM: {s['llm_calls']} | Tools: {s['tool_calls']} | " "tool_calls": s["tool_calls"],
f"Tokens: {s['tokens_in']:,}\u2191 {s['tokens_out']:,}\u2193 | " "tokens_in": s["tokens_in"],
f"Reports: {reports_done}/{reports_total} | " "tokens_out": s["tokens_out"],
f"\u23f1 {elapsed_str}" "reports_done": buf.get_completed_reports_count(),
) "reports_total": len(buf.report_sections),
"elapsed": round(elapsed, 1),
}
@cl.on_chat_start def _agent_stage(agent_name):
async def on_chat_start(): """Map agent name to pipeline stage."""
await cl.Message( if agent_name in ("Market Analyst", "Social Analyst", "News Analyst", "Fundamentals Analyst"):
content=( return "analysts"
"**TradingAgents** \u2014 Multi-Agent LLM Trading Analysis\n\n" if agent_name in ("Bull Researcher", "Bear Researcher", "Research Manager"):
"Send a ticker symbol to analyze. Examples:\n" return "research"
"- `NVDA`\n" if agent_name == "Trader":
"- `Analyze AAPL 2024-12-01`\n" return "trading"
"- `What's the outlook for TSLA?`\n\n" if agent_name in ("Aggressive Analyst", "Conservative Analyst", "Neutral Analyst"):
"I'll run a team of AI analysts, researchers, traders, and risk managers " return "risk"
"to produce a trading decision." if agent_name == "Portfolio Manager":
) return "decision"
).send() return "unknown"
@cl.on_message async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
async def on_message(message: cl.Message): """Background task that runs the TradingAgents pipeline and pushes SSE events."""
ticker, trade_date = parse_ticker_date(message.content) q = analyses[analysis_id]
if not ticker:
await cl.Message(
content="I couldn't find a ticker symbol. Try something like `NVDA` or `Analyze AAPL 2024-12-01`."
).send()
return
# --- Build graph ---
config = build_config() config = build_config()
stats_handler = StatsCallbackHandler() stats_handler = StatsCallbackHandler()
selected_analysts = ["market", "social", "news", "fundamentals"] selected_analysts = ["market", "social", "news", "fundamentals"]
@ -133,35 +105,28 @@ async def on_message(message: cl.Message):
callbacks=[stats_handler], callbacks=[stats_handler],
) )
except Exception as e: except Exception as e:
await cl.Message(content=f"Failed to initialize agents: {e}").send() await q.put({"type": "error", "message": f"Init failed: {e}"})
await q.put(None)
return return
# --- Initialize message buffer (same as CLI) ---
buf = MessageBuffer() buf = MessageBuffer()
buf.init_for_analysis(selected_analysts) buf.init_for_analysis(selected_analysts)
# --- Status message (will be updated as agents progress) ---
status_msg = cl.Message(content=f"**Analyzing {ticker} for {trade_date}...**\n\n{format_agent_status_table(buf)}")
await status_msg.send()
# --- Stream the graph ---
init_state = graph.propagator.create_initial_state(ticker, trade_date) init_state = graph.propagator.create_initial_state(ticker, trade_date)
args = graph.propagator.get_graph_args(callbacks=[stats_handler]) args = graph.propagator.get_graph_args(callbacks=[stats_handler])
start_time = time.time() start_time = time.time()
# Steps we'll create as agents complete emitted_reports = set()
analyst_steps = {} # field -> Step research_emitted = False
research_step = None trader_emitted = False
trader_step = None risk_emitted = False
risk_step = None
last_status_update = 0
final_state = None final_state = None
prev_statuses = {}
try: try:
async for chunk in graph.graph.astream(init_state, **args): async for chunk in graph.graph.astream(init_state, **args):
final_state = chunk final_state = chunk
# --- Process messages (same as CLI lines 1024-1044) --- # Process messages (same logic as Chainlit app)
if chunk.get("messages") and len(chunk["messages"]) > 0: if chunk.get("messages") and len(chunk["messages"]) > 0:
last_msg = chunk["messages"][-1] last_msg = chunk["messages"][-1]
msg_id = getattr(last_msg, "id", None) msg_id = getattr(last_msg, "id", None)
@ -177,24 +142,41 @@ async def on_message(message: cl.Message):
else: else:
buf.add_tool_call(tc.name, tc.args) buf.add_tool_call(tc.name, tc.args)
# --- Update analyst statuses (same as CLI line 1047) ---
update_analyst_statuses(buf, chunk) update_analyst_statuses(buf, chunk)
st = get_stats_dict(stats_handler, buf, start_time)
# --- Emit analyst report Steps as they complete --- # Emit agent status changes only (avoid flooding)
report_names = { for agent, status in buf.agent_status.items():
"market_report": "Market Analyst", if prev_statuses.get(agent) != status:
"sentiment_report": "Sentiment Analyst", prev_statuses[agent] = status
"news_report": "News Analyst", await q.put({
"fundamentals_report": "Fundamentals Analyst", "type": "agent_update",
"agent": agent,
"stage": _agent_stage(agent),
"status": status,
"stats": st,
})
# Analyst reports
report_map = {
"market_report": ("Market Analyst", "analysts"),
"sentiment_report": ("Social Analyst", "analysts"),
"news_report": ("News Analyst", "analysts"),
"fundamentals_report": ("Fundamentals Analyst", "analysts"),
} }
for field, name in report_names.items(): for field, (agent_name, stage) in report_map.items():
if field not in analyst_steps and chunk.get(field): if field not in emitted_reports and chunk.get(field):
analyst_steps[field] = True emitted_reports.add(field)
async with cl.Step(name=f"\u2705 {name} Report", type="tool") as step: await q.put({
report = chunk[field] "type": "report",
step.output = report[:4000] if len(report) > 4000 else report "agent": agent_name,
"stage": stage,
"field": field,
"report": chunk[field],
"stats": st,
})
# --- Research debate (same as CLI lines 1050-1072) --- # Research debate
if chunk.get("investment_debate_state"): if chunk.get("investment_debate_state"):
debate = chunk["investment_debate_state"] debate = chunk["investment_debate_state"]
bull = debate.get("bull_history", "").strip() bull = debate.get("bull_history", "").strip()
@ -203,28 +185,33 @@ async def on_message(message: cl.Message):
if bull or bear: if bull or bear:
update_research_team_status("in_progress") update_research_team_status("in_progress")
buf.update_report_section("investment_plan",
(f"### Bull Researcher\n{bull}\n\n### Bear Researcher\n{bear}") if bear else f"### Bull Researcher\n{bull}")
if judge and not research_step: if judge and not research_emitted:
research_step = True research_emitted = True
buf.update_report_section("investment_plan", f"### Research Manager Decision\n{judge}")
update_research_team_status("completed") update_research_team_status("completed")
buf.update_agent_status("Trader", "in_progress") buf.update_agent_status("Trader", "in_progress")
async with cl.Step(name="\u2705 Research Debate", type="tool") as step: await q.put({
step.output = f"**Bull Case:**\n{bull}\n\n---\n\n**Bear Case:**\n{bear}\n\n---\n\n**Research Manager Decision:**\n{judge}" "type": "debate",
"stage": "research",
"bull": bull,
"bear": bear,
"judge": judge,
"stats": get_stats_dict(stats_handler, buf, start_time),
})
# --- Trader plan (same as CLI lines 1075-1081) --- # Trader plan
if chunk.get("trader_investment_plan") and not trader_step: if chunk.get("trader_investment_plan") and not trader_emitted:
trader_step = True trader_emitted = True
buf.update_report_section("trader_investment_plan", chunk["trader_investment_plan"])
buf.update_agent_status("Trader", "completed") buf.update_agent_status("Trader", "completed")
buf.update_agent_status("Aggressive Analyst", "in_progress") buf.update_agent_status("Aggressive Analyst", "in_progress")
async with cl.Step(name="\u2705 Trader Plan", type="tool") as step: await q.put({
plan = chunk["trader_investment_plan"] "type": "trader",
step.output = plan[:4000] if len(plan) > 4000 else plan "stage": "trading",
"plan": chunk["trader_investment_plan"],
"stats": get_stats_dict(stats_handler, buf, start_time),
})
# --- Risk debate (same as CLI lines 1084-1118) --- # Risk debate
if chunk.get("risk_debate_state"): if chunk.get("risk_debate_state"):
risk = chunk["risk_debate_state"] risk = chunk["risk_debate_state"]
agg = risk.get("aggressive_history", "").strip() agg = risk.get("aggressive_history", "").strip()
@ -239,66 +226,74 @@ async def on_message(message: cl.Message):
if neu: if neu:
buf.update_agent_status("Neutral Analyst", "in_progress") buf.update_agent_status("Neutral Analyst", "in_progress")
if judge and not risk_step: if judge and not risk_emitted:
risk_step = True risk_emitted = True
buf.update_agent_status("Aggressive Analyst", "completed") buf.update_agent_status("Aggressive Analyst", "completed")
buf.update_agent_status("Conservative Analyst", "completed") buf.update_agent_status("Conservative Analyst", "completed")
buf.update_agent_status("Neutral Analyst", "completed") buf.update_agent_status("Neutral Analyst", "completed")
buf.update_agent_status("Portfolio Manager", "completed") buf.update_agent_status("Portfolio Manager", "completed")
buf.update_report_section("final_trade_decision", f"### Portfolio Manager Decision\n{judge}") await q.put({
"type": "risk",
async with cl.Step(name="\u2705 Risk Assessment", type="tool") as step: "stage": "risk",
parts = [] "aggressive": agg,
if agg: "conservative": con,
parts.append(f"**Aggressive Analyst:**\n{agg}") "neutral": neu,
if con: "judge": judge,
parts.append(f"**Conservative Analyst:**\n{con}") "stats": get_stats_dict(stats_handler, buf, start_time),
if neu: })
parts.append(f"**Neutral Analyst:**\n{neu}")
parts.append(f"**Portfolio Manager Decision:**\n{judge}")
step.output = "\n\n---\n\n".join(parts)
# --- Update status message periodically ---
now = time.time()
if now - last_status_update > 5:
last_status_update = now
status_msg.content = (
f"**Analyzing {ticker} for {trade_date}...**\n\n"
f"{format_agent_status_table(buf)}\n\n"
f"*{format_stats(stats_handler, buf, start_time)}*"
)
await status_msg.update()
except Exception as e: except Exception as e:
await cl.Message(content=f"Error during analysis: {e}").send() await q.put({"type": "error", "message": str(e)})
await q.put(None)
return return
if not final_state: # Final decision
await cl.Message(content="Analysis produced no results.").send() if final_state:
return decision_text = final_state.get("final_trade_decision", "No decision reached.")
signal = graph.process_signal(decision_text)
for agent in buf.agent_status:
buf.update_agent_status(agent, "completed")
st = get_stats_dict(stats_handler, buf, start_time)
await q.put({
"type": "decision",
"stage": "decision",
"signal": signal,
"decision_text": decision_text,
"stats": st,
})
# --- Final decision --- await q.put(None) # sentinel — stream done
decision_text = final_state.get("final_trade_decision", "No decision reached.")
signal = graph.process_signal(decision_text)
# Mark all agents completed
for agent in buf.agent_status:
buf.update_agent_status(agent, "completed")
# Final status update @app.post("/analyze")
status_msg.content = ( async def start_analysis(req: AnalyzeRequest):
f"**Analysis complete for {ticker} ({trade_date})**\n\n" ticker = req.ticker.upper().strip()
f"{format_agent_status_table(buf)}\n\n" if not ticker or len(ticker) > 5:
f"*{format_stats(stats_handler, buf, start_time)}*" raise HTTPException(400, "Invalid ticker")
) trade_date = req.date or str(date.today())
await status_msg.update() analysis_id = str(uuid.uuid4())
analyses[analysis_id] = asyncio.Queue()
asyncio.create_task(run_analysis(analysis_id, ticker, trade_date))
return {"id": analysis_id, "ticker": ticker, "date": trade_date}
# Send the final decision
await cl.Message( @app.get("/analyze/{analysis_id}/stream")
content=( async def stream_analysis(analysis_id: str):
f"## {ticker} \u2014 Trading Decision\n\n" if analysis_id not in analyses:
f"### Signal: {signal}\n\n" raise HTTPException(404, "Analysis not found")
f"---\n\n" q = analyses[analysis_id]
f"{decision_text}"
) async def event_generator():
).send() while True:
event = await q.get()
if event is None:
break
yield {"data": json.dumps(event)}
analyses.pop(analysis_id, None)
return EventSourceResponse(event_generator())
@app.get("/health")
async def health():
return {"status": "ok"}

View File

@ -14,7 +14,9 @@ requests
tqdm tqdm
pytz pytz
redis redis
chainlit fastapi
uvicorn[standard]
sse-starlette
rich rich
typer typer
questionary questionary