diff --git a/Dockerfile b/Dockerfile index 9e3afde6..b67c6d8a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,5 +12,4 @@ RUN pip install --no-cache-dir -r requirements.txt COPY . . -# Chainlit listens on $PORT (Railway sets this automatically) -CMD chainlit run app.py --host 0.0.0.0 --port ${PORT:-8000} +CMD uvicorn app:app --host 0.0.0.0 --port ${PORT:-8000} diff --git a/app.py b/app.py index 79616949..ed0111f7 100644 --- a/app.py +++ b/app.py @@ -1,12 +1,16 @@ -"""Chainlit web UI for TradingAgents — mirrors the CLI experience.""" +"""FastAPI SSE backend for TradingAgents.""" import os -import re import time -import datetime +import uuid +import asyncio +import json from datetime import date -import chainlit as cl +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from sse_starlette.sse import EventSourceResponse from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG @@ -16,22 +20,23 @@ from cli.main import ( classify_message_type, update_analyst_statuses, update_research_team_status, - ANALYST_ORDER, ) +app = FastAPI(title="TradingAgents API") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], +) -def parse_ticker_date(text: str): - """Extract ticker symbol and optional date from user message.""" - date_match = re.search(r"(\d{4}-\d{2}-\d{2})", text) - trade_date = date_match.group(1) if date_match else str(date.today()) +# Active analysis queues: id -> asyncio.Queue +analyses: dict[str, asyncio.Queue] = {} - candidates = re.findall(r"\b([A-Z]{1,5})\b", text) - skip = {"I", "A", "THE", "AND", "OR", "FOR", "TO", "IN", "ON", "AT", "IS", - "IT", "OF", "BY", "AS", "AN", "BE", "IF", "SO", "DO", "MY", "UP", - "NO", "NOT", "ALL", "BUT", "HOW", "GET", "HAS", "HAD", "CAN", - "WHAT", "ABOUT", "BUY", "SELL", "HOLD"} - tickers = [c for c in candidates if c not in skip] - return tickers[0] if tickers else None, trade_date + +class AnalyzeRequest(BaseModel): + ticker: str + date: str | None = None def build_config(): @@ -52,75 +57,42 @@ def build_config(): return config -def format_agent_status_table(buf): - """Build a markdown table showing agent status (like the CLI progress panel).""" - teams = { - "Analyst Team": ["Market Analyst", "Social Analyst", "News Analyst", "Fundamentals Analyst"], - "Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"], - "Trading Team": ["Trader"], - "Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"], - "Portfolio Management": ["Portfolio Manager"], - } - - icons = {"pending": "\u23f3", "in_progress": "\u26a1", "completed": "\u2705", "error": "\u274c"} - lines = ["| Team | Agent | Status |", "|---|---|---|"] - - for team, agents in teams.items(): - active = [a for a in agents if a in buf.agent_status] - for i, agent in enumerate(active): - status = buf.agent_status.get(agent, "pending") - icon = icons.get(status, "") - team_col = team if i == 0 else "" - lines.append(f"| {team_col} | {agent} | {icon} {status} |") - - return "\n".join(lines) - - -def format_stats(stats_handler, buf, start_time): - """Format footer stats like the CLI.""" +def get_stats_dict(stats_handler, buf, start_time): + """Build stats dict for SSE events.""" s = stats_handler.get_stats() agents_done = sum(1 for v in buf.agent_status.values() if v == "completed") - agents_total = len(buf.agent_status) - reports_done = buf.get_completed_reports_count() - reports_total = len(buf.report_sections) elapsed = time.time() - start_time - elapsed_str = f"{int(elapsed // 60):02d}:{int(elapsed % 60):02d}" - - return ( - f"Agents: {agents_done}/{agents_total} | " - f"LLM: {s['llm_calls']} | Tools: {s['tool_calls']} | " - f"Tokens: {s['tokens_in']:,}\u2191 {s['tokens_out']:,}\u2193 | " - f"Reports: {reports_done}/{reports_total} | " - f"\u23f1 {elapsed_str}" - ) + return { + "agents_done": agents_done, + "agents_total": len(buf.agent_status), + "llm_calls": s["llm_calls"], + "tool_calls": s["tool_calls"], + "tokens_in": s["tokens_in"], + "tokens_out": s["tokens_out"], + "reports_done": buf.get_completed_reports_count(), + "reports_total": len(buf.report_sections), + "elapsed": round(elapsed, 1), + } -@cl.on_chat_start -async def on_chat_start(): - await cl.Message( - content=( - "**TradingAgents** \u2014 Multi-Agent LLM Trading Analysis\n\n" - "Send a ticker symbol to analyze. Examples:\n" - "- `NVDA`\n" - "- `Analyze AAPL 2024-12-01`\n" - "- `What's the outlook for TSLA?`\n\n" - "I'll run a team of AI analysts, researchers, traders, and risk managers " - "to produce a trading decision." - ) - ).send() +def _agent_stage(agent_name): + """Map agent name to pipeline stage.""" + if agent_name in ("Market Analyst", "Social Analyst", "News Analyst", "Fundamentals Analyst"): + return "analysts" + if agent_name in ("Bull Researcher", "Bear Researcher", "Research Manager"): + return "research" + if agent_name == "Trader": + return "trading" + if agent_name in ("Aggressive Analyst", "Conservative Analyst", "Neutral Analyst"): + return "risk" + if agent_name == "Portfolio Manager": + return "decision" + return "unknown" -@cl.on_message -async def on_message(message: cl.Message): - ticker, trade_date = parse_ticker_date(message.content) - - if not ticker: - await cl.Message( - content="I couldn't find a ticker symbol. Try something like `NVDA` or `Analyze AAPL 2024-12-01`." - ).send() - return - - # --- Build graph --- +async def run_analysis(analysis_id: str, ticker: str, trade_date: str): + """Background task that runs the TradingAgents pipeline and pushes SSE events.""" + q = analyses[analysis_id] config = build_config() stats_handler = StatsCallbackHandler() selected_analysts = ["market", "social", "news", "fundamentals"] @@ -133,35 +105,28 @@ async def on_message(message: cl.Message): callbacks=[stats_handler], ) except Exception as e: - await cl.Message(content=f"Failed to initialize agents: {e}").send() + await q.put({"type": "error", "message": f"Init failed: {e}"}) + await q.put(None) return - # --- Initialize message buffer (same as CLI) --- buf = MessageBuffer() buf.init_for_analysis(selected_analysts) - - # --- Status message (will be updated as agents progress) --- - status_msg = cl.Message(content=f"**Analyzing {ticker} for {trade_date}...**\n\n{format_agent_status_table(buf)}") - await status_msg.send() - - # --- Stream the graph --- init_state = graph.propagator.create_initial_state(ticker, trade_date) args = graph.propagator.get_graph_args(callbacks=[stats_handler]) start_time = time.time() - # Steps we'll create as agents complete - analyst_steps = {} # field -> Step - research_step = None - trader_step = None - risk_step = None - last_status_update = 0 + emitted_reports = set() + research_emitted = False + trader_emitted = False + risk_emitted = False final_state = None + prev_statuses = {} try: async for chunk in graph.graph.astream(init_state, **args): final_state = chunk - # --- Process messages (same as CLI lines 1024-1044) --- + # Process messages (same logic as Chainlit app) if chunk.get("messages") and len(chunk["messages"]) > 0: last_msg = chunk["messages"][-1] msg_id = getattr(last_msg, "id", None) @@ -177,24 +142,41 @@ async def on_message(message: cl.Message): else: buf.add_tool_call(tc.name, tc.args) - # --- Update analyst statuses (same as CLI line 1047) --- update_analyst_statuses(buf, chunk) + st = get_stats_dict(stats_handler, buf, start_time) - # --- Emit analyst report Steps as they complete --- - report_names = { - "market_report": "Market Analyst", - "sentiment_report": "Sentiment Analyst", - "news_report": "News Analyst", - "fundamentals_report": "Fundamentals Analyst", + # Emit agent status changes only (avoid flooding) + for agent, status in buf.agent_status.items(): + if prev_statuses.get(agent) != status: + prev_statuses[agent] = status + await q.put({ + "type": "agent_update", + "agent": agent, + "stage": _agent_stage(agent), + "status": status, + "stats": st, + }) + + # Analyst reports + report_map = { + "market_report": ("Market Analyst", "analysts"), + "sentiment_report": ("Social Analyst", "analysts"), + "news_report": ("News Analyst", "analysts"), + "fundamentals_report": ("Fundamentals Analyst", "analysts"), } - for field, name in report_names.items(): - if field not in analyst_steps and chunk.get(field): - analyst_steps[field] = True - async with cl.Step(name=f"\u2705 {name} Report", type="tool") as step: - report = chunk[field] - step.output = report[:4000] if len(report) > 4000 else report + for field, (agent_name, stage) in report_map.items(): + if field not in emitted_reports and chunk.get(field): + emitted_reports.add(field) + await q.put({ + "type": "report", + "agent": agent_name, + "stage": stage, + "field": field, + "report": chunk[field], + "stats": st, + }) - # --- Research debate (same as CLI lines 1050-1072) --- + # Research debate if chunk.get("investment_debate_state"): debate = chunk["investment_debate_state"] bull = debate.get("bull_history", "").strip() @@ -203,28 +185,33 @@ async def on_message(message: cl.Message): if bull or bear: update_research_team_status("in_progress") - buf.update_report_section("investment_plan", - (f"### Bull Researcher\n{bull}\n\n### Bear Researcher\n{bear}") if bear else f"### Bull Researcher\n{bull}") - if judge and not research_step: - research_step = True - buf.update_report_section("investment_plan", f"### Research Manager Decision\n{judge}") + if judge and not research_emitted: + research_emitted = True update_research_team_status("completed") buf.update_agent_status("Trader", "in_progress") - async with cl.Step(name="\u2705 Research Debate", type="tool") as step: - step.output = f"**Bull Case:**\n{bull}\n\n---\n\n**Bear Case:**\n{bear}\n\n---\n\n**Research Manager Decision:**\n{judge}" + await q.put({ + "type": "debate", + "stage": "research", + "bull": bull, + "bear": bear, + "judge": judge, + "stats": get_stats_dict(stats_handler, buf, start_time), + }) - # --- Trader plan (same as CLI lines 1075-1081) --- - if chunk.get("trader_investment_plan") and not trader_step: - trader_step = True - buf.update_report_section("trader_investment_plan", chunk["trader_investment_plan"]) + # Trader plan + if chunk.get("trader_investment_plan") and not trader_emitted: + trader_emitted = True buf.update_agent_status("Trader", "completed") buf.update_agent_status("Aggressive Analyst", "in_progress") - async with cl.Step(name="\u2705 Trader Plan", type="tool") as step: - plan = chunk["trader_investment_plan"] - step.output = plan[:4000] if len(plan) > 4000 else plan + await q.put({ + "type": "trader", + "stage": "trading", + "plan": chunk["trader_investment_plan"], + "stats": get_stats_dict(stats_handler, buf, start_time), + }) - # --- Risk debate (same as CLI lines 1084-1118) --- + # Risk debate if chunk.get("risk_debate_state"): risk = chunk["risk_debate_state"] agg = risk.get("aggressive_history", "").strip() @@ -239,66 +226,74 @@ async def on_message(message: cl.Message): if neu: buf.update_agent_status("Neutral Analyst", "in_progress") - if judge and not risk_step: - risk_step = True + if judge and not risk_emitted: + risk_emitted = True buf.update_agent_status("Aggressive Analyst", "completed") buf.update_agent_status("Conservative Analyst", "completed") buf.update_agent_status("Neutral Analyst", "completed") buf.update_agent_status("Portfolio Manager", "completed") - buf.update_report_section("final_trade_decision", f"### Portfolio Manager Decision\n{judge}") - - async with cl.Step(name="\u2705 Risk Assessment", type="tool") as step: - parts = [] - if agg: - parts.append(f"**Aggressive Analyst:**\n{agg}") - if con: - parts.append(f"**Conservative Analyst:**\n{con}") - if neu: - parts.append(f"**Neutral Analyst:**\n{neu}") - parts.append(f"**Portfolio Manager Decision:**\n{judge}") - step.output = "\n\n---\n\n".join(parts) - - # --- Update status message periodically --- - now = time.time() - if now - last_status_update > 5: - last_status_update = now - status_msg.content = ( - f"**Analyzing {ticker} for {trade_date}...**\n\n" - f"{format_agent_status_table(buf)}\n\n" - f"*{format_stats(stats_handler, buf, start_time)}*" - ) - await status_msg.update() + await q.put({ + "type": "risk", + "stage": "risk", + "aggressive": agg, + "conservative": con, + "neutral": neu, + "judge": judge, + "stats": get_stats_dict(stats_handler, buf, start_time), + }) except Exception as e: - await cl.Message(content=f"Error during analysis: {e}").send() + await q.put({"type": "error", "message": str(e)}) + await q.put(None) return - if not final_state: - await cl.Message(content="Analysis produced no results.").send() - return + # Final decision + if final_state: + decision_text = final_state.get("final_trade_decision", "No decision reached.") + signal = graph.process_signal(decision_text) + for agent in buf.agent_status: + buf.update_agent_status(agent, "completed") + st = get_stats_dict(stats_handler, buf, start_time) + await q.put({ + "type": "decision", + "stage": "decision", + "signal": signal, + "decision_text": decision_text, + "stats": st, + }) - # --- Final decision --- - decision_text = final_state.get("final_trade_decision", "No decision reached.") - signal = graph.process_signal(decision_text) + await q.put(None) # sentinel — stream done - # Mark all agents completed - for agent in buf.agent_status: - buf.update_agent_status(agent, "completed") - # Final status update - status_msg.content = ( - f"**Analysis complete for {ticker} ({trade_date})**\n\n" - f"{format_agent_status_table(buf)}\n\n" - f"*{format_stats(stats_handler, buf, start_time)}*" - ) - await status_msg.update() +@app.post("/analyze") +async def start_analysis(req: AnalyzeRequest): + ticker = req.ticker.upper().strip() + if not ticker or len(ticker) > 5: + raise HTTPException(400, "Invalid ticker") + trade_date = req.date or str(date.today()) + analysis_id = str(uuid.uuid4()) + analyses[analysis_id] = asyncio.Queue() + asyncio.create_task(run_analysis(analysis_id, ticker, trade_date)) + return {"id": analysis_id, "ticker": ticker, "date": trade_date} - # Send the final decision - await cl.Message( - content=( - f"## {ticker} \u2014 Trading Decision\n\n" - f"### Signal: {signal}\n\n" - f"---\n\n" - f"{decision_text}" - ) - ).send() + +@app.get("/analyze/{analysis_id}/stream") +async def stream_analysis(analysis_id: str): + if analysis_id not in analyses: + raise HTTPException(404, "Analysis not found") + q = analyses[analysis_id] + + async def event_generator(): + while True: + event = await q.get() + if event is None: + break + yield {"data": json.dumps(event)} + analyses.pop(analysis_id, None) + + return EventSourceResponse(event_generator()) + + +@app.get("/health") +async def health(): + return {"status": "ok"} diff --git a/requirements.txt b/requirements.txt index 9e51ed98..aa4384da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,9 @@ requests tqdm pytz redis -chainlit +fastapi +uvicorn[standard] +sse-starlette rich typer questionary