diff --git a/web_dashboard/backend/api/portfolio.py b/web_dashboard/backend/api/portfolio.py index ce23590b..12fef09e 100644 --- a/web_dashboard/backend/api/portfolio.py +++ b/web_dashboard/backend/api/portfolio.py @@ -279,27 +279,38 @@ def remove_position(ticker: str, position_id: str, account: Optional[str]) -> bo # ============== Recommendations ============== -def get_recommendations(date: Optional[str] = None) -> list: - """List recommendations, optionally filtered by date.""" +# Pagination defaults (must match main.py constants) +DEFAULT_PAGE_SIZE = 50 +MAX_PAGE_SIZE = 500 + + +def get_recommendations(date: Optional[str] = None, limit: int = DEFAULT_PAGE_SIZE, offset: int = 0) -> dict: + """List recommendations, optionally filtered by date. Returns paginated results.""" RECOMMENDATIONS_DIR.mkdir(parents=True, exist_ok=True) + all_recs = [] + if date: date_dir = RECOMMENDATIONS_DIR / date - if not date_dir.exists(): - return [] - return [ - json.loads(f.read_text()) - for f in date_dir.glob("*.json") - if f.suffix == ".json" - ] + if date_dir.exists(): + all_recs = [ + json.loads(f.read_text()) + for f in sorted(date_dir.glob("*.json"), reverse=True) + if f.suffix == ".json" + ] else: - # Return most recent first - all_recs = [] for date_dir in sorted(RECOMMENDATIONS_DIR.iterdir(), reverse=True): if date_dir.is_dir() and date_dir.name.startswith("20"): - for f in date_dir.glob("*.json"): + for f in sorted(date_dir.glob("*.json"), reverse=True): if f.suffix == ".json": all_recs.append(json.loads(f.read_text())) - return all_recs + + total = len(all_recs) + return { + "recommendations": all_recs[offset : offset + limit], + "total": total, + "limit": limit, + "offset": offset, + } def get_recommendation(date: str, ticker: str) -> Optional[dict]: diff --git a/web_dashboard/backend/main.py b/web_dashboard/backend/main.py index bb4b054f..f15684c5 100644 --- a/web_dashboard/backend/main.py +++ b/web_dashboard/backend/main.py @@ -58,9 +58,13 @@ app = FastAPI( lifespan=lifespan ) +# CORS: allow all if CORS_ORIGINS is not set (development), otherwise comma-separated list +_cors_origins = os.environ.get("CORS_ORIGINS", "*") +_cors_origins_list = ["*"] if _cors_origins == "*" else [o.strip() for o in _cors_origins.split(",")] + app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=_cors_origins_list, allow_methods=["*"], allow_headers=["*"], ) @@ -83,6 +87,29 @@ MAX_RETRY_COUNT = 2 RETRY_BASE_DELAY_SECS = 1 MAX_CONCURRENT_YFINANCE = 5 +# Pagination defaults +DEFAULT_PAGE_SIZE = 50 +MAX_PAGE_SIZE = 500 + +# Auth — set DASHBOARD_API_KEY env var to enable API key authentication +_api_key: Optional[str] = None + +def _get_api_key() -> Optional[str]: + global _api_key + if _api_key is None: + _api_key = os.environ.get("DASHBOARD_API_KEY") or os.environ.get("ANTHROPIC_API_KEY") + return _api_key + +def _check_api_key(api_key: Optional[str]) -> bool: + """Return True if no key is required, or if the provided key matches.""" + required = _get_api_key() + if not required: + return True + return api_key == required + +def _auth_error(): + raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required") + def _get_cache_path(mode: str) -> Path: return CACHE_DIR / f"screen_{mode}.json" @@ -147,8 +174,10 @@ def _run_sepa_screening(mode: str) -> dict: @app.get("/api/stocks/screen") -async def screen_stocks(mode: str = Query("china_strict"), refresh: bool = Query(False)): +async def screen_stocks(mode: str = Query("china_strict"), refresh: bool = Query(False), api_key: Optional[str] = Header(None)): """Screen stocks using SEPA criteria with caching""" + if not _check_api_key(api_key): + _auth_error() if not refresh: cached = _load_from_cache(mode) if cached: @@ -225,15 +254,19 @@ print("ANALYSIS_COMPLETE:" + signal, flush=True) @app.post("/api/analysis/start") -async def start_analysis(request: AnalysisRequest): +async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Header(None)): """Start a new analysis task""" import uuid task_id = f"{request.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}" date = request.date or datetime.now().strftime("%Y-%m-%d") - # Validate API key before storing any task state - api_key = os.environ.get("ANTHROPIC_API_KEY") - if not api_key: + # Check dashboard API key (opt-in auth) + if not _check_api_key(api_key): + _auth_error() + + # Validate ANTHROPIC_API_KEY for the analysis subprocess + anthropic_key = os.environ.get("ANTHROPIC_API_KEY") + if not anthropic_key: raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set") # Initialize task state @@ -410,16 +443,20 @@ async def start_analysis(request: AnalysisRequest): @app.get("/api/analysis/status/{task_id}") -async def get_task_status(task_id: str): +async def get_task_status(task_id: str, api_key: Optional[str] = Header(None)): """Get task status""" + if not _check_api_key(api_key): + _auth_error() if task_id not in app.state.task_results: raise HTTPException(status_code=404, detail="Task not found") return app.state.task_results[task_id] @app.get("/api/analysis/tasks") -async def list_tasks(): +async def list_tasks(api_key: Optional[str] = Header(None)): """List all tasks (active and recent)""" + if not _check_api_key(api_key): + _auth_error() tasks = [] for task_id, state in app.state.task_results.items(): tasks.append({ @@ -438,8 +475,10 @@ async def list_tasks(): @app.delete("/api/analysis/cancel/{task_id}") -async def cancel_task(task_id: str): +async def cancel_task(task_id: str, api_key: Optional[str] = Header(None)): """Cancel a running task""" + if not _check_api_key(api_key): + _auth_error() if task_id not in app.state.task_results: raise HTTPException(status_code=404, detail="Task not found") @@ -477,7 +516,12 @@ async def cancel_task(task_id: str): @app.websocket("/ws/analysis/{task_id}") async def websocket_analysis(websocket: WebSocket, task_id: str): - """WebSocket for real-time analysis progress""" + """WebSocket for real-time analysis progress. Auth via ?api_key= query param.""" + # Optional API key check for WebSocket + ws_api_key = websocket.query_params.get("api_key") + if not _check_api_key(ws_api_key): + await websocket.close(code=4001, reason="Unauthorized") + return await websocket.accept() if task_id not in app.state.active_connections: @@ -574,12 +618,27 @@ def get_report_content(ticker: str, date: str) -> Optional[dict]: @app.get("/api/reports/list") -async def list_reports(): - return get_reports_list() +async def list_reports( + limit: int = Query(DEFAULT_PAGE_SIZE, ge=1, le=MAX_PAGE_SIZE), + offset: int = Query(0, ge=0), + api_key: Optional[str] = Header(None), +): + if not _check_api_key(api_key): + _auth_error() + reports = get_reports_list() + total = len(reports) + return { + "reports": sorted(reports, key=lambda x: x["date"], reverse=True)[offset : offset + limit], + "total": total, + "limit": limit, + "offset": offset, + } @app.get("/api/reports/{ticker}/{date}") -async def get_report(ticker: str, date: str): +async def get_report(ticker: str, date: str, api_key: Optional[str] = Header(None)): + if not _check_api_key(api_key): + _auth_error() content = get_report_content(ticker, date) if not content: raise HTTPException(status_code=404, detail="Report not found") @@ -614,8 +673,12 @@ def _extract_summary(markdown_text: str) -> str: @app.get("/api/reports/export") -async def export_reports_csv(): +async def export_reports_csv( + api_key: Optional[str] = Header(None), +): """Export all reports as CSV: ticker,date,decision,summary.""" + if not _check_api_key(api_key): + _auth_error() reports = get_reports_list() output = io.StringIO() writer = csv.DictWriter(output, fieldnames=["ticker", "date", "decision", "summary"]) @@ -644,8 +707,10 @@ async def export_reports_csv(): @app.get("/api/reports/{ticker}/{date}/pdf") -async def export_report_pdf(ticker: str, date: str): +async def export_report_pdf(ticker: str, date: str, api_key: Optional[str] = Header(None)): """Export a single report as PDF.""" + if not _check_api_key(api_key): + _auth_error() content = get_report_content(ticker, date) if not content or not content.get("report"): raise HTTPException(status_code=404, detail="Report not found") @@ -758,12 +823,16 @@ from api.portfolio import ( # --- Watchlist --- @app.get("/api/portfolio/watchlist") -async def list_watchlist(): +async def list_watchlist(api_key: Optional[str] = Header(None)): + if not _check_api_key(api_key): + _auth_error() return {"watchlist": get_watchlist()} @app.post("/api/portfolio/watchlist") -async def create_watchlist_entry(body: dict): +async def create_watchlist_entry(body: dict, api_key: Optional[str] = Header(None)): + if not _check_api_key(api_key): + _auth_error() try: entry = add_to_watchlist(body["ticker"], body.get("name", body["ticker"])) return entry @@ -772,7 +841,9 @@ async def create_watchlist_entry(body: dict): @app.delete("/api/portfolio/watchlist/{ticker}") -async def delete_watchlist_entry(ticker: str): +async def delete_watchlist_entry(ticker: str, api_key: Optional[str] = Header(None)): + if not _check_api_key(api_key): + _auth_error() if remove_from_watchlist(ticker): return {"ok": True} raise HTTPException(status_code=404, detail="Ticker not found in watchlist") @@ -781,13 +852,17 @@ async def delete_watchlist_entry(ticker: str): # --- Accounts --- @app.get("/api/portfolio/accounts") -async def list_accounts(): +async def list_accounts(api_key: Optional[str] = Header(None)): + if not _check_api_key(api_key): + _auth_error() accounts = get_accounts() return {"accounts": list(accounts.get("accounts", {}).keys())} @app.post("/api/portfolio/accounts") -async def create_account_endpoint(body: dict): +async def create_account_endpoint(body: dict, api_key: Optional[str] = Header(None)): + if not _check_api_key(api_key): + _auth_error() try: return create_account(body["account_name"]) except ValueError as e: @@ -795,7 +870,9 @@ async def create_account_endpoint(body: dict): @app.delete("/api/portfolio/accounts/{account_name}") -async def delete_account_endpoint(account_name: str): +async def delete_account_endpoint(account_name: str, api_key: Optional[str] = Header(None)): + if not _check_api_key(api_key): + _auth_error() if delete_account(account_name): return {"ok": True} raise HTTPException(status_code=404, detail="Account not found") @@ -804,12 +881,16 @@ async def delete_account_endpoint(account_name: str): # --- Positions --- @app.get("/api/portfolio/positions") -async def list_positions(account: Optional[str] = Query(None)): +async def list_positions(account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)): + if not _check_api_key(api_key): + _auth_error() return {"positions": get_positions(account)} @app.post("/api/portfolio/positions") -async def create_position(body: dict): +async def create_position(body: dict, api_key: Optional[str] = Header(None)): + if not _check_api_key(api_key): + _auth_error() try: pos = add_position( ticker=body["ticker"], @@ -825,7 +906,9 @@ async def create_position(body: dict): @app.delete("/api/portfolio/positions/{ticker}") -async def delete_position(ticker: str, position_id: Optional[str] = Query(None), account: Optional[str] = Query(None)): +async def delete_position(ticker: str, position_id: Optional[str] = Query(None), account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)): + if not _check_api_key(api_key): + _auth_error() removed = remove_position(ticker, position_id or "", account) if removed: return {"ok": True} @@ -833,7 +916,9 @@ async def delete_position(ticker: str, position_id: Optional[str] = Query(None), @app.get("/api/portfolio/positions/export") -async def export_positions_csv(account: Optional[str] = Query(None)): +async def export_positions_csv(account: Optional[str] = Query(None), api_key: Optional[str] = Header(None)): + if not _check_api_key(api_key): + _auth_error() positions = get_positions(account) import csv import io @@ -848,12 +933,21 @@ async def export_positions_csv(account: Optional[str] = Query(None)): # --- Recommendations --- @app.get("/api/portfolio/recommendations") -async def list_recommendations(date: Optional[str] = Query(None)): - return {"recommendations": get_recommendations(date)} +async def list_recommendations( + date: Optional[str] = Query(None), + limit: int = Query(DEFAULT_PAGE_SIZE, ge=1, le=MAX_PAGE_SIZE), + offset: int = Query(0, ge=0), + api_key: Optional[str] = Header(None), +): + if not _check_api_key(api_key): + _auth_error() + return get_recommendations(date, limit, offset) @app.get("/api/portfolio/recommendations/{date}/{ticker}") -async def get_recommendation_endpoint(date: str, ticker: str): +async def get_recommendation_endpoint(date: str, ticker: str, api_key: Optional[str] = Header(None)): + if not _check_api_key(api_key): + _auth_error() rec = get_recommendation(date, ticker) if not rec: raise HTTPException(status_code=404, detail="Recommendation not found") @@ -863,11 +957,13 @@ async def get_recommendation_endpoint(date: str, ticker: str): # --- Batch Analysis --- @app.post("/api/portfolio/analyze") -async def start_portfolio_analysis(): +async def start_portfolio_analysis(api_key: Optional[str] = Header(None)): """ Trigger batch analysis for all watchlist tickers. Runs serially, streaming progress via WebSocket (task_id prefixed with 'port_'). """ + if not _check_api_key(api_key): + _auth_error() import uuid date = datetime.now().strftime("%Y-%m-%d") task_id = f"port_{date}_{uuid.uuid4().hex[:6]}"