1041 lines
35 KiB
Python
1041 lines
35 KiB
Python
"""
|
|
TradingAgents Web Dashboard Backend
|
|
FastAPI REST API + WebSocket for real-time analysis progress
|
|
"""
|
|
import asyncio
|
|
import hmac
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
from contextlib import asynccontextmanager
|
|
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect, Query, Header
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import Response, FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
from tradingagents.default_config import get_default_config, normalize_runtime_llm_config
|
|
|
|
from services import (
|
|
AnalysisService,
|
|
JobService,
|
|
ResultStore,
|
|
TaskCommandService,
|
|
TaskQueryService,
|
|
build_request_context,
|
|
clone_request_context,
|
|
load_migration_flags,
|
|
)
|
|
from services.executor import LegacySubprocessAnalysisExecutor
|
|
|
|
# Path to TradingAgents repo root
|
|
REPO_ROOT = Path(__file__).parent.parent.parent
|
|
_env_file = os.environ.get("TRADINGAGENTS_ENV_FILE")
|
|
if _env_file != "":
|
|
load_dotenv(Path(_env_file) if _env_file else REPO_ROOT / ".env", override=True)
|
|
# Use the currently running Python interpreter
|
|
ANALYSIS_PYTHON = Path(sys.executable)
|
|
# Task state persistence directory
|
|
TASK_STATUS_DIR = Path(__file__).parent / "data" / "task_status"
|
|
CONFIG_PATH = Path(__file__).parent / "data" / "config.json"
|
|
|
|
|
|
# ============== Lifespan ==============
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Startup and shutdown events"""
|
|
app.state.active_connections: dict[str, list[WebSocket]] = {}
|
|
app.state.task_results: dict[str, dict] = {}
|
|
app.state.analysis_tasks: dict[str, asyncio.Task] = {}
|
|
app.state.processes: dict[str, asyncio.subprocess.Process | None] = {}
|
|
app.state.migration_flags = load_migration_flags()
|
|
|
|
portfolio_gateway = create_legacy_portfolio_gateway()
|
|
app.state.result_store = ResultStore(TASK_STATUS_DIR, portfolio_gateway)
|
|
app.state.job_service = JobService(
|
|
task_results=app.state.task_results,
|
|
analysis_tasks=app.state.analysis_tasks,
|
|
processes=app.state.processes,
|
|
persist_task=app.state.result_store.save_task_status,
|
|
delete_task=app.state.result_store.delete_task_status,
|
|
)
|
|
app.state.analysis_service = AnalysisService(
|
|
executor=LegacySubprocessAnalysisExecutor(
|
|
analysis_python=ANALYSIS_PYTHON,
|
|
repo_root=REPO_ROOT,
|
|
api_key_resolver=_get_analysis_provider_api_key,
|
|
process_registry=app.state.job_service.register_process,
|
|
),
|
|
result_store=app.state.result_store,
|
|
job_service=app.state.job_service,
|
|
retry_count=MAX_RETRY_COUNT,
|
|
retry_base_delay_secs=RETRY_BASE_DELAY_SECS,
|
|
)
|
|
app.state.task_query_service = TaskQueryService(
|
|
task_results=app.state.task_results,
|
|
result_store=app.state.result_store,
|
|
job_service=app.state.job_service,
|
|
)
|
|
app.state.task_command_service = TaskCommandService(
|
|
task_results=app.state.task_results,
|
|
analysis_tasks=app.state.analysis_tasks,
|
|
processes=app.state.processes,
|
|
result_store=app.state.result_store,
|
|
job_service=app.state.job_service,
|
|
)
|
|
|
|
# Restore persisted task states from disk
|
|
app.state.job_service.restore_task_results(app.state.result_store.restore_task_results())
|
|
|
|
yield
|
|
|
|
|
|
# ============== App ==============
|
|
|
|
app = FastAPI(
|
|
title="TradingAgents Web Dashboard API",
|
|
version="0.1.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
# CORS: allow all if CORS_ORIGINS is not set (development), otherwise comma-separated list
|
|
_cors_origins = os.environ.get("CORS_ORIGINS", "*")
|
|
_cors_origins_list = ["*"] if _cors_origins == "*" else [o.strip() for o in _cors_origins.split(",")]
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=_cors_origins_list,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# ============== Pydantic Models ==============
|
|
|
|
class AnalysisRequest(BaseModel):
|
|
ticker: str
|
|
date: Optional[str] = None
|
|
portfolio_context: Optional[str] = None
|
|
peer_context: Optional[str] = None
|
|
peer_context_mode: Optional[str] = None
|
|
|
|
class ScreenRequest(BaseModel):
|
|
mode: str = "china_strict"
|
|
|
|
|
|
# ============== Config Commands (Tauri IPC) ==============
|
|
|
|
@app.get("/api/config/check")
|
|
async def check_config():
|
|
"""Check if the analysis provider is configured with a callable API key."""
|
|
configured = bool(_resolve_analysis_runtime_settings().get("provider_api_key"))
|
|
return {"configured": configured}
|
|
|
|
|
|
@app.post("/api/config/apikey")
|
|
async def save_apikey(request: Request, body: dict = None, api_key: Optional[str] = Header(None)):
|
|
"""Persist API key for local desktop/backend use."""
|
|
if _get_api_key():
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
elif not _is_local_request(request):
|
|
raise HTTPException(status_code=403, detail="API key setup is only allowed from localhost")
|
|
|
|
if not body or "api_key" not in body:
|
|
raise HTTPException(status_code=400, detail="api_key is required")
|
|
|
|
apikey = body["api_key"].strip()
|
|
if not apikey:
|
|
raise HTTPException(status_code=400, detail="api_key cannot be empty")
|
|
|
|
try:
|
|
runtime_provider = _resolve_analysis_runtime_settings().get("llm_provider", "anthropic")
|
|
_persist_analysis_api_key(apikey, provider=str(runtime_provider).lower())
|
|
return {"ok": True, "saved": True}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to save API key: {e}")
|
|
|
|
|
|
# ============== Cache Helpers ==============
|
|
|
|
CACHE_DIR = Path(__file__).parent.parent / "cache"
|
|
CACHE_TTL_SECONDS = 300 # 5 minutes
|
|
MAX_RETRY_COUNT = 2
|
|
RETRY_BASE_DELAY_SECS = 1
|
|
MAX_CONCURRENT_YFINANCE = 5
|
|
|
|
# Pagination defaults
|
|
DEFAULT_PAGE_SIZE = 50
|
|
MAX_PAGE_SIZE = 500
|
|
|
|
# Auth — set DASHBOARD_API_KEY env var to enable API key authentication
|
|
_api_key: Optional[str] = None
|
|
|
|
def _get_api_key() -> Optional[str]:
|
|
global _api_key
|
|
if _api_key is None:
|
|
_api_key = os.environ.get("DASHBOARD_API_KEY")
|
|
return _api_key
|
|
|
|
def _check_api_key(api_key: Optional[str]) -> bool:
|
|
"""Return True if no key is required, or if the provided key matches."""
|
|
required = _get_api_key()
|
|
if not required:
|
|
return True
|
|
if not api_key:
|
|
return False
|
|
return hmac.compare_digest(api_key, required)
|
|
|
|
def _auth_error():
|
|
raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required")
|
|
|
|
|
|
def _load_saved_config() -> dict:
|
|
try:
|
|
if CONFIG_PATH.exists():
|
|
return json.loads(CONFIG_PATH.read_text())
|
|
except Exception:
|
|
pass
|
|
return {}
|
|
|
|
|
|
def _persist_analysis_api_key(api_key_value: str, *, provider: str):
|
|
global _api_key
|
|
existing = _load_saved_config()
|
|
api_keys = dict(existing.get("api_keys") or {})
|
|
api_keys[provider] = api_key_value
|
|
payload = dict(existing)
|
|
payload["api_keys"] = api_keys
|
|
payload["api_key_provider"] = provider
|
|
payload["api_key"] = api_key_value
|
|
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|
CONFIG_PATH.write_text(json.dumps(payload, ensure_ascii=False))
|
|
os.chmod(CONFIG_PATH, 0o600)
|
|
_api_key = None
|
|
|
|
|
|
def _get_analysis_provider_api_key(provider: str, saved_config: Optional[dict] = None) -> Optional[str]:
|
|
env_names = {
|
|
"anthropic": ("ANTHROPIC_API_KEY", "MINIMAX_API_KEY"),
|
|
"openai": ("OPENAI_API_KEY",),
|
|
"openrouter": ("OPENROUTER_API_KEY",),
|
|
"xai": ("XAI_API_KEY",),
|
|
"google": ("GOOGLE_API_KEY",),
|
|
"ollama": tuple(),
|
|
}.get(provider.lower(), tuple())
|
|
for env_name in env_names:
|
|
value = os.environ.get(env_name)
|
|
if value:
|
|
return value
|
|
saved = dict(saved_config or {})
|
|
api_keys = saved.get("api_keys")
|
|
if isinstance(api_keys, dict):
|
|
value = api_keys.get(provider.lower())
|
|
if value:
|
|
return value
|
|
legacy_provider = str(saved.get("api_key_provider") or "").lower()
|
|
legacy_key = saved.get("api_key")
|
|
if legacy_provider == provider.lower() and legacy_key:
|
|
return legacy_key
|
|
return None
|
|
|
|
|
|
def _resolve_analysis_runtime_settings() -> dict:
|
|
saved = _load_saved_config()
|
|
defaults = get_default_config()
|
|
|
|
provider = os.environ.get("TRADINGAGENTS_LLM_PROVIDER")
|
|
if not provider:
|
|
if os.environ.get("ANTHROPIC_BASE_URL"):
|
|
provider = "anthropic"
|
|
elif os.environ.get("OPENAI_BASE_URL"):
|
|
provider = "openai"
|
|
else:
|
|
provider = defaults.get("llm_provider", "anthropic")
|
|
|
|
backend_url = (
|
|
os.environ.get("TRADINGAGENTS_BACKEND_URL")
|
|
or os.environ.get("ANTHROPIC_BASE_URL")
|
|
or os.environ.get("OPENAI_BASE_URL")
|
|
or defaults.get("backend_url")
|
|
)
|
|
deep_model = (
|
|
os.environ.get("TRADINGAGENTS_DEEP_MODEL")
|
|
or os.environ.get("TRADINGAGENTS_MODEL")
|
|
or defaults.get("deep_think_llm")
|
|
)
|
|
quick_model = (
|
|
os.environ.get("TRADINGAGENTS_QUICK_MODEL")
|
|
or os.environ.get("TRADINGAGENTS_MODEL")
|
|
or defaults.get("quick_think_llm")
|
|
)
|
|
selected_analysts_raw = os.environ.get("TRADINGAGENTS_SELECTED_ANALYSTS", "market")
|
|
selected_analysts = [item.strip() for item in selected_analysts_raw.split(",") if item.strip()]
|
|
analysis_prompt_style = os.environ.get("TRADINGAGENTS_ANALYSIS_PROMPT_STYLE", "compact")
|
|
llm_timeout = float(
|
|
os.environ.get(
|
|
"TRADINGAGENTS_LLM_TIMEOUT",
|
|
str(defaults.get("llm_timeout", 45)),
|
|
)
|
|
)
|
|
llm_max_retries = int(
|
|
os.environ.get(
|
|
"TRADINGAGENTS_LLM_MAX_RETRIES",
|
|
str(defaults.get("llm_max_retries", 0)),
|
|
)
|
|
)
|
|
settings = {
|
|
"llm_provider": provider,
|
|
"backend_url": backend_url,
|
|
"deep_think_llm": deep_model,
|
|
"quick_think_llm": quick_model,
|
|
"selected_analysts": selected_analysts,
|
|
"analysis_prompt_style": analysis_prompt_style,
|
|
"llm_timeout": llm_timeout,
|
|
"llm_max_retries": llm_max_retries,
|
|
"provider_api_key": _get_analysis_provider_api_key(provider, saved),
|
|
}
|
|
return normalize_runtime_llm_config(settings)
|
|
|
|
|
|
def _build_analysis_request_context(request: Request, auth_key: Optional[str]):
|
|
settings = _resolve_analysis_runtime_settings()
|
|
return build_request_context(
|
|
request,
|
|
auth_key=auth_key,
|
|
provider_api_key=settings["provider_api_key"],
|
|
llm_provider=settings["llm_provider"],
|
|
backend_url=settings["backend_url"],
|
|
deep_think_llm=settings["deep_think_llm"],
|
|
quick_think_llm=settings["quick_think_llm"],
|
|
selected_analysts=settings["selected_analysts"],
|
|
analysis_prompt_style=settings["analysis_prompt_style"],
|
|
llm_timeout=settings["llm_timeout"],
|
|
llm_max_retries=settings["llm_max_retries"],
|
|
metadata={
|
|
"stdout_timeout_secs": max(float(settings["llm_timeout"]) * 4.0, 120.0),
|
|
"total_timeout_secs": max(float(settings["llm_timeout"]) * 12.0, 900.0),
|
|
"heartbeat_interval_secs": 10.0,
|
|
"local_recovery_timeout_secs": max(float(settings["llm_timeout"]) * 2.5, 90.0),
|
|
"provider_probe_timeout_secs": max(float(settings["llm_timeout"]) * 1.5, 60.0),
|
|
"local_recovery_cost_cap": 1.0,
|
|
"provider_probe_cost_cap": 1.0,
|
|
},
|
|
)
|
|
|
|
|
|
def _is_local_request(request: Request) -> bool:
|
|
client = request.client
|
|
if client is None:
|
|
return False
|
|
return client.host in {"127.0.0.1", "::1", "localhost", "testclient"}
|
|
|
|
|
|
def _get_cache_path(mode: str) -> Path:
|
|
return CACHE_DIR / f"screen_{mode}.json"
|
|
|
|
|
|
def _load_from_cache(mode: str) -> Optional[dict]:
|
|
cache_path = _get_cache_path(mode)
|
|
if not cache_path.exists():
|
|
return None
|
|
try:
|
|
age = time.time() - cache_path.stat().st_mtime
|
|
if age > CACHE_TTL_SECONDS:
|
|
return None
|
|
with open(cache_path) as f:
|
|
return json.load(f)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _save_to_cache(mode: str, data: dict):
|
|
"""Save screening result to cache"""
|
|
try:
|
|
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
cache_path = _get_cache_path(mode)
|
|
with open(cache_path, "w") as f:
|
|
json.dump(data, f)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
# ============== SEPA Screening ==============
|
|
|
|
def _run_sepa_screening(mode: str) -> dict:
|
|
"""Run SEPA screening synchronously in thread"""
|
|
sys.path.insert(0, str(REPO_ROOT))
|
|
from sepa_screener import screen_all, china_stocks
|
|
results = screen_all(mode=mode, max_workers=5)
|
|
total = len(china_stocks)
|
|
return {
|
|
"mode": mode,
|
|
"total_stocks": total,
|
|
"passed": len(results),
|
|
"results": results,
|
|
}
|
|
|
|
|
|
@app.get("/api/stocks/screen")
|
|
async def screen_stocks(mode: str = Query("china_strict"), refresh: bool = Query(False), api_key: Optional[str] = Header(None)):
|
|
"""Screen stocks using SEPA criteria with caching"""
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
if not refresh:
|
|
cached = _load_from_cache(mode)
|
|
if cached:
|
|
return {**cached, "cached": True}
|
|
|
|
# Run in thread pool (blocks thread but not event loop)
|
|
loop = asyncio.get_event_loop()
|
|
result = await loop.run_in_executor(None, lambda: _run_sepa_screening(mode))
|
|
|
|
_save_to_cache(mode, result)
|
|
return {**result, "cached": False}
|
|
|
|
|
|
# ============== Analysis Execution ==============
|
|
|
|
@app.post("/api/analysis/start")
|
|
async def start_analysis(
|
|
payload: AnalysisRequest,
|
|
http_request: Request,
|
|
api_key: Optional[str] = Header(None),
|
|
):
|
|
"""Start a new analysis task."""
|
|
import uuid
|
|
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
|
|
task_id = f"{payload.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
|
|
date = payload.date or datetime.now().strftime("%Y-%m-%d")
|
|
request_context = _build_analysis_request_context(http_request, api_key)
|
|
if payload.portfolio_context is not None or payload.peer_context is not None:
|
|
request_context = clone_request_context(
|
|
request_context,
|
|
metadata_updates={
|
|
"portfolio_context": payload.portfolio_context,
|
|
"peer_context": payload.peer_context,
|
|
"peer_context_mode": payload.peer_context_mode or ("CALLER_PROVIDED" if payload.peer_context else None),
|
|
},
|
|
)
|
|
|
|
try:
|
|
return await app.state.analysis_service.start_analysis(
|
|
task_id=task_id,
|
|
ticker=payload.ticker,
|
|
date=date,
|
|
request_context=request_context,
|
|
broadcast_progress=broadcast_progress,
|
|
)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc))
|
|
except RuntimeError as exc:
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
@app.get("/api/analysis/status/{task_id}")
|
|
async def get_task_status(task_id: str, api_key: Optional[str] = Header(None)):
|
|
"""Get task status"""
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
payload = app.state.task_query_service.public_task_payload(task_id)
|
|
if payload is None:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
return payload
|
|
|
|
|
|
@app.get("/api/analysis/tasks")
|
|
async def list_tasks(api_key: Optional[str] = Header(None)):
|
|
"""List all tasks (active and recent)"""
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
return app.state.task_query_service.list_task_summaries()
|
|
|
|
|
|
@app.delete("/api/analysis/cancel/{task_id}")
|
|
async def cancel_task(task_id: str, api_key: Optional[str] = Header(None)):
|
|
"""Cancel a running task."""
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
payload = await app.state.task_command_service.cancel_task(
|
|
task_id,
|
|
broadcast_progress=broadcast_progress,
|
|
)
|
|
if payload is None:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
return payload
|
|
|
|
|
|
# ============== WebSocket ==============
|
|
|
|
@app.websocket("/ws/analysis/{task_id}")
|
|
async def websocket_analysis(websocket: WebSocket, task_id: str):
|
|
"""WebSocket for real-time analysis progress. Auth via ?api_key= query param."""
|
|
# Optional API key check for WebSocket
|
|
ws_api_key = websocket.query_params.get("api_key")
|
|
if not _check_api_key(ws_api_key):
|
|
await websocket.close(code=4001, reason="Unauthorized")
|
|
return
|
|
await websocket.accept()
|
|
|
|
if task_id not in app.state.active_connections:
|
|
app.state.active_connections[task_id] = []
|
|
app.state.active_connections[task_id].append(websocket)
|
|
|
|
# Send current state immediately if available
|
|
if task_id in app.state.task_results:
|
|
await websocket.send_text(json.dumps({
|
|
"type": "progress",
|
|
**_public_task_payload(task_id)
|
|
}))
|
|
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_text()
|
|
message = json.loads(data)
|
|
if message.get("type") == "ping":
|
|
await websocket.send_text(json.dumps({"type": "pong"}))
|
|
except WebSocketDisconnect:
|
|
if task_id in app.state.active_connections:
|
|
app.state.active_connections[task_id].remove(websocket)
|
|
|
|
|
|
async def broadcast_progress(task_id: str, progress: dict):
|
|
"""Broadcast progress to all connections for a task"""
|
|
if task_id not in app.state.active_connections:
|
|
return
|
|
|
|
payload = _public_task_payload(task_id, state_override=progress)
|
|
message = json.dumps({"type": "progress", **payload})
|
|
dead = []
|
|
|
|
for connection in app.state.active_connections[task_id]:
|
|
try:
|
|
await connection.send_text(message)
|
|
except Exception:
|
|
dead.append(connection)
|
|
|
|
for conn in dead:
|
|
app.state.active_connections[task_id].remove(conn)
|
|
|
|
|
|
def _load_task_contract(task_id: str, state: Optional[dict] = None) -> Optional[dict]:
|
|
return app.state.task_query_service.load_task_contract(task_id, state_override=state)
|
|
|
|
|
|
def _public_task_payload(task_id: str, state_override: Optional[dict] = None) -> dict:
|
|
payload = app.state.task_query_service.public_task_payload(task_id, state_override=state_override)
|
|
if payload is None:
|
|
raise KeyError(task_id)
|
|
return payload
|
|
|
|
|
|
def _public_task_summary(task_id: str, state_override: Optional[dict] = None) -> dict:
|
|
summary = app.state.task_query_service.public_task_summary(task_id, state_override=state_override)
|
|
if summary is None:
|
|
raise KeyError(task_id)
|
|
return summary
|
|
|
|
|
|
# ============== Reports ==============
|
|
|
|
def get_results_dir() -> Path:
|
|
return Path(__file__).parent.parent.parent / "results"
|
|
|
|
|
|
def get_reports_list():
|
|
"""Get all historical reports"""
|
|
results_dir = get_results_dir()
|
|
reports = []
|
|
if not results_dir.exists():
|
|
return reports
|
|
for ticker_dir in results_dir.iterdir():
|
|
if ticker_dir.is_dir() and ticker_dir.name != "TradingAgentsStrategy_logs":
|
|
ticker = ticker_dir.name
|
|
for date_dir in ticker_dir.iterdir():
|
|
# Skip non-date directories like TradingAgentsStrategy_logs
|
|
if date_dir.is_dir() and date_dir.name.startswith("20"):
|
|
reports.append({
|
|
"ticker": ticker,
|
|
"date": date_dir.name,
|
|
"path": str(date_dir)
|
|
})
|
|
return sorted(reports, key=lambda x: x["date"], reverse=True)
|
|
|
|
|
|
def get_report_content(ticker: str, date: str) -> Optional[dict]:
|
|
"""Get report content for a specific ticker and date"""
|
|
# Validate inputs to prevent path traversal
|
|
if ".." in ticker or "/" in ticker or "\\" in ticker:
|
|
return None
|
|
if ".." in date or "/" in date or "\\" in date:
|
|
return None
|
|
report_dir = get_results_dir() / ticker / date
|
|
# Strict traversal check: resolved path must be within get_results_dir()
|
|
try:
|
|
report_dir.resolve().relative_to(get_results_dir().resolve())
|
|
except ValueError:
|
|
return None
|
|
if not report_dir.exists():
|
|
return None
|
|
content = {}
|
|
complete_report = report_dir / "complete_report.md"
|
|
if complete_report.exists():
|
|
content["report"] = complete_report.read_text()
|
|
for stage in ["1_analysts", "2_research", "3_trading", "4_risk", "5_portfolio"]:
|
|
stage_dir = report_dir / "reports" / stage
|
|
if stage_dir.exists():
|
|
for f in stage_dir.glob("*.md"):
|
|
content[f.name] = f.read_text()
|
|
return content
|
|
|
|
|
|
@app.get("/api/reports/list")
|
|
async def list_reports(
|
|
limit: int = Query(DEFAULT_PAGE_SIZE, ge=1, le=MAX_PAGE_SIZE),
|
|
offset: int = Query(0, ge=0),
|
|
api_key: Optional[str] = Header(None),
|
|
):
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
reports = get_reports_list()
|
|
total = len(reports)
|
|
return {
|
|
"reports": sorted(reports, key=lambda x: x["date"], reverse=True)[offset : offset + limit],
|
|
"total": total,
|
|
"limit": limit,
|
|
"offset": offset,
|
|
}
|
|
|
|
|
|
@app.get("/api/reports/{ticker}/{date}")
|
|
async def get_report(ticker: str, date: str, api_key: Optional[str] = Header(None)):
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
content = get_report_content(ticker, date)
|
|
if not content:
|
|
raise HTTPException(status_code=404, detail="Report not found")
|
|
return content
|
|
|
|
|
|
# ============== Report Export ==============
|
|
|
|
import csv
|
|
import io
|
|
import re
|
|
from fpdf import FPDF
|
|
|
|
|
|
def _extract_decision(markdown_text: str) -> str:
|
|
"""Extract BUY/OVERWEIGHT/SELL/UNDERWEIGHT/HOLD from markdown bold text."""
|
|
match = re.search(r'\*\*(BUY|SELL|HOLD|OVERWEIGHT|UNDERWEIGHT)\*\*', markdown_text)
|
|
return match.group(1) if match else 'UNKNOWN'
|
|
|
|
|
|
def _extract_summary(markdown_text: str) -> str:
|
|
"""Extract first ~200 chars after '## 分析摘要'."""
|
|
match = re.search(r'## 分析摘要\s*\n+(.{0,300}?)(?=\n##|\Z)', markdown_text, re.DOTALL)
|
|
if match:
|
|
text = match.group(1).strip()
|
|
# Strip markdown formatting
|
|
text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
|
|
text = re.sub(r'\*(.*?)\*', r'\1', text)
|
|
text = re.sub(r'[#\n]+', ' ', text)
|
|
return text[:200].strip()
|
|
return ''
|
|
|
|
|
|
@app.get("/api/reports/export")
|
|
async def export_reports_csv(
|
|
api_key: Optional[str] = Header(None),
|
|
):
|
|
"""Export all reports as CSV: ticker,date,decision,summary."""
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
reports = get_reports_list()
|
|
output = io.StringIO()
|
|
writer = csv.DictWriter(output, fieldnames=["ticker", "date", "decision", "summary"])
|
|
writer.writeheader()
|
|
for r in reports:
|
|
content = get_report_content(r["ticker"], r["date"])
|
|
if content and content.get("report"):
|
|
writer.writerow({
|
|
"ticker": r["ticker"],
|
|
"date": r["date"],
|
|
"decision": _extract_decision(content["report"]),
|
|
"summary": _extract_summary(content["report"]),
|
|
})
|
|
else:
|
|
writer.writerow({
|
|
"ticker": r["ticker"],
|
|
"date": r["date"],
|
|
"decision": "UNKNOWN",
|
|
"summary": "",
|
|
})
|
|
return Response(
|
|
content=output.getvalue(),
|
|
media_type="text/csv",
|
|
headers={"Content-Disposition": "attachment; filename=tradingagents_reports.csv"},
|
|
)
|
|
|
|
|
|
@app.get("/api/reports/{ticker}/{date}/pdf")
|
|
async def export_report_pdf(ticker: str, date: str, api_key: Optional[str] = Header(None)):
|
|
"""Export a single report as PDF."""
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
content = get_report_content(ticker, date)
|
|
if not content or not content.get("report"):
|
|
raise HTTPException(status_code=404, detail="Report not found")
|
|
|
|
markdown_text = content["report"]
|
|
decision = _extract_decision(markdown_text)
|
|
summary = _extract_summary(markdown_text)
|
|
|
|
pdf = FPDF()
|
|
pdf.set_auto_page_break(auto=True, margin=20)
|
|
|
|
# Try multiple font paths for cross-platform support
|
|
font_paths = [
|
|
"/System/Library/Fonts/Supplemental/DejaVuSans.ttf",
|
|
"/System/Library/Fonts/Supplemental/DejaVuSans-Bold.ttf",
|
|
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
|
|
"/usr/share/fonts/dejavu/DejaVuSans.ttf",
|
|
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
|
|
"/usr/share/fonts/dejavu/DejaVuSans-Bold.ttf",
|
|
str(Path.home() / ".local/share/fonts/DejaVuSans.ttf"),
|
|
str(Path.home() / ".fonts/DejaVuSans.ttf"),
|
|
]
|
|
regular_font = None
|
|
bold_font = None
|
|
for p in font_paths:
|
|
if Path(p).exists():
|
|
if "Bold" in p and bold_font is None:
|
|
bold_font = p
|
|
elif regular_font is None and "Bold" not in p:
|
|
regular_font = p
|
|
|
|
use_dejavu = bool(regular_font and bold_font)
|
|
if use_dejavu:
|
|
pdf.add_font("DejaVu", "", regular_font, unicode=True)
|
|
pdf.add_font("DejaVu", "B", bold_font, unicode=True)
|
|
font_regular = "DejaVu"
|
|
font_bold = "DejaVu"
|
|
else:
|
|
font_regular = "Helvetica"
|
|
font_bold = "Helvetica"
|
|
|
|
pdf.add_page()
|
|
pdf.set_font(font_bold, "B", 18)
|
|
pdf.cell(0, 12, f"TradingAgents 分析报告", ln=True, align="C")
|
|
pdf.ln(5)
|
|
|
|
pdf.set_font(font_regular, "", 11)
|
|
pdf.cell(0, 8, f"股票: {ticker} 日期: {date}", ln=True)
|
|
pdf.ln(3)
|
|
|
|
# Decision badge
|
|
pdf.set_font(font_bold, "B", 14)
|
|
if decision == "BUY":
|
|
pdf.set_text_color(34, 197, 94)
|
|
elif decision == "OVERWEIGHT":
|
|
pdf.set_text_color(134, 239, 172)
|
|
elif decision == "SELL":
|
|
pdf.set_text_color(220, 38, 38)
|
|
elif decision == "UNDERWEIGHT":
|
|
pdf.set_text_color(252, 165, 165)
|
|
else:
|
|
pdf.set_text_color(245, 158, 11)
|
|
pdf.cell(0, 10, f"决策: {decision}", ln=True)
|
|
pdf.set_text_color(0, 0, 0)
|
|
pdf.ln(5)
|
|
|
|
# Summary
|
|
pdf.set_font(font_bold, "B", 12)
|
|
pdf.cell(0, 8, "分析摘要", ln=True)
|
|
pdf.set_font(font_regular, "", 10)
|
|
pdf.multi_cell(0, 6, summary or "无")
|
|
pdf.ln(5)
|
|
|
|
# Full report text (stripped of heavy markdown)
|
|
pdf.set_font(font_bold, "B", 12)
|
|
pdf.cell(0, 8, "完整报告", ln=True)
|
|
pdf.set_font(font_regular, "", 9)
|
|
# Split into lines, filter out very long lines
|
|
for line in markdown_text.splitlines():
|
|
line = re.sub(r'\*\*(.*?)\*\*', r'\1', line)
|
|
line = re.sub(r'\*(.*?)\*', r'\1', line)
|
|
line = re.sub(r'#{1,6} ', '', line)
|
|
line = line.strip()
|
|
if not line:
|
|
pdf.ln(2)
|
|
continue
|
|
if len(line) > 120:
|
|
line = line[:120] + "..."
|
|
try:
|
|
pdf.multi_cell(0, 5, line)
|
|
except Exception:
|
|
pass
|
|
|
|
return Response(
|
|
content=pdf.output(),
|
|
media_type="application/pdf",
|
|
headers={"Content-Disposition": f"attachment; filename={ticker}_{date}_report.pdf"},
|
|
)
|
|
|
|
|
|
# ============== Portfolio ==============
|
|
|
|
import sys
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
from api.portfolio import (
|
|
create_legacy_portfolio_gateway,
|
|
get_watchlist, add_to_watchlist, remove_from_watchlist,
|
|
get_positions, add_position, remove_position,
|
|
get_accounts, create_account, delete_account,
|
|
)
|
|
|
|
|
|
# --- Watchlist ---
|
|
|
|
@app.get("/api/portfolio/watchlist")
|
|
async def list_watchlist(api_key: Optional[str] = Header(None)):
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
return {"watchlist": get_watchlist()}
|
|
|
|
|
|
@app.post("/api/portfolio/watchlist")
|
|
async def create_watchlist_entry(body: dict, api_key: Optional[str] = Header(None)):
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
try:
|
|
entry = add_to_watchlist(body["ticker"], body.get("name", body["ticker"]))
|
|
return entry
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@app.delete("/api/portfolio/watchlist/{ticker}")
|
|
async def delete_watchlist_entry(ticker: str, api_key: Optional[str] = Header(None)):
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
if remove_from_watchlist(ticker):
|
|
return {"ok": True}
|
|
raise HTTPException(status_code=404, detail="Ticker not found in watchlist")
|
|
|
|
|
|
# --- Accounts ---
|
|
|
|
@app.get("/api/portfolio/accounts")
|
|
async def list_accounts(api_key: Optional[str] = Header(None)):
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
accounts = get_accounts()
|
|
return {"accounts": list(accounts.get("accounts", {}).keys())}
|
|
|
|
|
|
@app.post("/api/portfolio/accounts")
|
|
async def create_account_endpoint(body: dict, api_key: Optional[str] = Header(None)):
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
try:
|
|
return create_account(body["account_name"])
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@app.delete("/api/portfolio/accounts/{account_name}")
|
|
async def delete_account_endpoint(account_name: str, api_key: Optional[str] = Header(None)):
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
if delete_account(account_name):
|
|
return {"ok": True}
|
|
raise HTTPException(status_code=404, detail="Account not found")
|
|
|
|
|
|
# --- Positions ---
|
|
|
|
@app.get("/api/portfolio/positions")
|
|
async def list_positions(account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)):
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
if app.state.migration_flags.use_result_store:
|
|
return {"positions": await app.state.result_store.get_positions(account)}
|
|
return {"positions": await get_positions(account)}
|
|
|
|
|
|
@app.post("/api/portfolio/positions")
|
|
async def create_position(body: dict, api_key: Optional[str] = Header(None)):
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
try:
|
|
pos = add_position(
|
|
ticker=body["ticker"],
|
|
shares=body["shares"],
|
|
cost_price=body["cost_price"],
|
|
purchase_date=body.get("purchase_date"),
|
|
notes=body.get("notes", ""),
|
|
account=body.get("account", "默认账户"),
|
|
)
|
|
return pos
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@app.delete("/api/portfolio/positions/{ticker}")
|
|
async def delete_position(ticker: str, position_id: Optional[str] = Query(None), account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)):
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
removed = remove_position(ticker, position_id or "", account)
|
|
if removed:
|
|
return {"ok": True}
|
|
raise HTTPException(status_code=404, detail="Position not found")
|
|
|
|
|
|
@app.get("/api/portfolio/positions/export")
|
|
async def export_positions_csv(account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)):
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
if app.state.migration_flags.use_result_store:
|
|
positions = await app.state.result_store.get_positions(account)
|
|
else:
|
|
positions = await get_positions(account)
|
|
import csv
|
|
import io
|
|
output = io.StringIO()
|
|
writer = csv.DictWriter(output, fieldnames=["ticker", "shares", "cost_price", "purchase_date", "notes", "account"])
|
|
writer.writeheader()
|
|
for p in positions:
|
|
writer.writerow({k: p[k] for k in ["ticker", "shares", "cost_price", "purchase_date", "notes", "account"]})
|
|
return Response(content=output.getvalue(), media_type="text/csv", headers={"Content-Disposition": "attachment; filename=positions.csv"})
|
|
|
|
|
|
# --- Recommendations ---
|
|
|
|
@app.get("/api/portfolio/recommendations")
|
|
async def list_recommendations(
|
|
date: Optional[str] = Query(None),
|
|
limit: int = Query(DEFAULT_PAGE_SIZE, ge=1, le=MAX_PAGE_SIZE),
|
|
offset: int = Query(0, ge=0),
|
|
api_key: Optional[str] = Header(None),
|
|
):
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
return app.state.result_store.get_recommendations(date, limit, offset)
|
|
|
|
|
|
@app.get("/api/portfolio/recommendations/{date}/{ticker}")
|
|
async def get_recommendation_endpoint(date: str, ticker: str, api_key: Optional[str] = Header(None)):
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
rec = app.state.result_store.get_recommendation(date, ticker)
|
|
if not rec:
|
|
raise HTTPException(status_code=404, detail="Recommendation not found")
|
|
return rec
|
|
|
|
|
|
# --- Batch Analysis ---
|
|
|
|
@app.post("/api/portfolio/analyze")
|
|
async def start_portfolio_analysis(
|
|
http_request: Request,
|
|
api_key: Optional[str] = Header(None),
|
|
):
|
|
"""Trigger batch analysis for all watchlist tickers."""
|
|
if not _check_api_key(api_key):
|
|
_auth_error()
|
|
|
|
import uuid
|
|
|
|
date = datetime.now().strftime("%Y-%m-%d")
|
|
task_id = f"port_{date}_{uuid.uuid4().hex[:6]}"
|
|
request_context = _build_analysis_request_context(http_request, api_key)
|
|
|
|
try:
|
|
return await app.state.analysis_service.start_portfolio_analysis(
|
|
task_id=task_id,
|
|
date=date,
|
|
request_context=request_context,
|
|
broadcast_progress=broadcast_progress,
|
|
)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc))
|
|
except RuntimeError as exc:
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
# Production mode: serve the built React frontend
|
|
frontend_dist = Path(__file__).parent.parent / "frontend" / "dist" / "index.html"
|
|
if frontend_dist.exists():
|
|
return FileResponse(str(frontend_dist))
|
|
return {"message": "TradingAgents Web Dashboard API", "version": "0.1.0"}
|
|
|
|
|
|
@app.websocket("/ws/orchestrator")
|
|
async def ws_orchestrator(websocket: WebSocket, api_key: Optional[str] = None):
|
|
"""WebSocket endpoint for orchestrator live signals."""
|
|
# Auth check before accepting — reject unauthenticated connections
|
|
if not _check_api_key(api_key):
|
|
await websocket.close(code=4401)
|
|
return
|
|
|
|
import sys
|
|
sys.path.insert(0, str(REPO_ROOT))
|
|
from orchestrator.config import OrchestratorConfig
|
|
from orchestrator.orchestrator import TradingOrchestrator
|
|
from orchestrator.live_mode import LiveMode
|
|
|
|
config = OrchestratorConfig(
|
|
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
|
|
)
|
|
orchestrator = TradingOrchestrator(config)
|
|
live = LiveMode(orchestrator)
|
|
|
|
await websocket.accept()
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_text()
|
|
payload = json.loads(data)
|
|
tickers = payload.get("tickers", [])
|
|
date = payload.get("date")
|
|
|
|
results = await live.run_once(tickers, date)
|
|
await websocket.send_text(json.dumps({
|
|
"contract_version": "v1alpha1",
|
|
"signals": results,
|
|
}))
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except Exception as e:
|
|
try:
|
|
await websocket.send_text(json.dumps({"error": str(e)}))
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
host = os.environ.get("HOST", "0.0.0.0")
|
|
port = int(os.environ.get("PORT", "8000"))
|
|
# Production mode: serve the built React frontend
|
|
frontend_dist = Path(__file__).parent.parent / "frontend" / "dist"
|
|
if frontend_dist.exists():
|
|
app.mount("/assets", StaticFiles(directory=str(frontend_dist / "assets")), name="assets")
|
|
uvicorn.run(app, host=host, port=port)
|