diff --git a/web_dashboard/backend/main.py b/web_dashboard/backend/main.py index 455f8bcd..9150287d 100644 --- a/web_dashboard/backend/main.py +++ b/web_dashboard/backend/main.py @@ -23,6 +23,8 @@ from pydantic import BaseModel REPO_ROOT = Path(__file__).parent.parent.parent # Use the currently running Python interpreter ANALYSIS_PYTHON = Path(sys.executable) +# Task state persistence directory +TASK_STATUS_DIR = Path(__file__).parent / "data" / "task_status" # ============== Lifespan ============== @@ -33,6 +35,16 @@ async def lifespan(app: FastAPI): app.state.active_connections: dict[str, list[WebSocket]] = {} app.state.task_results: dict[str, dict] = {} app.state.analysis_tasks: dict[str, asyncio.Task] = {} + + # Restore persisted task states from disk + TASK_STATUS_DIR.mkdir(parents=True, exist_ok=True) + for f in TASK_STATUS_DIR.glob("*.json"): + try: + data = json.loads(f.read_text()) + app.state.task_results[data["task_id"]] = data + except Exception: + pass + yield @@ -86,7 +98,21 @@ def _load_from_cache(mode: str) -> Optional[dict]: return None -def _save_to_cache(mode: str, data: dict): +def _save_task_status(task_id: str, data: dict): + """Persist task state to disk""" + try: + TASK_STATUS_DIR.mkdir(parents=True, exist_ok=True) + (TASK_STATUS_DIR / f"{task_id}.json").write_text(json.dumps(data, ensure_ascii=False)) + except Exception: + pass + + +def _delete_task_status(task_id: str): + """Remove persisted task state from disk""" + try: + (TASK_STATUS_DIR / f"{task_id}.json").unlink(missing_ok=True) + except Exception: + pass try: CACHE_DIR.mkdir(parents=True, exist_ok=True) cache_path = _get_cache_path(mode) @@ -346,6 +372,8 @@ async def start_analysis(request: AnalysisRequest): app.state.task_results[task_id]["status"] = "failed" app.state.task_results[task_id]["error"] = error_msg + _save_task_status(task_id, app.state.task_results[task_id]) + except Exception as e: cancel_event.set() app.state.task_results[task_id]["status"] = "failed" @@ -355,6 +383,8 @@ async def start_analysis(request: AnalysisRequest): except Exception: pass + _save_task_status(task_id, app.state.task_results[task_id]) + await broadcast_progress(task_id, app.state.task_results[task_id]) task = asyncio.create_task(run_analysis()) @@ -425,6 +455,9 @@ async def cancel_task(task_id: str): except Exception: pass + # Remove persisted task state + _delete_task_status(task_id) + return {"task_id": task_id, "status": "cancelled"}