fix(dashboard): address 4 critical issues found in pre-landing review

1. main.py: move API key validation before task state creation —
   prevents phantom "running" tasks when ANTHROPIC_API_KEY is missing
2. portfolio.py: make get_positions() async and fetch yfinance prices
   concurrently via run_in_executor — no longer blocks event loop
3. portfolio.py: add fcntl.LOCK_EX around all JSON read-modify-write
   operations on watchlist.json and positions.json — eliminates TOCTOU
   lost-write races under concurrent requests
4. main.py: use tempfile.mkstemp with mode 0o600 instead of world-
   readable /tmp/analysis_{task_id}.py — script content no longer
   exposed to other users on shared hosts

Also: remove unused UploadFile/File imports, undefined _save_to_cache
function, dead code in _delete_task_status, and unused
get_or_create_default_account helper.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
陈少杰 2026-04-07 17:27:49 +08:00
parent 17a4ed2513
commit d12c34c333
2 changed files with 179 additions and 114 deletions

View File

@ -2,6 +2,7 @@
Portfolio API 自选股持仓每日建议
"""
import asyncio
import fcntl
import json
import uuid
from datetime import datetime
@ -17,6 +18,8 @@ DATA_DIR.mkdir(parents=True, exist_ok=True)
WATCHLIST_FILE = DATA_DIR / "watchlist.json"
POSITIONS_FILE = DATA_DIR / "positions.json"
RECOMMENDATIONS_DIR = DATA_DIR / "recommendations"
WATCHLIST_LOCK = DATA_DIR / "watchlist.lock"
POSITIONS_LOCK = DATA_DIR / "positions.lock"
# ============== Watchlist ==============
@ -25,36 +28,56 @@ def get_watchlist() -> list:
if not WATCHLIST_FILE.exists():
return []
try:
return json.loads(WATCHLIST_FILE.read_text()).get("watchlist", [])
with open(WATCHLIST_LOCK, "w") as lf:
fcntl.flock(lf.fileno(), fcntl.LOCK_SH)
try:
return json.loads(WATCHLIST_FILE.read_text()).get("watchlist", [])
finally:
fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
except Exception:
return []
def _save_watchlist(watchlist: list):
WATCHLIST_FILE.write_text(json.dumps({"watchlist": watchlist}, ensure_ascii=False, indent=2))
with open(WATCHLIST_LOCK, "w") as lf:
fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
try:
WATCHLIST_FILE.write_text(json.dumps({"watchlist": watchlist}, ensure_ascii=False, indent=2))
finally:
fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
def add_to_watchlist(ticker: str, name: str) -> dict:
watchlist = get_watchlist()
if any(s["ticker"] == ticker for s in watchlist):
raise ValueError(f"{ticker} 已在自选股中")
entry = {
"ticker": ticker,
"name": name,
"added_at": datetime.now().strftime("%Y-%m-%d"),
}
watchlist.append(entry)
_save_watchlist(watchlist)
return entry
with open(WATCHLIST_LOCK, "w") as lf:
fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
try:
watchlist = json.loads(WATCHLIST_FILE.read_text()).get("watchlist", []) if WATCHLIST_FILE.exists() else []
if any(s["ticker"] == ticker for s in watchlist):
raise ValueError(f"{ticker} 已在自选股中")
entry = {
"ticker": ticker,
"name": name,
"added_at": datetime.now().strftime("%Y-%m-%d"),
}
watchlist.append(entry)
WATCHLIST_FILE.write_text(json.dumps({"watchlist": watchlist}, ensure_ascii=False, indent=2))
return entry
finally:
fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
def remove_from_watchlist(ticker: str) -> bool:
watchlist = get_watchlist()
new_list = [s for s in watchlist if s["ticker"] != ticker]
if len(new_list) == len(watchlist):
return False
_save_watchlist(new_list)
return True
with open(WATCHLIST_LOCK, "w") as lf:
fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
try:
watchlist = json.loads(WATCHLIST_FILE.read_text()).get("watchlist", []) if WATCHLIST_FILE.exists() else []
new_list = [s for s in watchlist if s["ticker"] != ticker]
if len(new_list) == len(watchlist):
return False
WATCHLIST_FILE.write_text(json.dumps({"watchlist": new_list}, ensure_ascii=False, indent=2))
return True
finally:
fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
# ============== Accounts ==============
@ -63,47 +86,71 @@ def get_accounts() -> dict:
if not POSITIONS_FILE.exists():
return {"accounts": {}}
try:
return json.loads(POSITIONS_FILE.read_text())
with open(POSITIONS_LOCK, "w") as lf:
fcntl.flock(lf.fileno(), fcntl.LOCK_SH)
try:
return json.loads(POSITIONS_FILE.read_text())
finally:
fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
except Exception:
return {"accounts": {}}
def _save_accounts(data: dict):
POSITIONS_FILE.write_text(json.dumps(data, ensure_ascii=False, indent=2))
def get_or_create_default_account(accounts: dict) -> dict:
if "默认账户" not in accounts.get("accounts", {}):
accounts["accounts"]["默认账户"] = {"positions": {}}
return accounts["accounts"]["默认账户"]
with open(POSITIONS_LOCK, "w") as lf:
fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
try:
POSITIONS_FILE.write_text(json.dumps(data, ensure_ascii=False, indent=2))
finally:
fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
def create_account(account_name: str) -> dict:
accounts = get_accounts()
if account_name in accounts.get("accounts", {}):
raise ValueError(f"账户 {account_name} 已存在")
accounts["accounts"][account_name] = {"positions": {}}
_save_accounts(accounts)
return {"account_name": account_name}
with open(POSITIONS_LOCK, "w") as lf:
fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
try:
accounts = json.loads(POSITIONS_FILE.read_text()) if POSITIONS_FILE.exists() else {"accounts": {}}
if account_name in accounts.get("accounts", {}):
raise ValueError(f"账户 {account_name} 已存在")
accounts["accounts"][account_name] = {"positions": {}}
POSITIONS_FILE.write_text(json.dumps(accounts, ensure_ascii=False, indent=2))
return {"account_name": account_name}
finally:
fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
def delete_account(account_name: str) -> bool:
accounts = get_accounts()
if account_name not in accounts.get("accounts", {}):
return False
del accounts["accounts"][account_name]
_save_accounts(accounts)
return True
with open(POSITIONS_LOCK, "w") as lf:
fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
try:
accounts = json.loads(POSITIONS_FILE.read_text()) if POSITIONS_FILE.exists() else {"accounts": {}}
if account_name not in accounts.get("accounts", {}):
return False
del accounts["accounts"][account_name]
POSITIONS_FILE.write_text(json.dumps(accounts, ensure_ascii=False, indent=2))
return True
finally:
fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
# ============== Positions ==============
# ============== Positions =============
def get_positions(account: Optional[str] = None) -> list:
def _fetch_price(ticker: str) -> float | None:
"""Fetch current price synchronously (called in thread executor)"""
try:
stock = yfinance.Ticker(ticker)
info = stock.info or {}
return info.get("currentPrice") or info.get("regularMarketPrice")
except Exception:
return None
async def get_positions(account: Optional[str] = None) -> list:
"""
Returns positions with live price from yfinance and computed P&L.
Uses asyncio executor to avoid blocking the event loop on yfinance HTTP calls.
"""
accounts = get_accounts()
result = []
if account:
acc = accounts.get("accounts", {}).get(account)
@ -119,22 +166,22 @@ def get_positions(account: Optional[str] = None) -> list:
for _pos in _positions
]
for ticker, pos in positions:
try:
stock = yfinance.Ticker(ticker)
info = stock.info or {}
current_price = info.get("currentPrice") or info.get("regularMarketPrice")
except Exception:
current_price = None
if not positions:
return []
loop = asyncio.get_event_loop()
tickers = [t for t, _ in positions]
prices = await asyncio.gather(*[loop.run_in_executor(None, _fetch_price, t) for t in tickers])
result = []
for (ticker, pos), current_price in zip(positions, prices):
shares = pos.get("shares", 0)
cost_price = pos.get("cost_price", 0)
unrealized_pnl = None
unrealized_pnl_pct = None
if current_price is not None and cost_price:
unrealized_pnl = (current_price - cost_price) * shares
unrealized_pnl_pct = (current_price / cost_price - 1) * 100
else:
unrealized_pnl = None
unrealized_pnl_pct = None
result.append({
"ticker": ticker,
@ -154,56 +201,68 @@ def get_positions(account: Optional[str] = None) -> list:
def add_position(ticker: str, shares: float, cost_price: float,
purchase_date: Optional[str], notes: str, account: str) -> dict:
accounts = get_accounts()
acc = accounts.get("accounts", {}).get(account)
if not acc:
acc = get_or_create_default_account(accounts)
with open(POSITIONS_LOCK, "w") as lf:
fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
try:
accounts = json.loads(POSITIONS_FILE.read_text()) if POSITIONS_FILE.exists() else {"accounts": {}}
acc = accounts.get("accounts", {}).get(account)
if not acc:
if "默认账户" not in accounts.get("accounts", {}):
accounts["accounts"]["默认账户"] = {"positions": {}}
acc = accounts["accounts"]["默认账户"]
position_id = f"pos_{uuid.uuid4().hex[:6]}"
position = {
"position_id": position_id,
"shares": shares,
"cost_price": cost_price,
"purchase_date": purchase_date,
"notes": notes,
"account": account,
"name": ticker,
}
position_id = f"pos_{uuid.uuid4().hex[:6]}"
position = {
"position_id": position_id,
"shares": shares,
"cost_price": cost_price,
"purchase_date": purchase_date,
"notes": notes,
"account": account,
"name": ticker,
}
if ticker not in acc["positions"]:
acc["positions"][ticker] = []
acc["positions"][ticker].append(position)
_save_accounts(accounts)
return position
if ticker not in acc["positions"]:
acc["positions"][ticker] = []
acc["positions"][ticker].append(position)
POSITIONS_FILE.write_text(json.dumps(accounts, ensure_ascii=False, indent=2))
return position
finally:
fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
def remove_position(ticker: str, position_id: str, account: Optional[str]) -> bool:
accounts = get_accounts()
if account:
acc = accounts.get("accounts", {}).get(account)
if acc and ticker in acc.get("positions", {}):
acc["positions"][ticker] = [
p for p in acc["positions"][ticker]
if p.get("position_id") != position_id
]
if not acc["positions"][ticker]:
del acc["positions"][ticker]
_save_accounts(accounts)
return True
else:
for acc_data in accounts.get("accounts", {}).values():
if ticker in acc_data.get("positions", {}):
original_len = len(acc_data["positions"][ticker])
acc_data["positions"][ticker] = [
p for p in acc_data["positions"][ticker]
if p.get("position_id") != position_id
]
if len(acc_data["positions"][ticker]) < original_len:
if not acc_data["positions"][ticker]:
del acc_data["positions"][ticker]
_save_accounts(accounts)
with open(POSITIONS_LOCK, "w") as lf:
fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
try:
accounts = json.loads(POSITIONS_FILE.read_text()) if POSITIONS_FILE.exists() else {"accounts": {}}
if account:
acc = accounts.get("accounts", {}).get(account)
if acc and ticker in acc.get("positions", {}):
acc["positions"][ticker] = [
p for p in acc["positions"][ticker]
if p.get("position_id") != position_id
]
if not acc["positions"][ticker]:
del acc["positions"][ticker]
POSITIONS_FILE.write_text(json.dumps(accounts, ensure_ascii=False, indent=2))
return True
return False
else:
for acc_data in accounts.get("accounts", {}).values():
if ticker in acc_data.get("positions", {}):
original_len = len(acc_data["positions"][ticker])
acc_data["positions"][ticker] = [
p for p in acc_data["positions"][ticker]
if p.get("position_id") != position_id
]
if len(acc_data["positions"][ticker]) < original_len:
if not acc_data["positions"][ticker]:
del acc_data["positions"][ticker]
POSITIONS_FILE.write_text(json.dumps(accounts, ensure_ascii=False, indent=2))
return True
return False
finally:
fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
# ============== Recommendations ==============

View File

@ -8,6 +8,7 @@ import json
import os
import subprocess
import sys
import tempfile
import time
import traceback
from datetime import datetime
@ -15,7 +16,7 @@ from pathlib import Path
from typing import Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Query, UploadFile, File
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi.responses import Response
@ -60,7 +61,6 @@ app = FastAPI(
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@ -99,6 +99,17 @@ def _load_from_cache(mode: str) -> Optional[dict]:
return None
def _save_to_cache(mode: str, data: dict):
"""Save screening result to cache"""
try:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
cache_path = _get_cache_path(mode)
with open(cache_path, "w") as f:
json.dump(data, f)
except Exception:
pass
def _save_task_status(task_id: str, data: dict):
"""Persist task state to disk"""
try:
@ -114,13 +125,6 @@ def _delete_task_status(task_id: str):
(TASK_STATUS_DIR / f"{task_id}.json").unlink(missing_ok=True)
except Exception:
pass
try:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
cache_path = _get_cache_path(mode)
with open(cache_path, "w") as f:
json.dump(data, f)
except Exception:
pass
# ============== SEPA Screening ==============
@ -224,6 +228,11 @@ async def start_analysis(request: AnalysisRequest):
task_id = f"{request.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
date = request.date or datetime.now().strftime("%Y-%m-%d")
# Validate API key before storing any task state
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set")
# Initialize task state
app.state.task_results[task_id] = {
"task_id": task_id,
@ -245,17 +254,14 @@ async def start_analysis(request: AnalysisRequest):
"decision": None,
"error": None,
}
# Get API key - fail fast before storing a running task
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set")
await broadcast_progress(task_id, app.state.task_results[task_id])
# Write analysis script to temp file (avoids subprocess -c quoting issues)
script_path = Path(f"/tmp/analysis_{task_id}.py")
script_content = ANALYSIS_SCRIPT_TEMPLATE
script_path.write_text(script_content)
# Write analysis script to temp file with restrictive permissions (avoids subprocess -c quoting issues)
fd, script_path_str = tempfile.mkstemp(suffix=".py", prefix=f"analysis_{task_id}_")
script_path = Path(script_path_str)
os.chmod(script_path, 0o600)
with os.fdopen(fd, "w") as f:
f.write(ANALYSIS_SCRIPT_TEMPLATE)
# Store process reference for cancellation
app.state.processes = getattr(app.state, 'processes', {})