fix: add API key auth, pagination, and configurable CORS to dashboard API (#3)
Security hardening: - API key authentication via X-API-Key header on all endpoints (opt-in: set DASHBOARD_API_KEY or ANTHROPIC_API_KEY env var to enable) If no key is set, endpoints remain open (backward-compatible) - WebSocket auth via ?api_key= query parameter - CORS now configurable via CORS_ORIGINS env var (default: allow all) Pagination (all list endpoints): - GET /api/reports/list — limit/offset with total count - GET /api/portfolio/recommendations — limit/offset with total count - DEFAULT_PAGE_SIZE=50, MAX_PAGE_SIZE=500 Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
f19c1c012e
commit
1cee59dd9f
|
|
@ -279,27 +279,38 @@ def remove_position(ticker: str, position_id: str, account: Optional[str]) -> bo
|
|||
|
||||
# ============== Recommendations ==============
|
||||
|
||||
def get_recommendations(date: Optional[str] = None) -> list:
|
||||
"""List recommendations, optionally filtered by date."""
|
||||
# Pagination defaults (must match main.py constants)
|
||||
DEFAULT_PAGE_SIZE = 50
|
||||
MAX_PAGE_SIZE = 500
|
||||
|
||||
|
||||
def get_recommendations(date: Optional[str] = None, limit: int = DEFAULT_PAGE_SIZE, offset: int = 0) -> dict:
|
||||
"""List recommendations, optionally filtered by date. Returns paginated results."""
|
||||
RECOMMENDATIONS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
all_recs = []
|
||||
|
||||
if date:
|
||||
date_dir = RECOMMENDATIONS_DIR / date
|
||||
if not date_dir.exists():
|
||||
return []
|
||||
return [
|
||||
json.loads(f.read_text())
|
||||
for f in date_dir.glob("*.json")
|
||||
if f.suffix == ".json"
|
||||
]
|
||||
if date_dir.exists():
|
||||
all_recs = [
|
||||
json.loads(f.read_text())
|
||||
for f in sorted(date_dir.glob("*.json"), reverse=True)
|
||||
if f.suffix == ".json"
|
||||
]
|
||||
else:
|
||||
# Return most recent first
|
||||
all_recs = []
|
||||
for date_dir in sorted(RECOMMENDATIONS_DIR.iterdir(), reverse=True):
|
||||
if date_dir.is_dir() and date_dir.name.startswith("20"):
|
||||
for f in date_dir.glob("*.json"):
|
||||
for f in sorted(date_dir.glob("*.json"), reverse=True):
|
||||
if f.suffix == ".json":
|
||||
all_recs.append(json.loads(f.read_text()))
|
||||
return all_recs
|
||||
|
||||
total = len(all_recs)
|
||||
return {
|
||||
"recommendations": all_recs[offset : offset + limit],
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
|
||||
def get_recommendation(date: str, ticker: str) -> Optional[dict]:
|
||||
|
|
|
|||
|
|
@ -58,9 +58,13 @@ app = FastAPI(
|
|||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# CORS: allow all if CORS_ORIGINS is not set (development), otherwise comma-separated list
|
||||
_cors_origins = os.environ.get("CORS_ORIGINS", "*")
|
||||
_cors_origins_list = ["*"] if _cors_origins == "*" else [o.strip() for o in _cors_origins.split(",")]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=_cors_origins_list,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
|
@ -83,6 +87,29 @@ MAX_RETRY_COUNT = 2
|
|||
RETRY_BASE_DELAY_SECS = 1
|
||||
MAX_CONCURRENT_YFINANCE = 5
|
||||
|
||||
# Pagination defaults
|
||||
DEFAULT_PAGE_SIZE = 50
|
||||
MAX_PAGE_SIZE = 500
|
||||
|
||||
# Auth — set DASHBOARD_API_KEY env var to enable API key authentication
|
||||
_api_key: Optional[str] = None
|
||||
|
||||
def _get_api_key() -> Optional[str]:
|
||||
global _api_key
|
||||
if _api_key is None:
|
||||
_api_key = os.environ.get("DASHBOARD_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
|
||||
return _api_key
|
||||
|
||||
def _check_api_key(api_key: Optional[str]) -> bool:
|
||||
"""Return True if no key is required, or if the provided key matches."""
|
||||
required = _get_api_key()
|
||||
if not required:
|
||||
return True
|
||||
return api_key == required
|
||||
|
||||
def _auth_error():
|
||||
raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required")
|
||||
|
||||
|
||||
def _get_cache_path(mode: str) -> Path:
|
||||
return CACHE_DIR / f"screen_{mode}.json"
|
||||
|
|
@ -147,8 +174,10 @@ def _run_sepa_screening(mode: str) -> dict:
|
|||
|
||||
|
||||
@app.get("/api/stocks/screen")
|
||||
async def screen_stocks(mode: str = Query("china_strict"), refresh: bool = Query(False)):
|
||||
async def screen_stocks(mode: str = Query("china_strict"), refresh: bool = Query(False), api_key: Optional[str] = Header(None)):
|
||||
"""Screen stocks using SEPA criteria with caching"""
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
if not refresh:
|
||||
cached = _load_from_cache(mode)
|
||||
if cached:
|
||||
|
|
@ -225,15 +254,19 @@ print("ANALYSIS_COMPLETE:" + signal, flush=True)
|
|||
|
||||
|
||||
@app.post("/api/analysis/start")
|
||||
async def start_analysis(request: AnalysisRequest):
|
||||
async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Header(None)):
|
||||
"""Start a new analysis task"""
|
||||
import uuid
|
||||
task_id = f"{request.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
|
||||
date = request.date or datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Validate API key before storing any task state
|
||||
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
# Check dashboard API key (opt-in auth)
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
|
||||
# Validate ANTHROPIC_API_KEY for the analysis subprocess
|
||||
anthropic_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
if not anthropic_key:
|
||||
raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set")
|
||||
|
||||
# Initialize task state
|
||||
|
|
@ -410,16 +443,20 @@ async def start_analysis(request: AnalysisRequest):
|
|||
|
||||
|
||||
@app.get("/api/analysis/status/{task_id}")
|
||||
async def get_task_status(task_id: str):
|
||||
async def get_task_status(task_id: str, api_key: Optional[str] = Header(None)):
|
||||
"""Get task status"""
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
if task_id not in app.state.task_results:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return app.state.task_results[task_id]
|
||||
|
||||
|
||||
@app.get("/api/analysis/tasks")
|
||||
async def list_tasks():
|
||||
async def list_tasks(api_key: Optional[str] = Header(None)):
|
||||
"""List all tasks (active and recent)"""
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
tasks = []
|
||||
for task_id, state in app.state.task_results.items():
|
||||
tasks.append({
|
||||
|
|
@ -438,8 +475,10 @@ async def list_tasks():
|
|||
|
||||
|
||||
@app.delete("/api/analysis/cancel/{task_id}")
|
||||
async def cancel_task(task_id: str):
|
||||
async def cancel_task(task_id: str, api_key: Optional[str] = Header(None)):
|
||||
"""Cancel a running task"""
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
if task_id not in app.state.task_results:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
|
|
@ -477,7 +516,12 @@ async def cancel_task(task_id: str):
|
|||
|
||||
@app.websocket("/ws/analysis/{task_id}")
|
||||
async def websocket_analysis(websocket: WebSocket, task_id: str):
|
||||
"""WebSocket for real-time analysis progress"""
|
||||
"""WebSocket for real-time analysis progress. Auth via ?api_key= query param."""
|
||||
# Optional API key check for WebSocket
|
||||
ws_api_key = websocket.query_params.get("api_key")
|
||||
if not _check_api_key(ws_api_key):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
await websocket.accept()
|
||||
|
||||
if task_id not in app.state.active_connections:
|
||||
|
|
@ -574,12 +618,27 @@ def get_report_content(ticker: str, date: str) -> Optional[dict]:
|
|||
|
||||
|
||||
@app.get("/api/reports/list")
|
||||
async def list_reports():
|
||||
return get_reports_list()
|
||||
async def list_reports(
|
||||
limit: int = Query(DEFAULT_PAGE_SIZE, ge=1, le=MAX_PAGE_SIZE),
|
||||
offset: int = Query(0, ge=0),
|
||||
api_key: Optional[str] = Header(None),
|
||||
):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
reports = get_reports_list()
|
||||
total = len(reports)
|
||||
return {
|
||||
"reports": sorted(reports, key=lambda x: x["date"], reverse=True)[offset : offset + limit],
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/reports/{ticker}/{date}")
|
||||
async def get_report(ticker: str, date: str):
|
||||
async def get_report(ticker: str, date: str, api_key: Optional[str] = Header(None)):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
content = get_report_content(ticker, date)
|
||||
if not content:
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
|
|
@ -614,8 +673,12 @@ def _extract_summary(markdown_text: str) -> str:
|
|||
|
||||
|
||||
@app.get("/api/reports/export")
|
||||
async def export_reports_csv():
|
||||
async def export_reports_csv(
|
||||
api_key: Optional[str] = Header(None),
|
||||
):
|
||||
"""Export all reports as CSV: ticker,date,decision,summary."""
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
reports = get_reports_list()
|
||||
output = io.StringIO()
|
||||
writer = csv.DictWriter(output, fieldnames=["ticker", "date", "decision", "summary"])
|
||||
|
|
@ -644,8 +707,10 @@ async def export_reports_csv():
|
|||
|
||||
|
||||
@app.get("/api/reports/{ticker}/{date}/pdf")
|
||||
async def export_report_pdf(ticker: str, date: str):
|
||||
async def export_report_pdf(ticker: str, date: str, api_key: Optional[str] = Header(None)):
|
||||
"""Export a single report as PDF."""
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
content = get_report_content(ticker, date)
|
||||
if not content or not content.get("report"):
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
|
|
@ -758,12 +823,16 @@ from api.portfolio import (
|
|||
# --- Watchlist ---
|
||||
|
||||
@app.get("/api/portfolio/watchlist")
|
||||
async def list_watchlist():
|
||||
async def list_watchlist(api_key: Optional[str] = Header(None)):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
return {"watchlist": get_watchlist()}
|
||||
|
||||
|
||||
@app.post("/api/portfolio/watchlist")
|
||||
async def create_watchlist_entry(body: dict):
|
||||
async def create_watchlist_entry(body: dict, api_key: Optional[str] = Header(None)):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
try:
|
||||
entry = add_to_watchlist(body["ticker"], body.get("name", body["ticker"]))
|
||||
return entry
|
||||
|
|
@ -772,7 +841,9 @@ async def create_watchlist_entry(body: dict):
|
|||
|
||||
|
||||
@app.delete("/api/portfolio/watchlist/{ticker}")
|
||||
async def delete_watchlist_entry(ticker: str):
|
||||
async def delete_watchlist_entry(ticker: str, api_key: Optional[str] = Header(None)):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
if remove_from_watchlist(ticker):
|
||||
return {"ok": True}
|
||||
raise HTTPException(status_code=404, detail="Ticker not found in watchlist")
|
||||
|
|
@ -781,13 +852,17 @@ async def delete_watchlist_entry(ticker: str):
|
|||
# --- Accounts ---
|
||||
|
||||
@app.get("/api/portfolio/accounts")
|
||||
async def list_accounts():
|
||||
async def list_accounts(api_key: Optional[str] = Header(None)):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
accounts = get_accounts()
|
||||
return {"accounts": list(accounts.get("accounts", {}).keys())}
|
||||
|
||||
|
||||
@app.post("/api/portfolio/accounts")
|
||||
async def create_account_endpoint(body: dict):
|
||||
async def create_account_endpoint(body: dict, api_key: Optional[str] = Header(None)):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
try:
|
||||
return create_account(body["account_name"])
|
||||
except ValueError as e:
|
||||
|
|
@ -795,7 +870,9 @@ async def create_account_endpoint(body: dict):
|
|||
|
||||
|
||||
@app.delete("/api/portfolio/accounts/{account_name}")
|
||||
async def delete_account_endpoint(account_name: str):
|
||||
async def delete_account_endpoint(account_name: str, api_key: Optional[str] = Header(None)):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
if delete_account(account_name):
|
||||
return {"ok": True}
|
||||
raise HTTPException(status_code=404, detail="Account not found")
|
||||
|
|
@ -804,12 +881,16 @@ async def delete_account_endpoint(account_name: str):
|
|||
# --- Positions ---
|
||||
|
||||
@app.get("/api/portfolio/positions")
|
||||
async def list_positions(account: Optional[str] = Query(None)):
|
||||
async def list_positions(account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
return {"positions": get_positions(account)}
|
||||
|
||||
|
||||
@app.post("/api/portfolio/positions")
|
||||
async def create_position(body: dict):
|
||||
async def create_position(body: dict, api_key: Optional[str] = Header(None)):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
try:
|
||||
pos = add_position(
|
||||
ticker=body["ticker"],
|
||||
|
|
@ -825,7 +906,9 @@ async def create_position(body: dict):
|
|||
|
||||
|
||||
@app.delete("/api/portfolio/positions/{ticker}")
|
||||
async def delete_position(ticker: str, position_id: Optional[str] = Query(None), account: Optional[str] = Query(None)):
|
||||
async def delete_position(ticker: str, position_id: Optional[str] = Query(None), account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
removed = remove_position(ticker, position_id or "", account)
|
||||
if removed:
|
||||
return {"ok": True}
|
||||
|
|
@ -833,7 +916,9 @@ async def delete_position(ticker: str, position_id: Optional[str] = Query(None),
|
|||
|
||||
|
||||
@app.get("/api/portfolio/positions/export")
|
||||
async def export_positions_csv(account: Optional[str] = Query(None)):
|
||||
async def export_positions_csv(account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
positions = get_positions(account)
|
||||
import csv
|
||||
import io
|
||||
|
|
@ -848,12 +933,21 @@ async def export_positions_csv(account: Optional[str] = Query(None)):
|
|||
# --- Recommendations ---
|
||||
|
||||
@app.get("/api/portfolio/recommendations")
|
||||
async def list_recommendations(date: Optional[str] = Query(None)):
|
||||
return {"recommendations": get_recommendations(date)}
|
||||
async def list_recommendations(
|
||||
date: Optional[str] = Query(None),
|
||||
limit: int = Query(DEFAULT_PAGE_SIZE, ge=1, le=MAX_PAGE_SIZE),
|
||||
offset: int = Query(0, ge=0),
|
||||
api_key: Optional[str] = Header(None),
|
||||
):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
return get_recommendations(date, limit, offset)
|
||||
|
||||
|
||||
@app.get("/api/portfolio/recommendations/{date}/{ticker}")
|
||||
async def get_recommendation_endpoint(date: str, ticker: str):
|
||||
async def get_recommendation_endpoint(date: str, ticker: str, api_key: Optional[str] = Header(None)):
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
rec = get_recommendation(date, ticker)
|
||||
if not rec:
|
||||
raise HTTPException(status_code=404, detail="Recommendation not found")
|
||||
|
|
@ -863,11 +957,13 @@ async def get_recommendation_endpoint(date: str, ticker: str):
|
|||
# --- Batch Analysis ---
|
||||
|
||||
@app.post("/api/portfolio/analyze")
|
||||
async def start_portfolio_analysis():
|
||||
async def start_portfolio_analysis(api_key: Optional[str] = Header(None)):
|
||||
"""
|
||||
Trigger batch analysis for all watchlist tickers.
|
||||
Runs serially, streaming progress via WebSocket (task_id prefixed with 'port_').
|
||||
"""
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
import uuid
|
||||
date = datetime.now().strftime("%Y-%m-%d")
|
||||
task_id = f"port_{date}_{uuid.uuid4().hex[:6]}"
|
||||
|
|
|
|||
Loading…
Reference in New Issue