fix: add API key auth, pagination, and configurable CORS to dashboard API (#3)

Security hardening:
- API key authentication via X-API-Key header on all endpoints
  (opt-in: set DASHBOARD_API_KEY or ANTHROPIC_API_KEY env var to enable)
  If no key is set, endpoints remain open (backward-compatible)
- WebSocket auth via ?api_key= query parameter
- CORS now configurable via CORS_ORIGINS env var (default: allow all)

Pagination (all list endpoints):
- GET /api/reports/list — limit/offset with total count
- GET /api/portfolio/recommendations — limit/offset with total count
- DEFAULT_PAGE_SIZE=50, MAX_PAGE_SIZE=500

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Shaojie 2026-04-07 18:57:51 +08:00 committed by GitHub
parent f19c1c012e
commit 1cee59dd9f
2 changed files with 149 additions and 42 deletions

View File

@ -279,27 +279,38 @@ def remove_position(ticker: str, position_id: str, account: Optional[str]) -> bo
# ============== Recommendations ==============
def get_recommendations(date: Optional[str] = None) -> list:
"""List recommendations, optionally filtered by date."""
# Pagination defaults (must match main.py constants)
DEFAULT_PAGE_SIZE = 50
MAX_PAGE_SIZE = 500
def get_recommendations(date: Optional[str] = None, limit: int = DEFAULT_PAGE_SIZE, offset: int = 0) -> dict:
"""List recommendations, optionally filtered by date. Returns paginated results."""
RECOMMENDATIONS_DIR.mkdir(parents=True, exist_ok=True)
all_recs = []
if date:
date_dir = RECOMMENDATIONS_DIR / date
if not date_dir.exists():
return []
return [
json.loads(f.read_text())
for f in date_dir.glob("*.json")
if f.suffix == ".json"
]
if date_dir.exists():
all_recs = [
json.loads(f.read_text())
for f in sorted(date_dir.glob("*.json"), reverse=True)
if f.suffix == ".json"
]
else:
# Return most recent first
all_recs = []
for date_dir in sorted(RECOMMENDATIONS_DIR.iterdir(), reverse=True):
if date_dir.is_dir() and date_dir.name.startswith("20"):
for f in date_dir.glob("*.json"):
for f in sorted(date_dir.glob("*.json"), reverse=True):
if f.suffix == ".json":
all_recs.append(json.loads(f.read_text()))
return all_recs
total = len(all_recs)
return {
"recommendations": all_recs[offset : offset + limit],
"total": total,
"limit": limit,
"offset": offset,
}
def get_recommendation(date: str, ticker: str) -> Optional[dict]:

View File

@ -58,9 +58,13 @@ app = FastAPI(
lifespan=lifespan
)
# CORS: allow all if CORS_ORIGINS is not set (development), otherwise comma-separated list
_cors_origins = os.environ.get("CORS_ORIGINS", "*")
_cors_origins_list = ["*"] if _cors_origins == "*" else [o.strip() for o in _cors_origins.split(",")]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_origins=_cors_origins_list,
allow_methods=["*"],
allow_headers=["*"],
)
@ -83,6 +87,29 @@ MAX_RETRY_COUNT = 2
RETRY_BASE_DELAY_SECS = 1
MAX_CONCURRENT_YFINANCE = 5
# Pagination defaults
DEFAULT_PAGE_SIZE = 50
MAX_PAGE_SIZE = 500
# Auth — set DASHBOARD_API_KEY env var to enable API key authentication
_api_key: Optional[str] = None
def _get_api_key() -> Optional[str]:
global _api_key
if _api_key is None:
_api_key = os.environ.get("DASHBOARD_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
return _api_key
def _check_api_key(api_key: Optional[str]) -> bool:
"""Return True if no key is required, or if the provided key matches."""
required = _get_api_key()
if not required:
return True
return api_key == required
def _auth_error():
raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required")
def _get_cache_path(mode: str) -> Path:
return CACHE_DIR / f"screen_{mode}.json"
@ -147,8 +174,10 @@ def _run_sepa_screening(mode: str) -> dict:
@app.get("/api/stocks/screen")
async def screen_stocks(mode: str = Query("china_strict"), refresh: bool = Query(False)):
async def screen_stocks(mode: str = Query("china_strict"), refresh: bool = Query(False), api_key: Optional[str] = Header(None)):
"""Screen stocks using SEPA criteria with caching"""
if not _check_api_key(api_key):
_auth_error()
if not refresh:
cached = _load_from_cache(mode)
if cached:
@ -225,15 +254,19 @@ print("ANALYSIS_COMPLETE:" + signal, flush=True)
@app.post("/api/analysis/start")
async def start_analysis(request: AnalysisRequest):
async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Header(None)):
"""Start a new analysis task"""
import uuid
task_id = f"{request.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
date = request.date or datetime.now().strftime("%Y-%m-%d")
# Validate API key before storing any task state
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
# Check dashboard API key (opt-in auth)
if not _check_api_key(api_key):
_auth_error()
# Validate ANTHROPIC_API_KEY for the analysis subprocess
anthropic_key = os.environ.get("ANTHROPIC_API_KEY")
if not anthropic_key:
raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set")
# Initialize task state
@ -410,16 +443,20 @@ async def start_analysis(request: AnalysisRequest):
@app.get("/api/analysis/status/{task_id}")
async def get_task_status(task_id: str):
async def get_task_status(task_id: str, api_key: Optional[str] = Header(None)):
"""Get task status"""
if not _check_api_key(api_key):
_auth_error()
if task_id not in app.state.task_results:
raise HTTPException(status_code=404, detail="Task not found")
return app.state.task_results[task_id]
@app.get("/api/analysis/tasks")
async def list_tasks():
async def list_tasks(api_key: Optional[str] = Header(None)):
"""List all tasks (active and recent)"""
if not _check_api_key(api_key):
_auth_error()
tasks = []
for task_id, state in app.state.task_results.items():
tasks.append({
@ -438,8 +475,10 @@ async def list_tasks():
@app.delete("/api/analysis/cancel/{task_id}")
async def cancel_task(task_id: str):
async def cancel_task(task_id: str, api_key: Optional[str] = Header(None)):
"""Cancel a running task"""
if not _check_api_key(api_key):
_auth_error()
if task_id not in app.state.task_results:
raise HTTPException(status_code=404, detail="Task not found")
@ -477,7 +516,12 @@ async def cancel_task(task_id: str):
@app.websocket("/ws/analysis/{task_id}")
async def websocket_analysis(websocket: WebSocket, task_id: str):
"""WebSocket for real-time analysis progress"""
"""WebSocket for real-time analysis progress. Auth via ?api_key= query param."""
# Optional API key check for WebSocket
ws_api_key = websocket.query_params.get("api_key")
if not _check_api_key(ws_api_key):
await websocket.close(code=4001, reason="Unauthorized")
return
await websocket.accept()
if task_id not in app.state.active_connections:
@ -574,12 +618,27 @@ def get_report_content(ticker: str, date: str) -> Optional[dict]:
@app.get("/api/reports/list")
async def list_reports():
return get_reports_list()
async def list_reports(
limit: int = Query(DEFAULT_PAGE_SIZE, ge=1, le=MAX_PAGE_SIZE),
offset: int = Query(0, ge=0),
api_key: Optional[str] = Header(None),
):
if not _check_api_key(api_key):
_auth_error()
reports = get_reports_list()
total = len(reports)
return {
"reports": sorted(reports, key=lambda x: x["date"], reverse=True)[offset : offset + limit],
"total": total,
"limit": limit,
"offset": offset,
}
@app.get("/api/reports/{ticker}/{date}")
async def get_report(ticker: str, date: str):
async def get_report(ticker: str, date: str, api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
content = get_report_content(ticker, date)
if not content:
raise HTTPException(status_code=404, detail="Report not found")
@ -614,8 +673,12 @@ def _extract_summary(markdown_text: str) -> str:
@app.get("/api/reports/export")
async def export_reports_csv():
async def export_reports_csv(
api_key: Optional[str] = Header(None),
):
"""Export all reports as CSV: ticker,date,decision,summary."""
if not _check_api_key(api_key):
_auth_error()
reports = get_reports_list()
output = io.StringIO()
writer = csv.DictWriter(output, fieldnames=["ticker", "date", "decision", "summary"])
@ -644,8 +707,10 @@ async def export_reports_csv():
@app.get("/api/reports/{ticker}/{date}/pdf")
async def export_report_pdf(ticker: str, date: str):
async def export_report_pdf(ticker: str, date: str, api_key: Optional[str] = Header(None)):
"""Export a single report as PDF."""
if not _check_api_key(api_key):
_auth_error()
content = get_report_content(ticker, date)
if not content or not content.get("report"):
raise HTTPException(status_code=404, detail="Report not found")
@ -758,12 +823,16 @@ from api.portfolio import (
# --- Watchlist ---
@app.get("/api/portfolio/watchlist")
async def list_watchlist():
async def list_watchlist(api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
return {"watchlist": get_watchlist()}
@app.post("/api/portfolio/watchlist")
async def create_watchlist_entry(body: dict):
async def create_watchlist_entry(body: dict, api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
try:
entry = add_to_watchlist(body["ticker"], body.get("name", body["ticker"]))
return entry
@ -772,7 +841,9 @@ async def create_watchlist_entry(body: dict):
@app.delete("/api/portfolio/watchlist/{ticker}")
async def delete_watchlist_entry(ticker: str):
async def delete_watchlist_entry(ticker: str, api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
if remove_from_watchlist(ticker):
return {"ok": True}
raise HTTPException(status_code=404, detail="Ticker not found in watchlist")
@ -781,13 +852,17 @@ async def delete_watchlist_entry(ticker: str):
# --- Accounts ---
@app.get("/api/portfolio/accounts")
async def list_accounts():
async def list_accounts(api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
accounts = get_accounts()
return {"accounts": list(accounts.get("accounts", {}).keys())}
@app.post("/api/portfolio/accounts")
async def create_account_endpoint(body: dict):
async def create_account_endpoint(body: dict, api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
try:
return create_account(body["account_name"])
except ValueError as e:
@ -795,7 +870,9 @@ async def create_account_endpoint(body: dict):
@app.delete("/api/portfolio/accounts/{account_name}")
async def delete_account_endpoint(account_name: str):
async def delete_account_endpoint(account_name: str, api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
if delete_account(account_name):
return {"ok": True}
raise HTTPException(status_code=404, detail="Account not found")
@ -804,12 +881,16 @@ async def delete_account_endpoint(account_name: str):
# --- Positions ---
@app.get("/api/portfolio/positions")
async def list_positions(account: Optional[str] = Query(None)):
async def list_positions(account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
return {"positions": get_positions(account)}
@app.post("/api/portfolio/positions")
async def create_position(body: dict):
async def create_position(body: dict, api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
try:
pos = add_position(
ticker=body["ticker"],
@ -825,7 +906,9 @@ async def create_position(body: dict):
@app.delete("/api/portfolio/positions/{ticker}")
async def delete_position(ticker: str, position_id: Optional[str] = Query(None), account: Optional[str] = Query(None)):
async def delete_position(ticker: str, position_id: Optional[str] = Query(None), account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
removed = remove_position(ticker, position_id or "", account)
if removed:
return {"ok": True}
@ -833,7 +916,9 @@ async def delete_position(ticker: str, position_id: Optional[str] = Query(None),
@app.get("/api/portfolio/positions/export")
async def export_positions_csv(account: Optional[str] = Query(None)):
async def export_positions_csv(account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
positions = get_positions(account)
import csv
import io
@ -848,12 +933,21 @@ async def export_positions_csv(account: Optional[str] = Query(None)):
# --- Recommendations ---
@app.get("/api/portfolio/recommendations")
async def list_recommendations(date: Optional[str] = Query(None)):
return {"recommendations": get_recommendations(date)}
async def list_recommendations(
date: Optional[str] = Query(None),
limit: int = Query(DEFAULT_PAGE_SIZE, ge=1, le=MAX_PAGE_SIZE),
offset: int = Query(0, ge=0),
api_key: Optional[str] = Header(None),
):
if not _check_api_key(api_key):
_auth_error()
return get_recommendations(date, limit, offset)
@app.get("/api/portfolio/recommendations/{date}/{ticker}")
async def get_recommendation_endpoint(date: str, ticker: str):
async def get_recommendation_endpoint(date: str, ticker: str, api_key: Optional[str] = Header(None)):
if not _check_api_key(api_key):
_auth_error()
rec = get_recommendation(date, ticker)
if not rec:
raise HTTPException(status_code=404, detail="Recommendation not found")
@ -863,11 +957,13 @@ async def get_recommendation_endpoint(date: str, ticker: str):
# --- Batch Analysis ---
@app.post("/api/portfolio/analyze")
async def start_portfolio_analysis():
async def start_portfolio_analysis(api_key: Optional[str] = Header(None)):
"""
Trigger batch analysis for all watchlist tickers.
Runs serially, streaming progress via WebSocket (task_id prefixed with 'port_').
"""
if not _check_api_key(api_key):
_auth_error()
import uuid
date = datetime.now().strftime("%Y-%m-%d")
task_id = f"port_{date}_{uuid.uuid4().hex[:6]}"