TradingAgents/web_dashboard/backend/main.py

1041 lines
35 KiB
Python

"""
TradingAgents Web Dashboard Backend
FastAPI REST API + WebSocket for real-time analysis progress
"""
import asyncio
import hmac
import json
import os
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Optional
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect, Query, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from tradingagents.default_config import get_default_config, normalize_runtime_llm_config
from services import (
AnalysisService,
JobService,
ResultStore,
TaskCommandService,
TaskQueryService,
build_request_context,
clone_request_context,
load_migration_flags,
)
from services.executor import LegacySubprocessAnalysisExecutor
# Path to TradingAgents repo root
REPO_ROOT = Path(__file__).parent.parent.parent
_env_file = os.environ.get("TRADINGAGENTS_ENV_FILE")
if _env_file != "":
load_dotenv(Path(_env_file) if _env_file else REPO_ROOT / ".env", override=True)
# Use the currently running Python interpreter
ANALYSIS_PYTHON = Path(sys.executable)
# Task state persistence directory
TASK_STATUS_DIR = Path(__file__).parent / "data" / "task_status"
CONFIG_PATH = Path(__file__).parent / "data" / "config.json"
# ============== Lifespan ==============
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup and shutdown events"""
app.state.active_connections: dict[str, list[WebSocket]] = {}
app.state.task_results: dict[str, dict] = {}
app.state.analysis_tasks: dict[str, asyncio.Task] = {}
app.state.processes: dict[str, asyncio.subprocess.Process | None] = {}
app.state.migration_flags = load_migration_flags()
portfolio_gateway = create_legacy_portfolio_gateway()
app.state.result_store = ResultStore(TASK_STATUS_DIR, portfolio_gateway)
app.state.job_service = JobService(
task_results=app.state.task_results,
analysis_tasks=app.state.analysis_tasks,
processes=app.state.processes,
persist_task=app.state.result_store.save_task_status,
delete_task=app.state.result_store.delete_task_status,
)
app.state.analysis_service = AnalysisService(
executor=LegacySubprocessAnalysisExecutor(
analysis_python=ANALYSIS_PYTHON,
repo_root=REPO_ROOT,
api_key_resolver=_get_analysis_provider_api_key,
process_registry=app.state.job_service.register_process,
),
result_store=app.state.result_store,
job_service=app.state.job_service,
retry_count=MAX_RETRY_COUNT,
retry_base_delay_secs=RETRY_BASE_DELAY_SECS,
)
app.state.task_query_service = TaskQueryService(
task_results=app.state.task_results,
result_store=app.state.result_store,
job_service=app.state.job_service,
)
app.state.task_command_service = TaskCommandService(
task_results=app.state.task_results,
analysis_tasks=app.state.analysis_tasks,
processes=app.state.processes,
result_store=app.state.result_store,
job_service=app.state.job_service,
)
# Restore persisted task states from disk
app.state.job_service.restore_task_results(app.state.result_store.restore_task_results())
yield
# ============== App ==============
app = FastAPI(
title="TradingAgents Web Dashboard API",
version="0.1.0",
lifespan=lifespan
)
# CORS: allow all if CORS_ORIGINS is not set (development), otherwise comma-separated list
_cors_origins = os.environ.get("CORS_ORIGINS", "*")
_cors_origins_list = ["*"] if _cors_origins == "*" else [o.strip() for o in _cors_origins.split(",")]
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins_list,
allow_methods=["*"],
allow_headers=["*"],
)
# ============== Pydantic Models ==============
class AnalysisRequest(BaseModel):
ticker: str
date: Optional[str] = None
portfolio_context: Optional[str] = None
peer_context: Optional[str] = None
peer_context_mode: Optional[str] = None
class ScreenRequest(BaseModel):
mode: str = "china_strict"
# ============== Config Commands (Tauri IPC) ==============
@app.get("/api/config/check")
async def check_config():
"""Check if the analysis provider is configured with a callable API key."""
configured = bool(_resolve_analysis_runtime_settings().get("provider_api_key"))
return {"configured": configured}
@app.post("/api/config/apikey")
async def save_apikey(request: Request, body: dict = None, api_key: Optional[str] = Header(None)):
"""Persist API key for local desktop/backend use."""
if _get_api_key():
if not _check_api_key(api_key):
_auth_error()
elif not _is_local_request(request):
raise HTTPException(status_code=403, detail="API key setup is only allowed from localhost")
if not body or "api_key" not in body:
raise HTTPException(status_code=400, detail="api_key is required")
apikey = body["api_key"].strip()
if not apikey:
raise HTTPException(status_code=400, detail="api_key cannot be empty")
try:
runtime_provider = _resolve_analysis_runtime_settings().get("llm_provider", "anthropic")
_persist_analysis_api_key(apikey, provider=str(runtime_provider).lower())
return {"ok": True, "saved": True}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to save API key: {e}")
# ============== Cache Helpers ==============
CACHE_DIR = Path(__file__).parent.parent / "cache"
CACHE_TTL_SECONDS = 300 # 5 minutes
MAX_RETRY_COUNT = 2
RETRY_BASE_DELAY_SECS = 1
MAX_CONCURRENT_YFINANCE = 5
# Pagination defaults
DEFAULT_PAGE_SIZE = 50
MAX_PAGE_SIZE = 500
# Auth — set DASHBOARD_API_KEY env var to enable API key authentication
_api_key: Optional[str] = None
def _get_api_key() -> Optional[str]:
global _api_key
if _api_key is None:
_api_key = os.environ.get("DASHBOARD_API_KEY")
return _api_key
def _check_api_key(api_key: Optional[str]) -> bool:
"""Return True if no key is required, or if the provided key matches."""
required = _get_api_key()
if not required:
return True
if not api_key:
return False
return hmac.compare_digest(api_key, required)
def _auth_error():
raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required")
def _load_saved_config() -> dict:
try:
if CONFIG_PATH.exists():
return json.loads(CONFIG_PATH.read_text())
except Exception:
pass
return {}
def _persist_analysis_api_key(api_key_value: str, *, provider: str):
global _api_key
existing = _load_saved_config()
api_keys = dict(existing.get("api_keys") or {})
api_keys[provider] = api_key_value
payload = dict(existing)
payload["api_keys"] = api_keys
payload["api_key_provider"] = provider
payload["api_key"] = api_key_value
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
CONFIG_PATH.write_text(json.dumps(payload, ensure_ascii=False))
os.chmod(CONFIG_PATH, 0o600)
_api_key = None
def _get_analysis_provider_api_key(provider: str, saved_config: Optional[dict] = None) -> Optional[str]:
env_names = {
"anthropic": ("ANTHROPIC_API_KEY", "MINIMAX_API_KEY"),
"openai": ("OPENAI_API_KEY",),
"openrouter": ("OPENROUTER_API_KEY",),
"xai": ("XAI_API_KEY",),
"google": ("GOOGLE_API_KEY",),
"ollama": tuple(),
}.get(provider.lower(), tuple())
for env_name in env_names:
value = os.environ.get(env_name)
if value:
return value
saved = dict(saved_config or {})
api_keys = saved.get("api_keys")
if isinstance(api_keys, dict):
value = api_keys.get(provider.lower())
if value:
return value
legacy_provider = str(saved.get("api_key_provider") or "").lower()
legacy_key = saved.get("api_key")
if legacy_provider == provider.lower() and legacy_key:
return legacy_key
return None
def _resolve_analysis_runtime_settings() -> dict:
saved = _load_saved_config()
defaults = get_default_config()
provider = os.environ.get("TRADINGAGENTS_LLM_PROVIDER")
if not provider:
if os.environ.get("ANTHROPIC_BASE_URL"):
provider = "anthropic"
elif os.environ.get("OPENAI_BASE_URL"):
provider = "openai"
else:
provider = defaults.get("llm_provider", "anthropic")
backend_url = (
os.environ.get("TRADINGAGENTS_BACKEND_URL")
or os.environ.get("ANTHROPIC_BASE_URL")
or os.environ.get("OPENAI_BASE_URL")
or defaults.get("backend_url")
)
deep_model = (
os.environ.get("TRADINGAGENTS_DEEP_MODEL")
or os.environ.get("TRADINGAGENTS_MODEL")
or defaults.get("deep_think_llm")
)
quick_model = (
os.environ.get("TRADINGAGENTS_QUICK_MODEL")
or os.environ.get("TRADINGAGENTS_MODEL")
or defaults.get("quick_think_llm")
)
selected_analysts_raw = os.environ.get("TRADINGAGENTS_SELECTED_ANALYSTS", "market")
selected_analysts = [item.strip() for item in selected_analysts_raw.split(",") if item.strip()]
analysis_prompt_style = os.environ.get("TRADINGAGENTS_ANALYSIS_PROMPT_STYLE", "compact")
llm_timeout = float(
os.environ.get(
"TRADINGAGENTS_LLM_TIMEOUT",
str(defaults.get("llm_timeout", 45)),
)
)
llm_max_retries = int(
os.environ.get(
"TRADINGAGENTS_LLM_MAX_RETRIES",
str(defaults.get("llm_max_retries", 0)),
)
)
settings = {
"llm_provider": provider,
"backend_url": backend_url,
"deep_think_llm": deep_model,
"quick_think_llm": quick_model,
"selected_analysts": selected_analysts,
"analysis_prompt_style": analysis_prompt_style,
"llm_timeout": llm_timeout,
"llm_max_retries": llm_max_retries,
"provider_api_key": _get_analysis_provider_api_key(provider, saved),
}
return normalize_runtime_llm_config(settings)
def _build_analysis_request_context(request: Request, auth_key: Optional[str]):
settings = _resolve_analysis_runtime_settings()
return build_request_context(
request,
auth_key=auth_key,
provider_api_key=settings["provider_api_key"],
llm_provider=settings["llm_provider"],
backend_url=settings["backend_url"],
deep_think_llm=settings["deep_think_llm"],
quick_think_llm=settings["quick_think_llm"],
selected_analysts=settings["selected_analysts"],
analysis_prompt_style=settings["analysis_prompt_style"],
llm_timeout=settings["llm_timeout"],
llm_max_retries=settings["llm_max_retries"],
metadata={
"stdout_timeout_secs": max(float(settings["llm_timeout"]) * 4.0, 120.0),
"total_timeout_secs": max(float(settings["llm_timeout"]) * 12.0, 900.0),
"heartbeat_interval_secs": 10.0,
"local_recovery_timeout_secs": max(float(settings["llm_timeout"]) * 2.5, 90.0),
"provider_probe_timeout_secs": max(float(settings["llm_timeout"]) * 1.5, 60.0),
"local_recovery_cost_cap": 1.0,
"provider_probe_cost_cap": 1.0,
},
)
def _is_local_request(request: Request) -> bool:
client = request.client
if client is None:
return False
return client.host in {"127.0.0.1", "::1", "localhost", "testclient"}
def _get_cache_path(mode: str) -> Path:
return CACHE_DIR / f"screen_{mode}.json"
def _load_from_cache(mode: str) -> Optional[dict]:
cache_path = _get_cache_path(mode)
if not cache_path.exists():
return None
try:
age = time.time() - cache_path.stat().st_mtime
if age > CACHE_TTL_SECONDS:
return None
with open(cache_path) as f:
return json.load(f)
except Exception:
return None
def _save_to_cache(mode: str, data: dict):
"""Save screening result to cache"""
try:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
cache_path = _get_cache_path(mode)
with open(cache_path, "w") as f:
json.dump(data, f)
except Exception:
pass
# ============== SEPA Screening ==============
def _run_sepa_screening(mode: str) -> dict:
"""Run SEPA screening synchronously in thread"""
sys.path.insert(0, str(REPO_ROOT))
from sepa_screener import screen_all, china_stocks
results = screen_all(mode=mode, max_workers=5)
total = len(china_stocks)
return {
"mode": mode,
"total_stocks": total,
"passed": len(results),
"results": results,
}
@app.get("/api/stocks/screen")
async def screen_stocks(mode: str = Query("china_strict"), refresh: bool = Query(False), api_key: Optional[str] = Header(None)):
"""Screen stocks using SEPA criteria with caching"""
if not _check_api_key(api_key):
_auth_error()
if not refresh:
cached = _load_from_cache(mode)
if cached:
return {**cached, "cached": True}
# Run in thread pool (blocks thread but not event loop)
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, lambda: _run_sepa_screening(mode))
_save_to_cache(mode, result)
return {**result, "cached": False}
# ============== Analysis Execution ==============
@app.post("/api/analysis/start")
async def start_analysis(
payload: AnalysisRequest,
http_request: Request,
api_key: Optional[str] = Header(None),
):
"""Start a new analysis task."""
import uuid
if not _check_api_key(api_key):
_auth_error()
task_id = f"{payload.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
date = payload.date or datetime.now().strftime("%Y-%m-%d")
request_context = _build_analysis_request_context(http_request, api_key)
if payload.portfolio_context is not None or payload.peer_context is not None:
request_context = clone_request_context(
request_context,
metadata_updates={
"portfolio_context": payload.portfolio_context,
"peer_context": payload.peer_context,
"peer_context_mode": payload.peer_context_mode or ("CALLER_PROVIDED" if payload.peer_context else None),
},
)
try:
return await app.state.analysis_service.start_analysis(
task_id=task_id,
ticker=payload.ticker,
date=date,
request_context=request_context,
broadcast_progress=broadcast_progress,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
except RuntimeError as exc:
raise HTTPException(status_code=500, detail=str(exc))
@app.get("/api/analysis/status/{task_id}")
async def get_task_status(task_id: str, api_key: Optional[str] = Header(None)):
"""Get task status"""
if not _check_api_key(api_key):
_auth_error()
payload = app.state.task_query_service.public_task_payload(task_id)
if payload is None:
raise HTTPException(status_code=404, detail="Task not found")
return payload
@app.get("/api/analysis/tasks")
async def list_tasks(api_key: Optional[str] = Header(None)):
"""List all tasks (active and recent)"""
if not _check_api_key(api_key):
_auth_error()
return app.state.task_query_service.list_task_summaries()
@app.delete("/api/analysis/cancel/{task_id}")
async def cancel_task(task_id: str, api_key: Optional[str] = Header(None)):
"""Cancel a running task."""
if not _check_api_key(api_key):
_auth_error()
payload = await app.state.task_command_service.cancel_task(
task_id,
broadcast_progress=broadcast_progress,
)
if payload is None:
raise HTTPException(status_code=404, detail="Task not found")
return payload
# ============== WebSocket ==============
@app.websocket("/ws/analysis/{task_id}")
async def websocket_analysis(websocket: WebSocket, task_id: str):
"""WebSocket for real-time analysis progress. Auth via ?api_key= query param."""
# Optional API key check for WebSocket
ws_api_key = websocket.query_params.get("api_key")
if not _check_api_key(ws_api_key):
await websocket.close(code=4001, reason="Unauthorized")
return
await websocket.accept()
if task_id not in app.state.active_connections:
app.state.active_connections[task_id] = []
app.state.active_connections[task_id].append(websocket)
# Send current state immediately if available
if task_id in app.state.task_results:
await websocket.send_text(json.dumps({
"type": "progress",
**_public_task_payload(task_id)
}))
try:
while True:
data = await websocket.receive_text()
message = json.loads(data)
if message.get("type") == "ping":
await websocket.send_text(json.dumps({"type": "pong"}))
except WebSocketDisconnect:
if task_id in app.state.active_connections:
app.state.active_connections[task_id].remove(websocket)
async def broadcast_progress(task_id: str, progress: dict):
"""Broadcast progress to all connections for a task"""
if task_id not in app.state.active_connections:
return
payload = _public_task_payload(task_id, state_override=progress)
message = json.dumps({"type": "progress", **payload})
dead = []
for connection in app.state.active_connections[task_id]:
try:
await connection.send_text(message)
except Exception:
dead.append(connection)
for conn in dead:
app.state.active_connections[task_id].remove(conn)
def _load_task_contract(task_id: str, state: Optional[dict] = None) -> Optional[dict]:
return app.state.task_query_service.load_task_contract(task_id, state_override=state)
def _public_task_payload(task_id: str, state_override: Optional[dict] = None) -> dict:
payload = app.state.task_query_service.public_task_payload(task_id, state_override=state_override)
if payload is None:
raise KeyError(task_id)
return payload
def _public_task_summary(task_id: str, state_override: Optional[dict] = None) -> dict:
summary = app.state.task_query_service.public_task_summary(task_id, state_override=state_override)
if summary is None:
raise KeyError(task_id)
return summary
# ============== Reports ==============
def get_results_dir() -> Path:
return Path(__file__).parent.parent.parent / "results"
def get_reports_list():
"""Get all historical reports"""
results_dir = get_results_dir()
reports = []
if not results_dir.exists():
return reports
for ticker_dir in results_dir.iterdir():
if ticker_dir.is_dir() and ticker_dir.name != "TradingAgentsStrategy_logs":
ticker = ticker_dir.name
for date_dir in ticker_dir.iterdir():
# Skip non-date directories like TradingAgentsStrategy_logs
if date_dir.is_dir() and date_dir.name.startswith("20"):
reports.append({
"ticker": ticker,
"date": date_dir.name,
"path": str(date_dir)
})
return sorted(reports, key=lambda x: x["date"], reverse=True)
def get_report_content(ticker: str, date: str) -> Optional[dict]:
"""Get report content for a specific ticker and date"""
# Validate inputs to prevent path traversal
if ".." in ticker or "/" in ticker or "\\" in ticker:
return None
if ".." in date or "/" in date or "\\" in date:
return None
report_dir = get_results_dir() / ticker / date
# Strict traversal check: resolved path must be within get_results_dir()
try:
report_dir.resolve().relative_to(get_results_dir().resolve())
except ValueError:
return None
if not report_dir.exists():
return None
content = {}
complete_report = report_dir / "complete_report.md"
if complete_report.exists():
content["report"] = complete_report.read_text()
for stage in ["1_analysts", "2_research", "3_trading", "4_risk", "5_portfolio"]:
stage_dir = report_dir / "reports" / stage
if stage_dir.exists():
for f in stage_dir.glob("*.md"):
content[f.name] = f.read_text()
return content
@app.get("/api/reports/list")
async def list_reports(
limit: int = Query(DEFAULT_PAGE_SIZE, ge=1, le=MAX_PAGE_SIZE),
offset: int = Query(0, ge=0),
api_key: Optional[str] = Header(None),
):
if not _check_api_key(api_key):
_auth_error()
reports = get_reports_list()
total = len(reports)
return {
"reports": sorted(reports, key=lambda x: x["date"], reverse=True)[offset : offset + limit],
"total": total,
"limit": limit,
"offset": offset,
}
@app.get("/api/reports/{ticker}/{date}")
async def get_report(ticker: str, date: str, api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
content = get_report_content(ticker, date)
if not content:
raise HTTPException(status_code=404, detail="Report not found")
return content
# ============== Report Export ==============
import csv
import io
import re
from fpdf import FPDF
def _extract_decision(markdown_text: str) -> str:
"""Extract BUY/OVERWEIGHT/SELL/UNDERWEIGHT/HOLD from markdown bold text."""
match = re.search(r'\*\*(BUY|SELL|HOLD|OVERWEIGHT|UNDERWEIGHT)\*\*', markdown_text)
return match.group(1) if match else 'UNKNOWN'
def _extract_summary(markdown_text: str) -> str:
"""Extract first ~200 chars after '## 分析摘要'."""
match = re.search(r'## 分析摘要\s*\n+(.{0,300}?)(?=\n##|\Z)', markdown_text, re.DOTALL)
if match:
text = match.group(1).strip()
# Strip markdown formatting
text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
text = re.sub(r'\*(.*?)\*', r'\1', text)
text = re.sub(r'[#\n]+', ' ', text)
return text[:200].strip()
return ''
@app.get("/api/reports/export")
async def export_reports_csv(
api_key: Optional[str] = Header(None),
):
"""Export all reports as CSV: ticker,date,decision,summary."""
if not _check_api_key(api_key):
_auth_error()
reports = get_reports_list()
output = io.StringIO()
writer = csv.DictWriter(output, fieldnames=["ticker", "date", "decision", "summary"])
writer.writeheader()
for r in reports:
content = get_report_content(r["ticker"], r["date"])
if content and content.get("report"):
writer.writerow({
"ticker": r["ticker"],
"date": r["date"],
"decision": _extract_decision(content["report"]),
"summary": _extract_summary(content["report"]),
})
else:
writer.writerow({
"ticker": r["ticker"],
"date": r["date"],
"decision": "UNKNOWN",
"summary": "",
})
return Response(
content=output.getvalue(),
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=tradingagents_reports.csv"},
)
@app.get("/api/reports/{ticker}/{date}/pdf")
async def export_report_pdf(ticker: str, date: str, api_key: Optional[str] = Header(None)):
"""Export a single report as PDF."""
if not _check_api_key(api_key):
_auth_error()
content = get_report_content(ticker, date)
if not content or not content.get("report"):
raise HTTPException(status_code=404, detail="Report not found")
markdown_text = content["report"]
decision = _extract_decision(markdown_text)
summary = _extract_summary(markdown_text)
pdf = FPDF()
pdf.set_auto_page_break(auto=True, margin=20)
# Try multiple font paths for cross-platform support
font_paths = [
"/System/Library/Fonts/Supplemental/DejaVuSans.ttf",
"/System/Library/Fonts/Supplemental/DejaVuSans-Bold.ttf",
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
"/usr/share/fonts/dejavu/DejaVuSans.ttf",
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
"/usr/share/fonts/dejavu/DejaVuSans-Bold.ttf",
str(Path.home() / ".local/share/fonts/DejaVuSans.ttf"),
str(Path.home() / ".fonts/DejaVuSans.ttf"),
]
regular_font = None
bold_font = None
for p in font_paths:
if Path(p).exists():
if "Bold" in p and bold_font is None:
bold_font = p
elif regular_font is None and "Bold" not in p:
regular_font = p
use_dejavu = bool(regular_font and bold_font)
if use_dejavu:
pdf.add_font("DejaVu", "", regular_font, unicode=True)
pdf.add_font("DejaVu", "B", bold_font, unicode=True)
font_regular = "DejaVu"
font_bold = "DejaVu"
else:
font_regular = "Helvetica"
font_bold = "Helvetica"
pdf.add_page()
pdf.set_font(font_bold, "B", 18)
pdf.cell(0, 12, f"TradingAgents 分析报告", ln=True, align="C")
pdf.ln(5)
pdf.set_font(font_regular, "", 11)
pdf.cell(0, 8, f"股票: {ticker} 日期: {date}", ln=True)
pdf.ln(3)
# Decision badge
pdf.set_font(font_bold, "B", 14)
if decision == "BUY":
pdf.set_text_color(34, 197, 94)
elif decision == "OVERWEIGHT":
pdf.set_text_color(134, 239, 172)
elif decision == "SELL":
pdf.set_text_color(220, 38, 38)
elif decision == "UNDERWEIGHT":
pdf.set_text_color(252, 165, 165)
else:
pdf.set_text_color(245, 158, 11)
pdf.cell(0, 10, f"决策: {decision}", ln=True)
pdf.set_text_color(0, 0, 0)
pdf.ln(5)
# Summary
pdf.set_font(font_bold, "B", 12)
pdf.cell(0, 8, "分析摘要", ln=True)
pdf.set_font(font_regular, "", 10)
pdf.multi_cell(0, 6, summary or "")
pdf.ln(5)
# Full report text (stripped of heavy markdown)
pdf.set_font(font_bold, "B", 12)
pdf.cell(0, 8, "完整报告", ln=True)
pdf.set_font(font_regular, "", 9)
# Split into lines, filter out very long lines
for line in markdown_text.splitlines():
line = re.sub(r'\*\*(.*?)\*\*', r'\1', line)
line = re.sub(r'\*(.*?)\*', r'\1', line)
line = re.sub(r'#{1,6} ', '', line)
line = line.strip()
if not line:
pdf.ln(2)
continue
if len(line) > 120:
line = line[:120] + "..."
try:
pdf.multi_cell(0, 5, line)
except Exception:
pass
return Response(
content=pdf.output(),
media_type="application/pdf",
headers={"Content-Disposition": f"attachment; filename={ticker}_{date}_report.pdf"},
)
# ============== Portfolio ==============
import sys
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from api.portfolio import (
create_legacy_portfolio_gateway,
get_watchlist, add_to_watchlist, remove_from_watchlist,
get_positions, add_position, remove_position,
get_accounts, create_account, delete_account,
)
# --- Watchlist ---
@app.get("/api/portfolio/watchlist")
async def list_watchlist(api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
return {"watchlist": get_watchlist()}
@app.post("/api/portfolio/watchlist")
async def create_watchlist_entry(body: dict, api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
try:
entry = add_to_watchlist(body["ticker"], body.get("name", body["ticker"]))
return entry
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.delete("/api/portfolio/watchlist/{ticker}")
async def delete_watchlist_entry(ticker: str, api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
if remove_from_watchlist(ticker):
return {"ok": True}
raise HTTPException(status_code=404, detail="Ticker not found in watchlist")
# --- Accounts ---
@app.get("/api/portfolio/accounts")
async def list_accounts(api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
accounts = get_accounts()
return {"accounts": list(accounts.get("accounts", {}).keys())}
@app.post("/api/portfolio/accounts")
async def create_account_endpoint(body: dict, api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
try:
return create_account(body["account_name"])
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.delete("/api/portfolio/accounts/{account_name}")
async def delete_account_endpoint(account_name: str, api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
if delete_account(account_name):
return {"ok": True}
raise HTTPException(status_code=404, detail="Account not found")
# --- Positions ---
@app.get("/api/portfolio/positions")
async def list_positions(account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
if app.state.migration_flags.use_result_store:
return {"positions": await app.state.result_store.get_positions(account)}
return {"positions": await get_positions(account)}
@app.post("/api/portfolio/positions")
async def create_position(body: dict, api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
try:
pos = add_position(
ticker=body["ticker"],
shares=body["shares"],
cost_price=body["cost_price"],
purchase_date=body.get("purchase_date"),
notes=body.get("notes", ""),
account=body.get("account", "默认账户"),
)
return pos
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.delete("/api/portfolio/positions/{ticker}")
async def delete_position(ticker: str, position_id: Optional[str] = Query(None), account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
removed = remove_position(ticker, position_id or "", account)
if removed:
return {"ok": True}
raise HTTPException(status_code=404, detail="Position not found")
@app.get("/api/portfolio/positions/export")
async def export_positions_csv(account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
if app.state.migration_flags.use_result_store:
positions = await app.state.result_store.get_positions(account)
else:
positions = await get_positions(account)
import csv
import io
output = io.StringIO()
writer = csv.DictWriter(output, fieldnames=["ticker", "shares", "cost_price", "purchase_date", "notes", "account"])
writer.writeheader()
for p in positions:
writer.writerow({k: p[k] for k in ["ticker", "shares", "cost_price", "purchase_date", "notes", "account"]})
return Response(content=output.getvalue(), media_type="text/csv", headers={"Content-Disposition": "attachment; filename=positions.csv"})
# --- Recommendations ---
@app.get("/api/portfolio/recommendations")
async def list_recommendations(
date: Optional[str] = Query(None),
limit: int = Query(DEFAULT_PAGE_SIZE, ge=1, le=MAX_PAGE_SIZE),
offset: int = Query(0, ge=0),
api_key: Optional[str] = Header(None),
):
if not _check_api_key(api_key):
_auth_error()
return app.state.result_store.get_recommendations(date, limit, offset)
@app.get("/api/portfolio/recommendations/{date}/{ticker}")
async def get_recommendation_endpoint(date: str, ticker: str, api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
rec = app.state.result_store.get_recommendation(date, ticker)
if not rec:
raise HTTPException(status_code=404, detail="Recommendation not found")
return rec
# --- Batch Analysis ---
@app.post("/api/portfolio/analyze")
async def start_portfolio_analysis(
http_request: Request,
api_key: Optional[str] = Header(None),
):
"""Trigger batch analysis for all watchlist tickers."""
if not _check_api_key(api_key):
_auth_error()
import uuid
date = datetime.now().strftime("%Y-%m-%d")
task_id = f"port_{date}_{uuid.uuid4().hex[:6]}"
request_context = _build_analysis_request_context(http_request, api_key)
try:
return await app.state.analysis_service.start_portfolio_analysis(
task_id=task_id,
date=date,
request_context=request_context,
broadcast_progress=broadcast_progress,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
except RuntimeError as exc:
raise HTTPException(status_code=500, detail=str(exc))
@app.get("/")
async def root():
# Production mode: serve the built React frontend
frontend_dist = Path(__file__).parent.parent / "frontend" / "dist" / "index.html"
if frontend_dist.exists():
return FileResponse(str(frontend_dist))
return {"message": "TradingAgents Web Dashboard API", "version": "0.1.0"}
@app.websocket("/ws/orchestrator")
async def ws_orchestrator(websocket: WebSocket, api_key: Optional[str] = None):
"""WebSocket endpoint for orchestrator live signals."""
# Auth check before accepting — reject unauthenticated connections
if not _check_api_key(api_key):
await websocket.close(code=4401)
return
import sys
sys.path.insert(0, str(REPO_ROOT))
from orchestrator.config import OrchestratorConfig
from orchestrator.orchestrator import TradingOrchestrator
from orchestrator.live_mode import LiveMode
config = OrchestratorConfig(
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
)
orchestrator = TradingOrchestrator(config)
live = LiveMode(orchestrator)
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
payload = json.loads(data)
tickers = payload.get("tickers", [])
date = payload.get("date")
results = await live.run_once(tickers, date)
await websocket.send_text(json.dumps({
"contract_version": "v1alpha1",
"signals": results,
}))
except WebSocketDisconnect:
pass
except Exception as e:
try:
await websocket.send_text(json.dumps({"error": str(e)}))
except Exception:
pass
@app.get("/health")
async def health():
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", "8000"))
# Production mode: serve the built React frontend
frontend_dist = Path(__file__).parent.parent / "frontend" / "dist"
if frontend_dist.exists():
app.mount("/assets", StaticFiles(directory=str(frontend_dist / "assets")), name="assets")
uvicorn.run(app, host=host, port=port)