""" TradingAgents Web Dashboard Backend FastAPI REST API + WebSocket for real-time analysis progress """ import asyncio import fcntl import json import os import subprocess import sys import time import traceback from datetime import datetime from pathlib import Path from typing import Optional from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Query from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel # Path to TradingAgents repo root REPO_ROOT = Path(__file__).parent.parent.parent # Use the currently running Python interpreter ANALYSIS_PYTHON = Path(sys.executable) # Task state persistence directory TASK_STATUS_DIR = Path(__file__).parent / "data" / "task_status" # ============== Lifespan ============== @asynccontextmanager async def lifespan(app: FastAPI): """Startup and shutdown events""" app.state.active_connections: dict[str, list[WebSocket]] = {} app.state.task_results: dict[str, dict] = {} app.state.analysis_tasks: dict[str, asyncio.Task] = {} # Restore persisted task states from disk TASK_STATUS_DIR.mkdir(parents=True, exist_ok=True) for f in TASK_STATUS_DIR.glob("*.json"): try: data = json.loads(f.read_text()) app.state.task_results[data["task_id"]] = data except Exception: pass yield # ============== App ============== app = FastAPI( title="TradingAgents Web Dashboard API", version="0.1.0", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============== Pydantic Models ============== class AnalysisRequest(BaseModel): ticker: str date: Optional[str] = None class ScreenRequest(BaseModel): mode: str = "china_strict" # ============== Cache Helpers ============== CACHE_DIR = Path(__file__).parent.parent / "cache" CACHE_TTL_SECONDS = 300 # 5 minutes def _get_cache_path(mode: str) -> Path: return CACHE_DIR / f"screen_{mode}.json" def _load_from_cache(mode: str) -> Optional[dict]: cache_path = _get_cache_path(mode) if not cache_path.exists(): return None try: age = time.time() - cache_path.stat().st_mtime if age > CACHE_TTL_SECONDS: return None with open(cache_path) as f: return json.load(f) except Exception: return None def _save_task_status(task_id: str, data: dict): """Persist task state to disk""" try: TASK_STATUS_DIR.mkdir(parents=True, exist_ok=True) (TASK_STATUS_DIR / f"{task_id}.json").write_text(json.dumps(data, ensure_ascii=False)) except Exception: pass def _delete_task_status(task_id: str): """Remove persisted task state from disk""" try: (TASK_STATUS_DIR / f"{task_id}.json").unlink(missing_ok=True) except Exception: pass try: CACHE_DIR.mkdir(parents=True, exist_ok=True) cache_path = _get_cache_path(mode) with open(cache_path, "w") as f: json.dump(data, f) except Exception: pass # ============== SEPA Screening ============== def _run_sepa_screening(mode: str) -> dict: """Run SEPA screening synchronously in thread""" sys.path.insert(0, str(REPO_ROOT)) from sepa_screener import screen_all, china_stocks results = screen_all(mode=mode, max_workers=5) total = len(china_stocks) return { "mode": mode, "total_stocks": total, "passed": len(results), "results": results, } @app.get("/api/stocks/screen") async def screen_stocks(mode: str = Query("china_strict"), refresh: bool = Query(False)): """Screen stocks using SEPA criteria with caching""" if not refresh: cached = _load_from_cache(mode) if cached: return {**cached, "cached": True} # Run in thread pool (blocks thread but not event loop) loop = asyncio.get_event_loop() result = await loop.run_in_executor(None, lambda: _run_sepa_screening(mode)) _save_to_cache(mode, result) return {**result, "cached": False} # ============== Analysis Execution ============== # Script template for subprocess-based analysis # api_key is passed via environment variable (not CLI) for security ANALYSIS_SCRIPT_TEMPLATE = """ import sys import os ticker = sys.argv[1] date = sys.argv[2] repo_root = sys.argv[3] sys.path.insert(0, repo_root) os.environ["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic" import py_mini_racer sys.modules["mini_racer"] = py_mini_racer from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG from pathlib import Path print("STAGE:analysts", flush=True) config = DEFAULT_CONFIG.copy() config["llm_provider"] = "anthropic" config["deep_think_llm"] = "MiniMax-M2.7-highspeed" config["quick_think_llm"] = "MiniMax-M2.7-highspeed" config["backend_url"] = "https://api.minimaxi.com/anthropic" config["max_debate_rounds"] = 1 config["max_risk_discuss_rounds"] = 1 print("STAGE:research", flush=True) ta = TradingAgentsGraph(debug=False, config=config) print("STAGE:trading", flush=True) final_state, decision = ta.propagate(ticker, date) print("STAGE:risk", flush=True) results_dir = Path(repo_root) / "results" / ticker / date results_dir.mkdir(parents=True, exist_ok=True) signal = decision if isinstance(decision, str) else decision.get("signal", "HOLD") report_content = ( "# TradingAgents 分析报告\\n\\n" "**股票**: " + ticker + "\\n" "**日期**: " + date + "\\n\\n" "## 最终决策\\n\\n" "**" + signal + "**\\n\\n" "## 分析摘要\\n\\n" + final_state.get("market_report", "N/A") + "\\n\\n" "## 基本面\\n\\n" + final_state.get("fundamentals_report", "N/A") + "\\n" ) report_path = results_dir / "complete_report.md" report_path.write_text(report_content) print("STAGE:portfolio", flush=True) print("ANALYSIS_COMPLETE:" + signal, flush=True) """ @app.post("/api/analysis/start") async def start_analysis(request: AnalysisRequest): """Start a new analysis task""" import uuid task_id = f"{request.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}" date = request.date or datetime.now().strftime("%Y-%m-%d") # Initialize task state app.state.task_results[task_id] = { "task_id": task_id, "ticker": request.ticker, "date": date, "status": "running", "progress": 0, "current_stage": "analysts", "elapsed": 0, "stages": [ {"status": "running", "completed_at": None}, {"status": "pending", "completed_at": None}, {"status": "pending", "completed_at": None}, {"status": "pending", "completed_at": None}, {"status": "pending", "completed_at": None}, ], "logs": [], "decision": None, "error": None, } # Get API key - fail fast before storing a running task api_key = os.environ.get("ANTHROPIC_API_KEY") if not api_key: raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set") await broadcast_progress(task_id, app.state.task_results[task_id]) # Write analysis script to temp file (avoids subprocess -c quoting issues) script_path = Path(f"/tmp/analysis_{task_id}.py") script_content = ANALYSIS_SCRIPT_TEMPLATE script_path.write_text(script_content) # Store process reference for cancellation app.state.processes = getattr(app.state, 'processes', {}) app.state.processes[task_id] = None # Cancellation event for the monitor coroutine cancel_event = asyncio.Event() # Stage name to index mapping STAGE_NAMES = ["analysts", "research", "trading", "risk", "portfolio"] def _update_task_stage(stage_name: str): """Update task state for a completed stage and mark next as running.""" try: idx = STAGE_NAMES.index(stage_name) except ValueError: return # Mark all previous stages as completed for i in range(idx): if app.state.task_results[task_id]["stages"][i]["status"] != "completed": app.state.task_results[task_id]["stages"][i]["status"] = "completed" app.state.task_results[task_id]["stages"][i]["completed_at"] = datetime.now().strftime("%H:%M:%S") # Mark current as completed if app.state.task_results[task_id]["stages"][idx]["status"] != "completed": app.state.task_results[task_id]["stages"][idx]["status"] = "completed" app.state.task_results[task_id]["stages"][idx]["completed_at"] = datetime.now().strftime("%H:%M:%S") # Mark next as running if idx + 1 < 5: if app.state.task_results[task_id]["stages"][idx + 1]["status"] == "pending": app.state.task_results[task_id]["stages"][idx + 1]["status"] = "running" # Update progress app.state.task_results[task_id]["progress"] = int((idx + 1) / 5 * 100) app.state.task_results[task_id]["current_stage"] = stage_name async def monitor_subprocess(task_id: str, proc: asyncio.subprocess.Process, cancel_evt: asyncio.Event): """Monitor subprocess stdout for stage markers and broadcast progress.""" # Set stdout to non-blocking fd = proc.stdout.fileno() fl = fcntl.fcntl(fd, fcntl.GETFL) fcntl.fcntl(fd, fcntl.SETFL, fl | os.O_NONBLOCK) while not cancel_evt.is_set(): if proc.returncode is not None: break await asyncio.sleep(5) if cancel_evt.is_set(): break try: chunk = os.read(fd, 32768) if chunk: for line in chunk.decode().splitlines(): if line.startswith("STAGE:"): stage = line.split(":", 1)[1].strip() _update_task_stage(stage) await broadcast_progress(task_id, app.state.task_results[task_id]) except (BlockingIOError, OSError): # No data available yet pass async def run_analysis(): """Run analysis subprocess and broadcast progress""" try: # Use clean environment - don't inherit parent env clean_env = {k: v for k, v in os.environ.items() if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))} clean_env["ANTHROPIC_API_KEY"] = api_key clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic" proc = await asyncio.create_subprocess_exec( str(ANALYSIS_PYTHON), str(script_path), request.ticker, date, str(REPO_ROOT), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, env=clean_env, ) app.state.processes[task_id] = proc # Start monitor coroutine alongside subprocess monitor_task = asyncio.create_task(monitor_subprocess(task_id, proc, cancel_event)) stdout, stderr = await proc.communicate() # Signal monitor to stop and wait for it cancel_event.set() try: await asyncio.wait_for(monitor_task, timeout=1.0) except asyncio.TimeoutError: monitor_task.cancel() # Clean up script file try: script_path.unlink() except Exception: pass if proc.returncode == 0: output = stdout.decode() decision = "HOLD" for line in output.splitlines(): if line.startswith("ANALYSIS_COMPLETE:"): decision = line.split(":", 1)[1].strip() app.state.task_results[task_id]["status"] = "completed" app.state.task_results[task_id]["progress"] = 100 app.state.task_results[task_id]["decision"] = decision app.state.task_results[task_id]["current_stage"] = "portfolio" for i in range(5): app.state.task_results[task_id]["stages"][i]["status"] = "completed" if not app.state.task_results[task_id]["stages"][i].get("completed_at"): app.state.task_results[task_id]["stages"][i]["completed_at"] = datetime.now().strftime("%H:%M:%S") else: error_msg = stderr.decode()[-1000:] if stderr else "Unknown error" app.state.task_results[task_id]["status"] = "failed" app.state.task_results[task_id]["error"] = error_msg _save_task_status(task_id, app.state.task_results[task_id]) except Exception as e: cancel_event.set() app.state.task_results[task_id]["status"] = "failed" app.state.task_results[task_id]["error"] = str(e) try: script_path.unlink() except Exception: pass _save_task_status(task_id, app.state.task_results[task_id]) await broadcast_progress(task_id, app.state.task_results[task_id]) task = asyncio.create_task(run_analysis()) app.state.analysis_tasks[task_id] = task return { "task_id": task_id, "ticker": request.ticker, "date": date, "status": "running", } @app.get("/api/analysis/status/{task_id}") async def get_task_status(task_id: str): """Get task status""" if task_id not in app.state.task_results: raise HTTPException(status_code=404, detail="Task not found") return app.state.task_results[task_id] @app.get("/api/analysis/tasks") async def list_tasks(): """List all tasks (active and recent)""" tasks = [] for task_id, state in app.state.task_results.items(): tasks.append({ "task_id": task_id, "ticker": state.get("ticker"), "date": state.get("date"), "status": state.get("status"), "progress": state.get("progress", 0), "decision": state.get("decision"), "error": state.get("error"), "created_at": state.get("stages", [{}])[0].get("completed_at") if state.get("stages") else None, }) # Sort by task_id (which includes timestamp) descending tasks.sort(key=lambda x: x["task_id"], reverse=True) return {"tasks": tasks, "total": len(tasks)} @app.delete("/api/analysis/cancel/{task_id}") async def cancel_task(task_id: str): """Cancel a running task""" if task_id not in app.state.task_results: raise HTTPException(status_code=404, detail="Task not found") # Kill the subprocess if it's still running proc = app.state.processes.get(task_id) if proc and proc.returncode is None: try: proc.kill() except Exception: pass # Cancel the asyncio task task = app.state.analysis_tasks.get(task_id) if task: task.cancel() app.state.task_results[task_id]["status"] = "failed" app.state.task_results[task_id]["error"] = "用户取消" await broadcast_progress(task_id, app.state.task_results[task_id]) # Clean up temp script script_path = Path(f"/tmp/analysis_{task_id}.py") try: script_path.unlink() except Exception: pass # Remove persisted task state _delete_task_status(task_id) return {"task_id": task_id, "status": "cancelled"} # ============== WebSocket ============== @app.websocket("/ws/analysis/{task_id}") async def websocket_analysis(websocket: WebSocket, task_id: str): """WebSocket for real-time analysis progress""" await websocket.accept() if task_id not in app.state.active_connections: app.state.active_connections[task_id] = [] app.state.active_connections[task_id].append(websocket) # Send current state immediately if available if task_id in app.state.task_results: await websocket.send_text(json.dumps({ "type": "progress", **app.state.task_results[task_id] })) try: while True: data = await websocket.receive_text() message = json.loads(data) if message.get("type") == "ping": await websocket.send_text(json.dumps({"type": "pong"})) except WebSocketDisconnect: if task_id in app.state.active_connections: app.state.active_connections[task_id].remove(websocket) async def broadcast_progress(task_id: str, progress: dict): """Broadcast progress to all connections for a task""" if task_id not in app.state.active_connections: return message = json.dumps({"type": "progress", **progress}) dead = [] for connection in app.state.active_connections[task_id]: try: await connection.send_text(message) except Exception: dead.append(connection) for conn in dead: app.state.active_connections[task_id].remove(conn) # ============== Reports ============== def get_results_dir() -> Path: return Path(__file__).parent.parent.parent / "results" def get_reports_list(): """Get all historical reports""" results_dir = get_results_dir() reports = [] if not results_dir.exists(): return reports for ticker_dir in results_dir.iterdir(): if ticker_dir.is_dir() and ticker_dir.name != "TradingAgentsStrategy_logs": ticker = ticker_dir.name for date_dir in ticker_dir.iterdir(): # Skip non-date directories like TradingAgentsStrategy_logs if date_dir.is_dir() and date_dir.name.startswith("20"): reports.append({ "ticker": ticker, "date": date_dir.name, "path": str(date_dir) }) return sorted(reports, key=lambda x: x["date"], reverse=True) def get_report_content(ticker: str, date: str) -> Optional[dict]: """Get report content for a specific ticker and date""" report_dir = get_results_dir() / ticker / date if not report_dir.exists(): return None content = {} complete_report = report_dir / "complete_report.md" if complete_report.exists(): content["report"] = complete_report.read_text() for stage in ["1_analysts", "2_research", "3_trading", "4_risk", "5_portfolio"]: stage_dir = report_dir / "reports" / stage if stage_dir.exists(): for f in stage_dir.glob("*.md"): content[f.name] = f.read_text() return content @app.get("/api/reports/list") async def list_reports(): return get_reports_list() @app.get("/api/reports/{ticker}/{date}") async def get_report(ticker: str, date: str): content = get_report_content(ticker, date) if not content: raise HTTPException(status_code=404, detail="Report not found") return content @app.get("/") async def root(): return {"message": "TradingAgents Web Dashboard API", "version": "0.1.0"} if __name__ == "__main__": import uvicorn # Run with: cd web_dashboard && ../env312/bin/python -m uvicorn main:app --reload # Or: cd web_dashboard/backend && python3 main.py (requires env312 in PATH) uvicorn.run(app, host="0.0.0.0", port=8000)