Replace Chainlit with FastAPI SSE backend

Swap Chainlit chatbot UI for a minimal FastAPI service with:
- POST /analyze to start analysis
- GET /analyze/{id}/stream for SSE progress events
- GET /health for Railway healthcheck

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
dtarkent2-sys 2026-02-20 02:43:24 +00:00
parent ac782d179d
commit 52228414ed
3 changed files with 172 additions and 176 deletions

View File

@ -12,5 +12,4 @@ RUN pip install --no-cache-dir -r requirements.txt
COPY . .
# Chainlit listens on $PORT (Railway sets this automatically)
CMD chainlit run app.py --host 0.0.0.0 --port ${PORT:-8000}
CMD uvicorn app:app --host 0.0.0.0 --port ${PORT:-8000}

341
app.py
View File

@ -1,12 +1,16 @@
"""Chainlit web UI for TradingAgents — mirrors the CLI experience."""
"""FastAPI SSE backend for TradingAgents."""
import os
import re
import time
import datetime
import uuid
import asyncio
import json
from datetime import date
import chainlit as cl
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
@ -16,22 +20,23 @@ from cli.main import (
classify_message_type,
update_analyst_statuses,
update_research_team_status,
ANALYST_ORDER,
)
app = FastAPI(title="TradingAgents API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def parse_ticker_date(text: str):
"""Extract ticker symbol and optional date from user message."""
date_match = re.search(r"(\d{4}-\d{2}-\d{2})", text)
trade_date = date_match.group(1) if date_match else str(date.today())
# Active analysis queues: id -> asyncio.Queue
analyses: dict[str, asyncio.Queue] = {}
candidates = re.findall(r"\b([A-Z]{1,5})\b", text)
skip = {"I", "A", "THE", "AND", "OR", "FOR", "TO", "IN", "ON", "AT", "IS",
"IT", "OF", "BY", "AS", "AN", "BE", "IF", "SO", "DO", "MY", "UP",
"NO", "NOT", "ALL", "BUT", "HOW", "GET", "HAS", "HAD", "CAN",
"WHAT", "ABOUT", "BUY", "SELL", "HOLD"}
tickers = [c for c in candidates if c not in skip]
return tickers[0] if tickers else None, trade_date
class AnalyzeRequest(BaseModel):
ticker: str
date: str | None = None
def build_config():
@ -52,75 +57,42 @@ def build_config():
return config
def format_agent_status_table(buf):
"""Build a markdown table showing agent status (like the CLI progress panel)."""
teams = {
"Analyst Team": ["Market Analyst", "Social Analyst", "News Analyst", "Fundamentals Analyst"],
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
"Trading Team": ["Trader"],
"Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"],
"Portfolio Management": ["Portfolio Manager"],
}
icons = {"pending": "\u23f3", "in_progress": "\u26a1", "completed": "\u2705", "error": "\u274c"}
lines = ["| Team | Agent | Status |", "|---|---|---|"]
for team, agents in teams.items():
active = [a for a in agents if a in buf.agent_status]
for i, agent in enumerate(active):
status = buf.agent_status.get(agent, "pending")
icon = icons.get(status, "")
team_col = team if i == 0 else ""
lines.append(f"| {team_col} | {agent} | {icon} {status} |")
return "\n".join(lines)
def format_stats(stats_handler, buf, start_time):
"""Format footer stats like the CLI."""
def get_stats_dict(stats_handler, buf, start_time):
"""Build stats dict for SSE events."""
s = stats_handler.get_stats()
agents_done = sum(1 for v in buf.agent_status.values() if v == "completed")
agents_total = len(buf.agent_status)
reports_done = buf.get_completed_reports_count()
reports_total = len(buf.report_sections)
elapsed = time.time() - start_time
elapsed_str = f"{int(elapsed // 60):02d}:{int(elapsed % 60):02d}"
return (
f"Agents: {agents_done}/{agents_total} | "
f"LLM: {s['llm_calls']} | Tools: {s['tool_calls']} | "
f"Tokens: {s['tokens_in']:,}\u2191 {s['tokens_out']:,}\u2193 | "
f"Reports: {reports_done}/{reports_total} | "
f"\u23f1 {elapsed_str}"
)
return {
"agents_done": agents_done,
"agents_total": len(buf.agent_status),
"llm_calls": s["llm_calls"],
"tool_calls": s["tool_calls"],
"tokens_in": s["tokens_in"],
"tokens_out": s["tokens_out"],
"reports_done": buf.get_completed_reports_count(),
"reports_total": len(buf.report_sections),
"elapsed": round(elapsed, 1),
}
@cl.on_chat_start
async def on_chat_start():
await cl.Message(
content=(
"**TradingAgents** \u2014 Multi-Agent LLM Trading Analysis\n\n"
"Send a ticker symbol to analyze. Examples:\n"
"- `NVDA`\n"
"- `Analyze AAPL 2024-12-01`\n"
"- `What's the outlook for TSLA?`\n\n"
"I'll run a team of AI analysts, researchers, traders, and risk managers "
"to produce a trading decision."
)
).send()
def _agent_stage(agent_name):
"""Map agent name to pipeline stage."""
if agent_name in ("Market Analyst", "Social Analyst", "News Analyst", "Fundamentals Analyst"):
return "analysts"
if agent_name in ("Bull Researcher", "Bear Researcher", "Research Manager"):
return "research"
if agent_name == "Trader":
return "trading"
if agent_name in ("Aggressive Analyst", "Conservative Analyst", "Neutral Analyst"):
return "risk"
if agent_name == "Portfolio Manager":
return "decision"
return "unknown"
@cl.on_message
async def on_message(message: cl.Message):
ticker, trade_date = parse_ticker_date(message.content)
if not ticker:
await cl.Message(
content="I couldn't find a ticker symbol. Try something like `NVDA` or `Analyze AAPL 2024-12-01`."
).send()
return
# --- Build graph ---
async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
"""Background task that runs the TradingAgents pipeline and pushes SSE events."""
q = analyses[analysis_id]
config = build_config()
stats_handler = StatsCallbackHandler()
selected_analysts = ["market", "social", "news", "fundamentals"]
@ -133,35 +105,28 @@ async def on_message(message: cl.Message):
callbacks=[stats_handler],
)
except Exception as e:
await cl.Message(content=f"Failed to initialize agents: {e}").send()
await q.put({"type": "error", "message": f"Init failed: {e}"})
await q.put(None)
return
# --- Initialize message buffer (same as CLI) ---
buf = MessageBuffer()
buf.init_for_analysis(selected_analysts)
# --- Status message (will be updated as agents progress) ---
status_msg = cl.Message(content=f"**Analyzing {ticker} for {trade_date}...**\n\n{format_agent_status_table(buf)}")
await status_msg.send()
# --- Stream the graph ---
init_state = graph.propagator.create_initial_state(ticker, trade_date)
args = graph.propagator.get_graph_args(callbacks=[stats_handler])
start_time = time.time()
# Steps we'll create as agents complete
analyst_steps = {} # field -> Step
research_step = None
trader_step = None
risk_step = None
last_status_update = 0
emitted_reports = set()
research_emitted = False
trader_emitted = False
risk_emitted = False
final_state = None
prev_statuses = {}
try:
async for chunk in graph.graph.astream(init_state, **args):
final_state = chunk
# --- Process messages (same as CLI lines 1024-1044) ---
# Process messages (same logic as Chainlit app)
if chunk.get("messages") and len(chunk["messages"]) > 0:
last_msg = chunk["messages"][-1]
msg_id = getattr(last_msg, "id", None)
@ -177,24 +142,41 @@ async def on_message(message: cl.Message):
else:
buf.add_tool_call(tc.name, tc.args)
# --- Update analyst statuses (same as CLI line 1047) ---
update_analyst_statuses(buf, chunk)
st = get_stats_dict(stats_handler, buf, start_time)
# --- Emit analyst report Steps as they complete ---
report_names = {
"market_report": "Market Analyst",
"sentiment_report": "Sentiment Analyst",
"news_report": "News Analyst",
"fundamentals_report": "Fundamentals Analyst",
# Emit agent status changes only (avoid flooding)
for agent, status in buf.agent_status.items():
if prev_statuses.get(agent) != status:
prev_statuses[agent] = status
await q.put({
"type": "agent_update",
"agent": agent,
"stage": _agent_stage(agent),
"status": status,
"stats": st,
})
# Analyst reports
report_map = {
"market_report": ("Market Analyst", "analysts"),
"sentiment_report": ("Social Analyst", "analysts"),
"news_report": ("News Analyst", "analysts"),
"fundamentals_report": ("Fundamentals Analyst", "analysts"),
}
for field, name in report_names.items():
if field not in analyst_steps and chunk.get(field):
analyst_steps[field] = True
async with cl.Step(name=f"\u2705 {name} Report", type="tool") as step:
report = chunk[field]
step.output = report[:4000] if len(report) > 4000 else report
for field, (agent_name, stage) in report_map.items():
if field not in emitted_reports and chunk.get(field):
emitted_reports.add(field)
await q.put({
"type": "report",
"agent": agent_name,
"stage": stage,
"field": field,
"report": chunk[field],
"stats": st,
})
# --- Research debate (same as CLI lines 1050-1072) ---
# Research debate
if chunk.get("investment_debate_state"):
debate = chunk["investment_debate_state"]
bull = debate.get("bull_history", "").strip()
@ -203,28 +185,33 @@ async def on_message(message: cl.Message):
if bull or bear:
update_research_team_status("in_progress")
buf.update_report_section("investment_plan",
(f"### Bull Researcher\n{bull}\n\n### Bear Researcher\n{bear}") if bear else f"### Bull Researcher\n{bull}")
if judge and not research_step:
research_step = True
buf.update_report_section("investment_plan", f"### Research Manager Decision\n{judge}")
if judge and not research_emitted:
research_emitted = True
update_research_team_status("completed")
buf.update_agent_status("Trader", "in_progress")
async with cl.Step(name="\u2705 Research Debate", type="tool") as step:
step.output = f"**Bull Case:**\n{bull}\n\n---\n\n**Bear Case:**\n{bear}\n\n---\n\n**Research Manager Decision:**\n{judge}"
await q.put({
"type": "debate",
"stage": "research",
"bull": bull,
"bear": bear,
"judge": judge,
"stats": get_stats_dict(stats_handler, buf, start_time),
})
# --- Trader plan (same as CLI lines 1075-1081) ---
if chunk.get("trader_investment_plan") and not trader_step:
trader_step = True
buf.update_report_section("trader_investment_plan", chunk["trader_investment_plan"])
# Trader plan
if chunk.get("trader_investment_plan") and not trader_emitted:
trader_emitted = True
buf.update_agent_status("Trader", "completed")
buf.update_agent_status("Aggressive Analyst", "in_progress")
async with cl.Step(name="\u2705 Trader Plan", type="tool") as step:
plan = chunk["trader_investment_plan"]
step.output = plan[:4000] if len(plan) > 4000 else plan
await q.put({
"type": "trader",
"stage": "trading",
"plan": chunk["trader_investment_plan"],
"stats": get_stats_dict(stats_handler, buf, start_time),
})
# --- Risk debate (same as CLI lines 1084-1118) ---
# Risk debate
if chunk.get("risk_debate_state"):
risk = chunk["risk_debate_state"]
agg = risk.get("aggressive_history", "").strip()
@ -239,66 +226,74 @@ async def on_message(message: cl.Message):
if neu:
buf.update_agent_status("Neutral Analyst", "in_progress")
if judge and not risk_step:
risk_step = True
if judge and not risk_emitted:
risk_emitted = True
buf.update_agent_status("Aggressive Analyst", "completed")
buf.update_agent_status("Conservative Analyst", "completed")
buf.update_agent_status("Neutral Analyst", "completed")
buf.update_agent_status("Portfolio Manager", "completed")
buf.update_report_section("final_trade_decision", f"### Portfolio Manager Decision\n{judge}")
async with cl.Step(name="\u2705 Risk Assessment", type="tool") as step:
parts = []
if agg:
parts.append(f"**Aggressive Analyst:**\n{agg}")
if con:
parts.append(f"**Conservative Analyst:**\n{con}")
if neu:
parts.append(f"**Neutral Analyst:**\n{neu}")
parts.append(f"**Portfolio Manager Decision:**\n{judge}")
step.output = "\n\n---\n\n".join(parts)
# --- Update status message periodically ---
now = time.time()
if now - last_status_update > 5:
last_status_update = now
status_msg.content = (
f"**Analyzing {ticker} for {trade_date}...**\n\n"
f"{format_agent_status_table(buf)}\n\n"
f"*{format_stats(stats_handler, buf, start_time)}*"
)
await status_msg.update()
await q.put({
"type": "risk",
"stage": "risk",
"aggressive": agg,
"conservative": con,
"neutral": neu,
"judge": judge,
"stats": get_stats_dict(stats_handler, buf, start_time),
})
except Exception as e:
await cl.Message(content=f"Error during analysis: {e}").send()
await q.put({"type": "error", "message": str(e)})
await q.put(None)
return
if not final_state:
await cl.Message(content="Analysis produced no results.").send()
return
# Final decision
if final_state:
decision_text = final_state.get("final_trade_decision", "No decision reached.")
signal = graph.process_signal(decision_text)
for agent in buf.agent_status:
buf.update_agent_status(agent, "completed")
st = get_stats_dict(stats_handler, buf, start_time)
await q.put({
"type": "decision",
"stage": "decision",
"signal": signal,
"decision_text": decision_text,
"stats": st,
})
# --- Final decision ---
decision_text = final_state.get("final_trade_decision", "No decision reached.")
signal = graph.process_signal(decision_text)
await q.put(None) # sentinel — stream done
# Mark all agents completed
for agent in buf.agent_status:
buf.update_agent_status(agent, "completed")
# Final status update
status_msg.content = (
f"**Analysis complete for {ticker} ({trade_date})**\n\n"
f"{format_agent_status_table(buf)}\n\n"
f"*{format_stats(stats_handler, buf, start_time)}*"
)
await status_msg.update()
@app.post("/analyze")
async def start_analysis(req: AnalyzeRequest):
ticker = req.ticker.upper().strip()
if not ticker or len(ticker) > 5:
raise HTTPException(400, "Invalid ticker")
trade_date = req.date or str(date.today())
analysis_id = str(uuid.uuid4())
analyses[analysis_id] = asyncio.Queue()
asyncio.create_task(run_analysis(analysis_id, ticker, trade_date))
return {"id": analysis_id, "ticker": ticker, "date": trade_date}
# Send the final decision
await cl.Message(
content=(
f"## {ticker} \u2014 Trading Decision\n\n"
f"### Signal: {signal}\n\n"
f"---\n\n"
f"{decision_text}"
)
).send()
@app.get("/analyze/{analysis_id}/stream")
async def stream_analysis(analysis_id: str):
if analysis_id not in analyses:
raise HTTPException(404, "Analysis not found")
q = analyses[analysis_id]
async def event_generator():
while True:
event = await q.get()
if event is None:
break
yield {"data": json.dumps(event)}
analyses.pop(analysis_id, None)
return EventSourceResponse(event_generator())
@app.get("/health")
async def health():
return {"status": "ok"}

View File

@ -14,7 +14,9 @@ requests
tqdm
pytz
redis
chainlit
fastapi
uvicorn[standard]
sse-starlette
rich
typer
questionary