TradingAgents/web_dashboard/backend/main.py

1012 lines
34 KiB
Python

"""
TradingAgents Web Dashboard Backend
FastAPI REST API + WebSocket for real-time analysis progress
"""
import asyncio
import fcntl
import json
import os
import subprocess
import sys
import tempfile
import time
import traceback
from datetime import datetime
from pathlib import Path
from typing import Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi.responses import Response
# Path to TradingAgents repo root
REPO_ROOT = Path(__file__).parent.parent.parent
# Use the currently running Python interpreter
ANALYSIS_PYTHON = Path(sys.executable)
# Task state persistence directory
TASK_STATUS_DIR = Path(__file__).parent / "data" / "task_status"
# ============== Lifespan ==============
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup and shutdown events"""
app.state.active_connections: dict[str, list[WebSocket]] = {}
app.state.task_results: dict[str, dict] = {}
app.state.analysis_tasks: dict[str, asyncio.Task] = {}
# Restore persisted task states from disk
TASK_STATUS_DIR.mkdir(parents=True, exist_ok=True)
for f in TASK_STATUS_DIR.glob("*.json"):
try:
data = json.loads(f.read_text())
app.state.task_results[data["task_id"]] = data
except Exception:
pass
yield
# ============== App ==============
app = FastAPI(
title="TradingAgents Web Dashboard API",
version="0.1.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ============== Pydantic Models ==============
class AnalysisRequest(BaseModel):
ticker: str
date: Optional[str] = None
class ScreenRequest(BaseModel):
mode: str = "china_strict"
# ============== Cache Helpers ==============
CACHE_DIR = Path(__file__).parent.parent / "cache"
CACHE_TTL_SECONDS = 300 # 5 minutes
MAX_RETRY_COUNT = 2
RETRY_BASE_DELAY_SECS = 1
MAX_CONCURRENT_YFINANCE = 5
def _get_cache_path(mode: str) -> Path:
return CACHE_DIR / f"screen_{mode}.json"
def _load_from_cache(mode: str) -> Optional[dict]:
cache_path = _get_cache_path(mode)
if not cache_path.exists():
return None
try:
age = time.time() - cache_path.stat().st_mtime
if age > CACHE_TTL_SECONDS:
return None
with open(cache_path) as f:
return json.load(f)
except Exception:
return None
def _save_to_cache(mode: str, data: dict):
"""Save screening result to cache"""
try:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
cache_path = _get_cache_path(mode)
with open(cache_path, "w") as f:
json.dump(data, f)
except Exception:
pass
def _save_task_status(task_id: str, data: dict):
"""Persist task state to disk"""
try:
TASK_STATUS_DIR.mkdir(parents=True, exist_ok=True)
(TASK_STATUS_DIR / f"{task_id}.json").write_text(json.dumps(data, ensure_ascii=False))
except Exception:
pass
def _delete_task_status(task_id: str):
"""Remove persisted task state from disk"""
try:
(TASK_STATUS_DIR / f"{task_id}.json").unlink(missing_ok=True)
except Exception:
pass
# ============== SEPA Screening ==============
def _run_sepa_screening(mode: str) -> dict:
"""Run SEPA screening synchronously in thread"""
sys.path.insert(0, str(REPO_ROOT))
from sepa_screener import screen_all, china_stocks
results = screen_all(mode=mode, max_workers=5)
total = len(china_stocks)
return {
"mode": mode,
"total_stocks": total,
"passed": len(results),
"results": results,
}
@app.get("/api/stocks/screen")
async def screen_stocks(mode: str = Query("china_strict"), refresh: bool = Query(False)):
"""Screen stocks using SEPA criteria with caching"""
if not refresh:
cached = _load_from_cache(mode)
if cached:
return {**cached, "cached": True}
# Run in thread pool (blocks thread but not event loop)
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, lambda: _run_sepa_screening(mode))
_save_to_cache(mode, result)
return {**result, "cached": False}
# ============== Analysis Execution ==============
# Script template for subprocess-based analysis
# api_key is passed via environment variable (not CLI) for security
ANALYSIS_SCRIPT_TEMPLATE = """
import sys
import os
ticker = sys.argv[1]
date = sys.argv[2]
repo_root = sys.argv[3]
sys.path.insert(0, repo_root)
os.environ["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
import py_mini_racer
sys.modules["mini_racer"] = py_mini_racer
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from pathlib import Path
print("STAGE:analysts", flush=True)
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "anthropic"
config["deep_think_llm"] = "MiniMax-M2.7-highspeed"
config["quick_think_llm"] = "MiniMax-M2.7-highspeed"
config["backend_url"] = "https://api.minimaxi.com/anthropic"
config["max_debate_rounds"] = 1
config["max_risk_discuss_rounds"] = 1
print("STAGE:research", flush=True)
ta = TradingAgentsGraph(debug=False, config=config)
print("STAGE:trading", flush=True)
final_state, decision = ta.propagate(ticker, date)
print("STAGE:risk", flush=True)
results_dir = Path(repo_root) / "results" / ticker / date
results_dir.mkdir(parents=True, exist_ok=True)
signal = decision if isinstance(decision, str) else decision.get("signal", "HOLD")
report_content = (
"# TradingAgents 分析报告\\n\\n"
"**股票**: " + ticker + "\\n"
"**日期**: " + date + "\\n\\n"
"## 最终决策\\n\\n"
"**" + signal + "**\\n\\n"
"## 分析摘要\\n\\n"
+ final_state.get("market_report", "N/A") + "\\n\\n"
"## 基本面\\n\\n"
+ final_state.get("fundamentals_report", "N/A") + "\\n"
)
report_path = results_dir / "complete_report.md"
report_path.write_text(report_content)
print("STAGE:portfolio", flush=True)
print("ANALYSIS_COMPLETE:" + signal, flush=True)
"""
@app.post("/api/analysis/start")
async def start_analysis(request: AnalysisRequest):
"""Start a new analysis task"""
import uuid
task_id = f"{request.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
date = request.date or datetime.now().strftime("%Y-%m-%d")
# Validate API key before storing any task state
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set")
# Initialize task state
app.state.task_results[task_id] = {
"task_id": task_id,
"ticker": request.ticker,
"date": date,
"status": "running",
"progress": 0,
"current_stage": "analysts",
"created_at": datetime.now().isoformat(),
"elapsed": 0,
"stages": [
{"status": "running", "completed_at": None},
{"status": "pending", "completed_at": None},
{"status": "pending", "completed_at": None},
{"status": "pending", "completed_at": None},
{"status": "pending", "completed_at": None},
],
"logs": [],
"decision": None,
"error": None,
}
await broadcast_progress(task_id, app.state.task_results[task_id])
# Write analysis script to temp file with restrictive permissions (avoids subprocess -c quoting issues)
fd, script_path_str = tempfile.mkstemp(suffix=".py", prefix=f"analysis_{task_id}_")
script_path = Path(script_path_str)
os.chmod(script_path, 0o600)
with os.fdopen(fd, "w") as f:
f.write(ANALYSIS_SCRIPT_TEMPLATE)
# Store process reference for cancellation
app.state.processes = getattr(app.state, 'processes', {})
app.state.processes[task_id] = None
# Cancellation event for the monitor coroutine
cancel_event = asyncio.Event()
# Stage name to index mapping
STAGE_NAMES = ["analysts", "research", "trading", "risk", "portfolio"]
def _update_task_stage(stage_name: str):
"""Update task state for a completed stage and mark next as running."""
try:
idx = STAGE_NAMES.index(stage_name)
except ValueError:
return
# Mark all previous stages as completed
for i in range(idx):
if app.state.task_results[task_id]["stages"][i]["status"] != "completed":
app.state.task_results[task_id]["stages"][i]["status"] = "completed"
app.state.task_results[task_id]["stages"][i]["completed_at"] = datetime.now().strftime("%H:%M:%S")
# Mark current as completed
if app.state.task_results[task_id]["stages"][idx]["status"] != "completed":
app.state.task_results[task_id]["stages"][idx]["status"] = "completed"
app.state.task_results[task_id]["stages"][idx]["completed_at"] = datetime.now().strftime("%H:%M:%S")
# Mark next as running
if idx + 1 < 5:
if app.state.task_results[task_id]["stages"][idx + 1]["status"] == "pending":
app.state.task_results[task_id]["stages"][idx + 1]["status"] = "running"
# Update progress
app.state.task_results[task_id]["progress"] = int((idx + 1) / 5 * 100)
app.state.task_results[task_id]["current_stage"] = stage_name
async def monitor_subprocess(task_id: str, proc: asyncio.subprocess.Process, cancel_evt: asyncio.Event):
"""Monitor subprocess stdout for stage markers and broadcast progress."""
# Set stdout to non-blocking
fd = proc.stdout.fileno()
fl = fcntl.fcntl(fd, fcntl.GETFL)
fcntl.fcntl(fd, fcntl.SETFL, fl | os.O_NONBLOCK)
while not cancel_evt.is_set():
if proc.returncode is not None:
break
await asyncio.sleep(5)
if cancel_evt.is_set():
break
try:
chunk = os.read(fd, 32768)
if chunk:
for line in chunk.decode().splitlines():
if line.startswith("STAGE:"):
stage = line.split(":", 1)[1].strip()
_update_task_stage(stage)
await broadcast_progress(task_id, app.state.task_results[task_id])
except (BlockingIOError, OSError):
# No data available yet
pass
async def run_analysis():
"""Run analysis subprocess and broadcast progress"""
try:
# Use clean environment - don't inherit parent env
clean_env = {k: v for k, v in os.environ.items()
if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))}
clean_env["ANTHROPIC_API_KEY"] = api_key
clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
proc = await asyncio.create_subprocess_exec(
str(ANALYSIS_PYTHON),
str(script_path),
request.ticker,
date,
str(REPO_ROOT),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=clean_env,
)
app.state.processes[task_id] = proc
# Start monitor coroutine alongside subprocess
monitor_task = asyncio.create_task(monitor_subprocess(task_id, proc, cancel_event))
stdout, stderr = await proc.communicate()
# Signal monitor to stop and wait for it
cancel_event.set()
try:
await asyncio.wait_for(monitor_task, timeout=1.0)
except asyncio.TimeoutError:
monitor_task.cancel()
# Clean up script file
try:
script_path.unlink()
except Exception:
pass
if proc.returncode == 0:
output = stdout.decode()
decision = "HOLD"
for line in output.splitlines():
if line.startswith("ANALYSIS_COMPLETE:"):
decision = line.split(":", 1)[1].strip()
app.state.task_results[task_id]["status"] = "completed"
app.state.task_results[task_id]["progress"] = 100
app.state.task_results[task_id]["decision"] = decision
app.state.task_results[task_id]["current_stage"] = "portfolio"
for i in range(5):
app.state.task_results[task_id]["stages"][i]["status"] = "completed"
if not app.state.task_results[task_id]["stages"][i].get("completed_at"):
app.state.task_results[task_id]["stages"][i]["completed_at"] = datetime.now().strftime("%H:%M:%S")
else:
error_msg = stderr.decode()[-1000:] if stderr else "Unknown error"
app.state.task_results[task_id]["status"] = "failed"
app.state.task_results[task_id]["error"] = error_msg
_save_task_status(task_id, app.state.task_results[task_id])
except Exception as e:
cancel_event.set()
app.state.task_results[task_id]["status"] = "failed"
app.state.task_results[task_id]["error"] = str(e)
try:
script_path.unlink()
except Exception:
pass
_save_task_status(task_id, app.state.task_results[task_id])
await broadcast_progress(task_id, app.state.task_results[task_id])
task = asyncio.create_task(run_analysis())
app.state.analysis_tasks[task_id] = task
return {
"task_id": task_id,
"ticker": request.ticker,
"date": date,
"status": "running",
}
@app.get("/api/analysis/status/{task_id}")
async def get_task_status(task_id: str):
"""Get task status"""
if task_id not in app.state.task_results:
raise HTTPException(status_code=404, detail="Task not found")
return app.state.task_results[task_id]
@app.get("/api/analysis/tasks")
async def list_tasks():
"""List all tasks (active and recent)"""
tasks = []
for task_id, state in app.state.task_results.items():
tasks.append({
"task_id": task_id,
"ticker": state.get("ticker"),
"date": state.get("date"),
"status": state.get("status"),
"progress": state.get("progress", 0),
"decision": state.get("decision"),
"error": state.get("error"),
"created_at": state.get("created_at"),
})
# Sort by created_at descending (most recent first)
tasks.sort(key=lambda x: x.get("created_at") or "", reverse=True)
return {"tasks": tasks, "total": len(tasks)}
@app.delete("/api/analysis/cancel/{task_id}")
async def cancel_task(task_id: str):
"""Cancel a running task"""
if task_id not in app.state.task_results:
raise HTTPException(status_code=404, detail="Task not found")
# Kill the subprocess if it's still running
proc = app.state.processes.get(task_id)
if proc and proc.returncode is None:
try:
proc.kill()
except Exception:
pass
# Cancel the asyncio task
task = app.state.analysis_tasks.get(task_id)
if task:
task.cancel()
app.state.task_results[task_id]["status"] = "failed"
app.state.task_results[task_id]["error"] = "用户取消"
_save_task_status(task_id, app.state.task_results[task_id])
await broadcast_progress(task_id, app.state.task_results[task_id])
# Clean up temp script (may use tempfile.mkstemp with random suffix)
for p in Path("/tmp").glob(f"analysis_{task_id}_*.py"):
try:
p.unlink()
except Exception:
pass
# Remove persisted task state
_delete_task_status(task_id)
return {"task_id": task_id, "status": "cancelled"}
# ============== WebSocket ==============
@app.websocket("/ws/analysis/{task_id}")
async def websocket_analysis(websocket: WebSocket, task_id: str):
"""WebSocket for real-time analysis progress"""
await websocket.accept()
if task_id not in app.state.active_connections:
app.state.active_connections[task_id] = []
app.state.active_connections[task_id].append(websocket)
# Send current state immediately if available
if task_id in app.state.task_results:
await websocket.send_text(json.dumps({
"type": "progress",
**app.state.task_results[task_id]
}))
try:
while True:
data = await websocket.receive_text()
message = json.loads(data)
if message.get("type") == "ping":
await websocket.send_text(json.dumps({"type": "pong"}))
except WebSocketDisconnect:
if task_id in app.state.active_connections:
app.state.active_connections[task_id].remove(websocket)
async def broadcast_progress(task_id: str, progress: dict):
"""Broadcast progress to all connections for a task"""
if task_id not in app.state.active_connections:
return
message = json.dumps({"type": "progress", **progress})
dead = []
for connection in app.state.active_connections[task_id]:
try:
await connection.send_text(message)
except Exception:
dead.append(connection)
for conn in dead:
app.state.active_connections[task_id].remove(conn)
# ============== Reports ==============
def get_results_dir() -> Path:
return Path(__file__).parent.parent.parent / "results"
def get_reports_list():
"""Get all historical reports"""
results_dir = get_results_dir()
reports = []
if not results_dir.exists():
return reports
for ticker_dir in results_dir.iterdir():
if ticker_dir.is_dir() and ticker_dir.name != "TradingAgentsStrategy_logs":
ticker = ticker_dir.name
for date_dir in ticker_dir.iterdir():
# Skip non-date directories like TradingAgentsStrategy_logs
if date_dir.is_dir() and date_dir.name.startswith("20"):
reports.append({
"ticker": ticker,
"date": date_dir.name,
"path": str(date_dir)
})
return sorted(reports, key=lambda x: x["date"], reverse=True)
def get_report_content(ticker: str, date: str) -> Optional[dict]:
"""Get report content for a specific ticker and date"""
# Validate inputs to prevent path traversal
if ".." in ticker or "/" in ticker or "\\" in ticker:
return None
if ".." in date or "/" in date or "\\" in date:
return None
report_dir = get_results_dir() / ticker / date
# Strict traversal check: resolved path must be within get_results_dir()
try:
report_dir.resolve().relative_to(get_results_dir().resolve())
except ValueError:
return None
if not report_dir.exists():
return None
content = {}
complete_report = report_dir / "complete_report.md"
if complete_report.exists():
content["report"] = complete_report.read_text()
for stage in ["1_analysts", "2_research", "3_trading", "4_risk", "5_portfolio"]:
stage_dir = report_dir / "reports" / stage
if stage_dir.exists():
for f in stage_dir.glob("*.md"):
content[f.name] = f.read_text()
return content
@app.get("/api/reports/list")
async def list_reports():
return get_reports_list()
@app.get("/api/reports/{ticker}/{date}")
async def get_report(ticker: str, date: str):
content = get_report_content(ticker, date)
if not content:
raise HTTPException(status_code=404, detail="Report not found")
return content
# ============== Report Export ==============
import csv
import io
import re
from fpdf import FPDF
def _extract_decision(markdown_text: str) -> str:
"""Extract BUY/SELL/HOLD from markdown bold text."""
match = re.search(r'\*\*(BUY|SELL|HOLD)\*\*', markdown_text)
return match.group(1) if match else 'UNKNOWN'
def _extract_summary(markdown_text: str) -> str:
"""Extract first ~200 chars after '## 分析摘要'."""
match = re.search(r'## 分析摘要\s*\n+(.{0,300}?)(?=\n##|\Z)', markdown_text, re.DOTALL)
if match:
text = match.group(1).strip()
# Strip markdown formatting
text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
text = re.sub(r'\*(.*?)\*', r'\1', text)
text = re.sub(r'[#\n]+', ' ', text)
return text[:200].strip()
return ''
@app.get("/api/reports/export")
async def export_reports_csv():
"""Export all reports as CSV: ticker,date,decision,summary."""
reports = get_reports_list()
output = io.StringIO()
writer = csv.DictWriter(output, fieldnames=["ticker", "date", "decision", "summary"])
writer.writeheader()
for r in reports:
content = get_report_content(r["ticker"], r["date"])
if content and content.get("report"):
writer.writerow({
"ticker": r["ticker"],
"date": r["date"],
"decision": _extract_decision(content["report"]),
"summary": _extract_summary(content["report"]),
})
else:
writer.writerow({
"ticker": r["ticker"],
"date": r["date"],
"decision": "UNKNOWN",
"summary": "",
})
return Response(
content=output.getvalue(),
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=tradingagents_reports.csv"},
)
@app.get("/api/reports/{ticker}/{date}/pdf")
async def export_report_pdf(ticker: str, date: str):
"""Export a single report as PDF."""
content = get_report_content(ticker, date)
if not content or not content.get("report"):
raise HTTPException(status_code=404, detail="Report not found")
markdown_text = content["report"]
decision = _extract_decision(markdown_text)
summary = _extract_summary(markdown_text)
pdf = FPDF()
pdf.set_auto_page_break(auto=True, margin=20)
# Try multiple font paths for cross-platform support
font_paths = [
"/System/Library/Fonts/Supplemental/DejaVuSans.ttf",
"/System/Library/Fonts/Supplemental/DejaVuSans-Bold.ttf",
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
"/usr/share/fonts/dejavu/DejaVuSans.ttf",
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
"/usr/share/fonts/dejavu/DejaVuSans-Bold.ttf",
str(Path.home() / ".local/share/fonts/DejaVuSans.ttf"),
str(Path.home() / ".fonts/DejaVuSans.ttf"),
]
regular_font = None
bold_font = None
for p in font_paths:
if Path(p).exists():
if "Bold" in p and bold_font is None:
bold_font = p
elif regular_font is None and "Bold" not in p:
regular_font = p
use_dejavu = bool(regular_font and bold_font)
if use_dejavu:
pdf.add_font("DejaVu", "", regular_font, unicode=True)
pdf.add_font("DejaVu", "B", bold_font, unicode=True)
font_regular = "DejaVu"
font_bold = "DejaVu"
else:
font_regular = "Helvetica"
font_bold = "Helvetica"
pdf.add_page()
pdf.set_font(font_bold, "B", 18)
pdf.cell(0, 12, f"TradingAgents 分析报告", ln=True, align="C")
pdf.ln(5)
pdf.set_font(font_regular, "", 11)
pdf.cell(0, 8, f"股票: {ticker} 日期: {date}", ln=True)
pdf.ln(3)
# Decision badge
pdf.set_font(font_bold, "B", 14)
if decision == "BUY":
pdf.set_text_color(34, 197, 94)
elif decision == "SELL":
pdf.set_text_color(220, 38, 38)
else:
pdf.set_text_color(245, 158, 11)
pdf.cell(0, 10, f"决策: {decision}", ln=True)
pdf.set_text_color(0, 0, 0)
pdf.ln(5)
# Summary
pdf.set_font(font_bold, "B", 12)
pdf.cell(0, 8, "分析摘要", ln=True)
pdf.set_font(font_regular, "", 10)
pdf.multi_cell(0, 6, summary or "")
pdf.ln(5)
# Full report text (stripped of heavy markdown)
pdf.set_font(font_bold, "B", 12)
pdf.cell(0, 8, "完整报告", ln=True)
pdf.set_font(font_regular, "", 9)
# Split into lines, filter out very long lines
for line in markdown_text.splitlines():
line = re.sub(r'\*\*(.*?)\*\*', r'\1', line)
line = re.sub(r'\*(.*?)\*', r'\1', line)
line = re.sub(r'#{1,6} ', '', line)
line = line.strip()
if not line:
pdf.ln(2)
continue
if len(line) > 120:
line = line[:120] + "..."
try:
pdf.multi_cell(0, 5, line)
except Exception:
pass
return Response(
content=pdf.output(),
media_type="application/pdf",
headers={"Content-Disposition": f"attachment; filename={ticker}_{date}_report.pdf"},
)
# ============== Portfolio ==============
import sys
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from api.portfolio import (
get_watchlist, add_to_watchlist, remove_from_watchlist,
get_positions, add_position, remove_position,
get_accounts, create_account, delete_account,
get_recommendations, get_recommendation, save_recommendation,
RECOMMENDATIONS_DIR,
)
# --- Watchlist ---
@app.get("/api/portfolio/watchlist")
async def list_watchlist():
return {"watchlist": get_watchlist()}
@app.post("/api/portfolio/watchlist")
async def create_watchlist_entry(body: dict):
try:
entry = add_to_watchlist(body["ticker"], body.get("name", body["ticker"]))
return entry
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.delete("/api/portfolio/watchlist/{ticker}")
async def delete_watchlist_entry(ticker: str):
if remove_from_watchlist(ticker):
return {"ok": True}
raise HTTPException(status_code=404, detail="Ticker not found in watchlist")
# --- Accounts ---
@app.get("/api/portfolio/accounts")
async def list_accounts():
accounts = get_accounts()
return {"accounts": list(accounts.get("accounts", {}).keys())}
@app.post("/api/portfolio/accounts")
async def create_account_endpoint(body: dict):
try:
return create_account(body["account_name"])
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.delete("/api/portfolio/accounts/{account_name}")
async def delete_account_endpoint(account_name: str):
if delete_account(account_name):
return {"ok": True}
raise HTTPException(status_code=404, detail="Account not found")
# --- Positions ---
@app.get("/api/portfolio/positions")
async def list_positions(account: Optional[str] = Query(None)):
return {"positions": get_positions(account)}
@app.post("/api/portfolio/positions")
async def create_position(body: dict):
try:
pos = add_position(
ticker=body["ticker"],
shares=body["shares"],
cost_price=body["cost_price"],
purchase_date=body.get("purchase_date"),
notes=body.get("notes", ""),
account=body.get("account", "默认账户"),
)
return pos
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.delete("/api/portfolio/positions/{ticker}")
async def delete_position(ticker: str, position_id: Optional[str] = Query(None), account: Optional[str] = Query(None)):
removed = remove_position(ticker, position_id or "", account)
if removed:
return {"ok": True}
raise HTTPException(status_code=404, detail="Position not found")
@app.get("/api/portfolio/positions/export")
async def export_positions_csv(account: Optional[str] = Query(None)):
positions = get_positions(account)
import csv
import io
output = io.StringIO()
writer = csv.DictWriter(output, fieldnames=["ticker", "shares", "cost_price", "purchase_date", "notes", "account"])
writer.writeheader()
for p in positions:
writer.writerow({k: p[k] for k in ["ticker", "shares", "cost_price", "purchase_date", "notes", "account"]})
return Response(content=output.getvalue(), media_type="text/csv", headers={"Content-Disposition": "attachment; filename=positions.csv"})
# --- Recommendations ---
@app.get("/api/portfolio/recommendations")
async def list_recommendations(date: Optional[str] = Query(None)):
return {"recommendations": get_recommendations(date)}
@app.get("/api/portfolio/recommendations/{date}/{ticker}")
async def get_recommendation_endpoint(date: str, ticker: str):
rec = get_recommendation(date, ticker)
if not rec:
raise HTTPException(status_code=404, detail="Recommendation not found")
return rec
# --- Batch Analysis ---
@app.post("/api/portfolio/analyze")
async def start_portfolio_analysis():
"""
Trigger batch analysis for all watchlist tickers.
Runs serially, streaming progress via WebSocket (task_id prefixed with 'port_').
"""
import uuid
date = datetime.now().strftime("%Y-%m-%d")
task_id = f"port_{date}_{uuid.uuid4().hex[:6]}"
watchlist = get_watchlist()
if not watchlist:
raise HTTPException(status_code=400, detail="自选股为空,请先添加股票")
total = len(watchlist)
app.state.task_results[task_id] = {
"task_id": task_id,
"type": "portfolio",
"status": "running",
"total": total,
"completed": 0,
"failed": 0,
"current_ticker": None,
"results": [],
"error": None,
}
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set")
await broadcast_progress(task_id, app.state.task_results[task_id])
async def run_portfolio_analysis():
max_retries = MAX_RETRY_COUNT
async def run_single_analysis(ticker: str, stock: dict) -> tuple[bool, str, dict | None]:
"""Run analysis for one ticker. Returns (success, decision, rec_or_error)."""
last_error = None
for attempt in range(max_retries + 1):
script_path = None
try:
fd, script_path_str = tempfile.mkstemp(suffix=".py", prefix=f"analysis_{task_id}_{stock['_idx']}_")
script_path = Path(script_path_str)
os.chmod(script_path, 0o600)
with os.fdopen(fd, "w") as f:
f.write(ANALYSIS_SCRIPT_TEMPLATE)
clean_env = {k: v for k, v in os.environ.items()
if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))}
clean_env["ANTHROPIC_API_KEY"] = api_key
clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
proc = await asyncio.create_subprocess_exec(
str(ANALYSIS_PYTHON), str(script_path), ticker, date, str(REPO_ROOT),
stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
env=clean_env,
)
app.state.processes[task_id] = proc
stdout, stderr = await proc.communicate()
try:
script_path.unlink()
except Exception:
pass
if proc.returncode == 0:
output = stdout.decode()
decision = "HOLD"
for line in output.splitlines():
if line.startswith("ANALYSIS_COMPLETE:"):
decision = line.split(":", 1)[1].strip()
rec = {
"ticker": ticker,
"name": stock.get("name", ticker),
"analysis_date": date,
"decision": decision,
"created_at": datetime.now().isoformat(),
}
save_recommendation(date, ticker, rec)
return True, decision, rec
else:
last_error = stderr.decode()[-500:] if stderr else f"exit {proc.returncode}"
except Exception as e:
last_error = str(e)
finally:
if script_path:
try:
script_path.unlink()
except Exception:
pass
if attempt < max_retries:
await asyncio.sleep(RETRY_BASE_DELAY_SECS ** attempt) # exponential backoff: 1s, 2s
return False, "HOLD", None
try:
for i, stock in enumerate(watchlist):
stock["_idx"] = i # used in temp file name
ticker = stock["ticker"]
app.state.task_results[task_id]["current_ticker"] = ticker
app.state.task_results[task_id]["status"] = "running"
app.state.task_results[task_id]["completed"] = i
await broadcast_progress(task_id, app.state.task_results[task_id])
success, decision, rec = await run_single_analysis(ticker, stock)
if success:
app.state.task_results[task_id]["completed"] = i + 1
app.state.task_results[task_id]["results"].append(rec)
else:
app.state.task_results[task_id]["failed"] += 1
await broadcast_progress(task_id, app.state.task_results[task_id])
app.state.task_results[task_id]["status"] = "completed"
app.state.task_results[task_id]["current_ticker"] = None
_save_task_status(task_id, app.state.task_results[task_id])
except Exception as e:
app.state.task_results[task_id]["status"] = "failed"
app.state.task_results[task_id]["error"] = str(e)
_save_task_status(task_id, app.state.task_results[task_id])
await broadcast_progress(task_id, app.state.task_results[task_id])
task = asyncio.create_task(run_portfolio_analysis())
app.state.analysis_tasks[task_id] = task
return {
"task_id": task_id,
"total": total,
"status": "running",
}
@app.get("/")
async def root():
return {"message": "TradingAgents Web Dashboard API", "version": "0.1.0"}
if __name__ == "__main__":
import uvicorn
# Run with: cd web_dashboard && ../env312/bin/python -m uvicorn main:app --reload
# Or: cd web_dashboard/backend && python3 main.py (requires env312 in PATH)
uvicorn.run(app, host="0.0.0.0", port=8000)