Replace Chainlit with FastAPI SSE backend
Swap Chainlit chatbot UI for a minimal FastAPI service with:
- POST /analyze to start analysis
- GET /analyze/{id}/stream for SSE progress events
- GET /health for Railway healthcheck
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ac782d179d
commit
52228414ed
|
|
@ -12,5 +12,4 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|||
|
||||
COPY . .
|
||||
|
||||
# Chainlit listens on $PORT (Railway sets this automatically)
|
||||
CMD chainlit run app.py --host 0.0.0.0 --port ${PORT:-8000}
|
||||
CMD uvicorn app:app --host 0.0.0.0 --port ${PORT:-8000}
|
||||
|
|
|
|||
341
app.py
341
app.py
|
|
@ -1,12 +1,16 @@
|
|||
"""Chainlit web UI for TradingAgents — mirrors the CLI experience."""
|
||||
"""FastAPI SSE backend for TradingAgents."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import datetime
|
||||
import uuid
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import date
|
||||
|
||||
import chainlit as cl
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
|
@ -16,22 +20,23 @@ from cli.main import (
|
|||
classify_message_type,
|
||||
update_analyst_statuses,
|
||||
update_research_team_status,
|
||||
ANALYST_ORDER,
|
||||
)
|
||||
|
||||
app = FastAPI(title="TradingAgents API")
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
def parse_ticker_date(text: str):
|
||||
"""Extract ticker symbol and optional date from user message."""
|
||||
date_match = re.search(r"(\d{4}-\d{2}-\d{2})", text)
|
||||
trade_date = date_match.group(1) if date_match else str(date.today())
|
||||
# Active analysis queues: id -> asyncio.Queue
|
||||
analyses: dict[str, asyncio.Queue] = {}
|
||||
|
||||
candidates = re.findall(r"\b([A-Z]{1,5})\b", text)
|
||||
skip = {"I", "A", "THE", "AND", "OR", "FOR", "TO", "IN", "ON", "AT", "IS",
|
||||
"IT", "OF", "BY", "AS", "AN", "BE", "IF", "SO", "DO", "MY", "UP",
|
||||
"NO", "NOT", "ALL", "BUT", "HOW", "GET", "HAS", "HAD", "CAN",
|
||||
"WHAT", "ABOUT", "BUY", "SELL", "HOLD"}
|
||||
tickers = [c for c in candidates if c not in skip]
|
||||
return tickers[0] if tickers else None, trade_date
|
||||
|
||||
class AnalyzeRequest(BaseModel):
|
||||
ticker: str
|
||||
date: str | None = None
|
||||
|
||||
|
||||
def build_config():
|
||||
|
|
@ -52,75 +57,42 @@ def build_config():
|
|||
return config
|
||||
|
||||
|
||||
def format_agent_status_table(buf):
|
||||
"""Build a markdown table showing agent status (like the CLI progress panel)."""
|
||||
teams = {
|
||||
"Analyst Team": ["Market Analyst", "Social Analyst", "News Analyst", "Fundamentals Analyst"],
|
||||
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
|
||||
"Trading Team": ["Trader"],
|
||||
"Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"],
|
||||
"Portfolio Management": ["Portfolio Manager"],
|
||||
}
|
||||
|
||||
icons = {"pending": "\u23f3", "in_progress": "\u26a1", "completed": "\u2705", "error": "\u274c"}
|
||||
lines = ["| Team | Agent | Status |", "|---|---|---|"]
|
||||
|
||||
for team, agents in teams.items():
|
||||
active = [a for a in agents if a in buf.agent_status]
|
||||
for i, agent in enumerate(active):
|
||||
status = buf.agent_status.get(agent, "pending")
|
||||
icon = icons.get(status, "")
|
||||
team_col = team if i == 0 else ""
|
||||
lines.append(f"| {team_col} | {agent} | {icon} {status} |")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_stats(stats_handler, buf, start_time):
|
||||
"""Format footer stats like the CLI."""
|
||||
def get_stats_dict(stats_handler, buf, start_time):
|
||||
"""Build stats dict for SSE events."""
|
||||
s = stats_handler.get_stats()
|
||||
agents_done = sum(1 for v in buf.agent_status.values() if v == "completed")
|
||||
agents_total = len(buf.agent_status)
|
||||
reports_done = buf.get_completed_reports_count()
|
||||
reports_total = len(buf.report_sections)
|
||||
elapsed = time.time() - start_time
|
||||
elapsed_str = f"{int(elapsed // 60):02d}:{int(elapsed % 60):02d}"
|
||||
|
||||
return (
|
||||
f"Agents: {agents_done}/{agents_total} | "
|
||||
f"LLM: {s['llm_calls']} | Tools: {s['tool_calls']} | "
|
||||
f"Tokens: {s['tokens_in']:,}\u2191 {s['tokens_out']:,}\u2193 | "
|
||||
f"Reports: {reports_done}/{reports_total} | "
|
||||
f"\u23f1 {elapsed_str}"
|
||||
)
|
||||
return {
|
||||
"agents_done": agents_done,
|
||||
"agents_total": len(buf.agent_status),
|
||||
"llm_calls": s["llm_calls"],
|
||||
"tool_calls": s["tool_calls"],
|
||||
"tokens_in": s["tokens_in"],
|
||||
"tokens_out": s["tokens_out"],
|
||||
"reports_done": buf.get_completed_reports_count(),
|
||||
"reports_total": len(buf.report_sections),
|
||||
"elapsed": round(elapsed, 1),
|
||||
}
|
||||
|
||||
|
||||
@cl.on_chat_start
|
||||
async def on_chat_start():
|
||||
await cl.Message(
|
||||
content=(
|
||||
"**TradingAgents** \u2014 Multi-Agent LLM Trading Analysis\n\n"
|
||||
"Send a ticker symbol to analyze. Examples:\n"
|
||||
"- `NVDA`\n"
|
||||
"- `Analyze AAPL 2024-12-01`\n"
|
||||
"- `What's the outlook for TSLA?`\n\n"
|
||||
"I'll run a team of AI analysts, researchers, traders, and risk managers "
|
||||
"to produce a trading decision."
|
||||
)
|
||||
).send()
|
||||
def _agent_stage(agent_name):
|
||||
"""Map agent name to pipeline stage."""
|
||||
if agent_name in ("Market Analyst", "Social Analyst", "News Analyst", "Fundamentals Analyst"):
|
||||
return "analysts"
|
||||
if agent_name in ("Bull Researcher", "Bear Researcher", "Research Manager"):
|
||||
return "research"
|
||||
if agent_name == "Trader":
|
||||
return "trading"
|
||||
if agent_name in ("Aggressive Analyst", "Conservative Analyst", "Neutral Analyst"):
|
||||
return "risk"
|
||||
if agent_name == "Portfolio Manager":
|
||||
return "decision"
|
||||
return "unknown"
|
||||
|
||||
|
||||
@cl.on_message
|
||||
async def on_message(message: cl.Message):
|
||||
ticker, trade_date = parse_ticker_date(message.content)
|
||||
|
||||
if not ticker:
|
||||
await cl.Message(
|
||||
content="I couldn't find a ticker symbol. Try something like `NVDA` or `Analyze AAPL 2024-12-01`."
|
||||
).send()
|
||||
return
|
||||
|
||||
# --- Build graph ---
|
||||
async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||
"""Background task that runs the TradingAgents pipeline and pushes SSE events."""
|
||||
q = analyses[analysis_id]
|
||||
config = build_config()
|
||||
stats_handler = StatsCallbackHandler()
|
||||
selected_analysts = ["market", "social", "news", "fundamentals"]
|
||||
|
|
@ -133,35 +105,28 @@ async def on_message(message: cl.Message):
|
|||
callbacks=[stats_handler],
|
||||
)
|
||||
except Exception as e:
|
||||
await cl.Message(content=f"Failed to initialize agents: {e}").send()
|
||||
await q.put({"type": "error", "message": f"Init failed: {e}"})
|
||||
await q.put(None)
|
||||
return
|
||||
|
||||
# --- Initialize message buffer (same as CLI) ---
|
||||
buf = MessageBuffer()
|
||||
buf.init_for_analysis(selected_analysts)
|
||||
|
||||
# --- Status message (will be updated as agents progress) ---
|
||||
status_msg = cl.Message(content=f"**Analyzing {ticker} for {trade_date}...**\n\n{format_agent_status_table(buf)}")
|
||||
await status_msg.send()
|
||||
|
||||
# --- Stream the graph ---
|
||||
init_state = graph.propagator.create_initial_state(ticker, trade_date)
|
||||
args = graph.propagator.get_graph_args(callbacks=[stats_handler])
|
||||
start_time = time.time()
|
||||
|
||||
# Steps we'll create as agents complete
|
||||
analyst_steps = {} # field -> Step
|
||||
research_step = None
|
||||
trader_step = None
|
||||
risk_step = None
|
||||
last_status_update = 0
|
||||
emitted_reports = set()
|
||||
research_emitted = False
|
||||
trader_emitted = False
|
||||
risk_emitted = False
|
||||
final_state = None
|
||||
prev_statuses = {}
|
||||
|
||||
try:
|
||||
async for chunk in graph.graph.astream(init_state, **args):
|
||||
final_state = chunk
|
||||
|
||||
# --- Process messages (same as CLI lines 1024-1044) ---
|
||||
# Process messages (same logic as Chainlit app)
|
||||
if chunk.get("messages") and len(chunk["messages"]) > 0:
|
||||
last_msg = chunk["messages"][-1]
|
||||
msg_id = getattr(last_msg, "id", None)
|
||||
|
|
@ -177,24 +142,41 @@ async def on_message(message: cl.Message):
|
|||
else:
|
||||
buf.add_tool_call(tc.name, tc.args)
|
||||
|
||||
# --- Update analyst statuses (same as CLI line 1047) ---
|
||||
update_analyst_statuses(buf, chunk)
|
||||
st = get_stats_dict(stats_handler, buf, start_time)
|
||||
|
||||
# --- Emit analyst report Steps as they complete ---
|
||||
report_names = {
|
||||
"market_report": "Market Analyst",
|
||||
"sentiment_report": "Sentiment Analyst",
|
||||
"news_report": "News Analyst",
|
||||
"fundamentals_report": "Fundamentals Analyst",
|
||||
# Emit agent status changes only (avoid flooding)
|
||||
for agent, status in buf.agent_status.items():
|
||||
if prev_statuses.get(agent) != status:
|
||||
prev_statuses[agent] = status
|
||||
await q.put({
|
||||
"type": "agent_update",
|
||||
"agent": agent,
|
||||
"stage": _agent_stage(agent),
|
||||
"status": status,
|
||||
"stats": st,
|
||||
})
|
||||
|
||||
# Analyst reports
|
||||
report_map = {
|
||||
"market_report": ("Market Analyst", "analysts"),
|
||||
"sentiment_report": ("Social Analyst", "analysts"),
|
||||
"news_report": ("News Analyst", "analysts"),
|
||||
"fundamentals_report": ("Fundamentals Analyst", "analysts"),
|
||||
}
|
||||
for field, name in report_names.items():
|
||||
if field not in analyst_steps and chunk.get(field):
|
||||
analyst_steps[field] = True
|
||||
async with cl.Step(name=f"\u2705 {name} Report", type="tool") as step:
|
||||
report = chunk[field]
|
||||
step.output = report[:4000] if len(report) > 4000 else report
|
||||
for field, (agent_name, stage) in report_map.items():
|
||||
if field not in emitted_reports and chunk.get(field):
|
||||
emitted_reports.add(field)
|
||||
await q.put({
|
||||
"type": "report",
|
||||
"agent": agent_name,
|
||||
"stage": stage,
|
||||
"field": field,
|
||||
"report": chunk[field],
|
||||
"stats": st,
|
||||
})
|
||||
|
||||
# --- Research debate (same as CLI lines 1050-1072) ---
|
||||
# Research debate
|
||||
if chunk.get("investment_debate_state"):
|
||||
debate = chunk["investment_debate_state"]
|
||||
bull = debate.get("bull_history", "").strip()
|
||||
|
|
@ -203,28 +185,33 @@ async def on_message(message: cl.Message):
|
|||
|
||||
if bull or bear:
|
||||
update_research_team_status("in_progress")
|
||||
buf.update_report_section("investment_plan",
|
||||
(f"### Bull Researcher\n{bull}\n\n### Bear Researcher\n{bear}") if bear else f"### Bull Researcher\n{bull}")
|
||||
|
||||
if judge and not research_step:
|
||||
research_step = True
|
||||
buf.update_report_section("investment_plan", f"### Research Manager Decision\n{judge}")
|
||||
if judge and not research_emitted:
|
||||
research_emitted = True
|
||||
update_research_team_status("completed")
|
||||
buf.update_agent_status("Trader", "in_progress")
|
||||
async with cl.Step(name="\u2705 Research Debate", type="tool") as step:
|
||||
step.output = f"**Bull Case:**\n{bull}\n\n---\n\n**Bear Case:**\n{bear}\n\n---\n\n**Research Manager Decision:**\n{judge}"
|
||||
await q.put({
|
||||
"type": "debate",
|
||||
"stage": "research",
|
||||
"bull": bull,
|
||||
"bear": bear,
|
||||
"judge": judge,
|
||||
"stats": get_stats_dict(stats_handler, buf, start_time),
|
||||
})
|
||||
|
||||
# --- Trader plan (same as CLI lines 1075-1081) ---
|
||||
if chunk.get("trader_investment_plan") and not trader_step:
|
||||
trader_step = True
|
||||
buf.update_report_section("trader_investment_plan", chunk["trader_investment_plan"])
|
||||
# Trader plan
|
||||
if chunk.get("trader_investment_plan") and not trader_emitted:
|
||||
trader_emitted = True
|
||||
buf.update_agent_status("Trader", "completed")
|
||||
buf.update_agent_status("Aggressive Analyst", "in_progress")
|
||||
async with cl.Step(name="\u2705 Trader Plan", type="tool") as step:
|
||||
plan = chunk["trader_investment_plan"]
|
||||
step.output = plan[:4000] if len(plan) > 4000 else plan
|
||||
await q.put({
|
||||
"type": "trader",
|
||||
"stage": "trading",
|
||||
"plan": chunk["trader_investment_plan"],
|
||||
"stats": get_stats_dict(stats_handler, buf, start_time),
|
||||
})
|
||||
|
||||
# --- Risk debate (same as CLI lines 1084-1118) ---
|
||||
# Risk debate
|
||||
if chunk.get("risk_debate_state"):
|
||||
risk = chunk["risk_debate_state"]
|
||||
agg = risk.get("aggressive_history", "").strip()
|
||||
|
|
@ -239,66 +226,74 @@ async def on_message(message: cl.Message):
|
|||
if neu:
|
||||
buf.update_agent_status("Neutral Analyst", "in_progress")
|
||||
|
||||
if judge and not risk_step:
|
||||
risk_step = True
|
||||
if judge and not risk_emitted:
|
||||
risk_emitted = True
|
||||
buf.update_agent_status("Aggressive Analyst", "completed")
|
||||
buf.update_agent_status("Conservative Analyst", "completed")
|
||||
buf.update_agent_status("Neutral Analyst", "completed")
|
||||
buf.update_agent_status("Portfolio Manager", "completed")
|
||||
buf.update_report_section("final_trade_decision", f"### Portfolio Manager Decision\n{judge}")
|
||||
|
||||
async with cl.Step(name="\u2705 Risk Assessment", type="tool") as step:
|
||||
parts = []
|
||||
if agg:
|
||||
parts.append(f"**Aggressive Analyst:**\n{agg}")
|
||||
if con:
|
||||
parts.append(f"**Conservative Analyst:**\n{con}")
|
||||
if neu:
|
||||
parts.append(f"**Neutral Analyst:**\n{neu}")
|
||||
parts.append(f"**Portfolio Manager Decision:**\n{judge}")
|
||||
step.output = "\n\n---\n\n".join(parts)
|
||||
|
||||
# --- Update status message periodically ---
|
||||
now = time.time()
|
||||
if now - last_status_update > 5:
|
||||
last_status_update = now
|
||||
status_msg.content = (
|
||||
f"**Analyzing {ticker} for {trade_date}...**\n\n"
|
||||
f"{format_agent_status_table(buf)}\n\n"
|
||||
f"*{format_stats(stats_handler, buf, start_time)}*"
|
||||
)
|
||||
await status_msg.update()
|
||||
await q.put({
|
||||
"type": "risk",
|
||||
"stage": "risk",
|
||||
"aggressive": agg,
|
||||
"conservative": con,
|
||||
"neutral": neu,
|
||||
"judge": judge,
|
||||
"stats": get_stats_dict(stats_handler, buf, start_time),
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
await cl.Message(content=f"Error during analysis: {e}").send()
|
||||
await q.put({"type": "error", "message": str(e)})
|
||||
await q.put(None)
|
||||
return
|
||||
|
||||
if not final_state:
|
||||
await cl.Message(content="Analysis produced no results.").send()
|
||||
return
|
||||
# Final decision
|
||||
if final_state:
|
||||
decision_text = final_state.get("final_trade_decision", "No decision reached.")
|
||||
signal = graph.process_signal(decision_text)
|
||||
for agent in buf.agent_status:
|
||||
buf.update_agent_status(agent, "completed")
|
||||
st = get_stats_dict(stats_handler, buf, start_time)
|
||||
await q.put({
|
||||
"type": "decision",
|
||||
"stage": "decision",
|
||||
"signal": signal,
|
||||
"decision_text": decision_text,
|
||||
"stats": st,
|
||||
})
|
||||
|
||||
# --- Final decision ---
|
||||
decision_text = final_state.get("final_trade_decision", "No decision reached.")
|
||||
signal = graph.process_signal(decision_text)
|
||||
await q.put(None) # sentinel — stream done
|
||||
|
||||
# Mark all agents completed
|
||||
for agent in buf.agent_status:
|
||||
buf.update_agent_status(agent, "completed")
|
||||
|
||||
# Final status update
|
||||
status_msg.content = (
|
||||
f"**Analysis complete for {ticker} ({trade_date})**\n\n"
|
||||
f"{format_agent_status_table(buf)}\n\n"
|
||||
f"*{format_stats(stats_handler, buf, start_time)}*"
|
||||
)
|
||||
await status_msg.update()
|
||||
@app.post("/analyze")
|
||||
async def start_analysis(req: AnalyzeRequest):
|
||||
ticker = req.ticker.upper().strip()
|
||||
if not ticker or len(ticker) > 5:
|
||||
raise HTTPException(400, "Invalid ticker")
|
||||
trade_date = req.date or str(date.today())
|
||||
analysis_id = str(uuid.uuid4())
|
||||
analyses[analysis_id] = asyncio.Queue()
|
||||
asyncio.create_task(run_analysis(analysis_id, ticker, trade_date))
|
||||
return {"id": analysis_id, "ticker": ticker, "date": trade_date}
|
||||
|
||||
# Send the final decision
|
||||
await cl.Message(
|
||||
content=(
|
||||
f"## {ticker} \u2014 Trading Decision\n\n"
|
||||
f"### Signal: {signal}\n\n"
|
||||
f"---\n\n"
|
||||
f"{decision_text}"
|
||||
)
|
||||
).send()
|
||||
|
||||
@app.get("/analyze/{analysis_id}/stream")
|
||||
async def stream_analysis(analysis_id: str):
|
||||
if analysis_id not in analyses:
|
||||
raise HTTPException(404, "Analysis not found")
|
||||
q = analyses[analysis_id]
|
||||
|
||||
async def event_generator():
|
||||
while True:
|
||||
event = await q.get()
|
||||
if event is None:
|
||||
break
|
||||
yield {"data": json.dumps(event)}
|
||||
analyses.pop(analysis_id, None)
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,9 @@ requests
|
|||
tqdm
|
||||
pytz
|
||||
redis
|
||||
chainlit
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
sse-starlette
|
||||
rich
|
||||
typer
|
||||
questionary
|
||||
|
|
|
|||
Loading…
Reference in New Issue