diff --git a/web_dashboard/backend/main.py b/web_dashboard/backend/main.py index 283e76cd..36f7a023 100644 --- a/web_dashboard/backend/main.py +++ b/web_dashboard/backend/main.py @@ -6,11 +6,8 @@ import asyncio import hmac import json import os -import subprocess import sys -import tempfile import time -import traceback from datetime import datetime from pathlib import Path from typing import Optional @@ -23,6 +20,7 @@ from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from services import AnalysisService, JobService, ResultStore, build_request_context, load_migration_flags +from services.executor import LegacySubprocessAnalysisExecutor # Path to TradingAgents repo root REPO_ROOT = Path(__file__).parent.parent.parent @@ -54,10 +52,12 @@ async def lifespan(app: FastAPI): delete_task=app.state.result_store.delete_task_status, ) app.state.analysis_service = AnalysisService( - analysis_python=ANALYSIS_PYTHON, - repo_root=REPO_ROOT, - analysis_script_template=ANALYSIS_SCRIPT_TEMPLATE, - api_key_resolver=_get_analysis_api_key, + executor=LegacySubprocessAnalysisExecutor( + analysis_python=ANALYSIS_PYTHON, + repo_root=REPO_ROOT, + api_key_resolver=_get_analysis_api_key, + process_registry=app.state.job_service.register_process, + ), result_store=app.state.result_store, job_service=app.state.job_service, retry_count=MAX_RETRY_COUNT, @@ -229,23 +229,6 @@ def _save_to_cache(mode: str, data: dict): pass -def _save_task_status(task_id: str, data: dict): - """Persist task state to disk""" - try: - TASK_STATUS_DIR.mkdir(parents=True, exist_ok=True) - (TASK_STATUS_DIR / f"{task_id}.json").write_text(json.dumps(data, ensure_ascii=False)) - except Exception: - pass - - -def _delete_task_status(task_id: str): - """Remove persisted task state from disk""" - try: - (TASK_STATUS_DIR / f"{task_id}.json").unlink(missing_ok=True) - except Exception: - pass - - # ============== SEPA Screening ============== def _run_sepa_screening(mode: str) -> dict: @@ -282,288 +265,34 @@ async def screen_stocks(mode: str = Query("china_strict"), refresh: bool = Query # ============== Analysis Execution ============== -# Script template for subprocess-based analysis -# api_key is passed via environment variable (not CLI) for security -ANALYSIS_SCRIPT_TEMPLATE = """ -import sys -import os -import json -ticker = sys.argv[1] -date = sys.argv[2] -repo_root = sys.argv[3] - -sys.path.insert(0, repo_root) -os.environ["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic" -import py_mini_racer -sys.modules["mini_racer"] = py_mini_racer -from pathlib import Path - -print("STAGE:analysts", flush=True) - -from orchestrator.config import OrchestratorConfig -from orchestrator.orchestrator import TradingOrchestrator - -config = OrchestratorConfig( - quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""), - trading_agents_config={ - "llm_provider": "anthropic", - "deep_think_llm": "MiniMax-M2.7-highspeed", - "quick_think_llm": "MiniMax-M2.7-highspeed", - "backend_url": "https://api.minimaxi.com/anthropic", - "max_debate_rounds": 1, - "max_risk_discuss_rounds": 1, - "project_dir": os.path.join(repo_root, "tradingagents"), - "results_dir": os.path.join(repo_root, "results"), - } -) - -print("STAGE:research", flush=True) - -orchestrator = TradingOrchestrator(config) - -print("STAGE:trading", flush=True) - -try: - result = orchestrator.get_combined_signal(ticker, date) -except ValueError as _e: - print("ANALYSIS_ERROR:" + str(_e), file=sys.stderr, flush=True) - sys.exit(1) - -print("STAGE:risk", flush=True) - -# Map direction + confidence to 5-level signal -# FinalSignal is a dataclass, access via attributes not .get() -direction = result.direction -confidence = result.confidence -llm_sig_obj = result.llm_signal -quant_sig_obj = result.quant_signal -# LLM metadata has "rating" field; quant metadata does not — derive from direction -llm_signal = llm_sig_obj.metadata.get("rating", "HOLD") if llm_sig_obj else "HOLD" -if quant_sig_obj is None: - quant_signal = "HOLD" -elif quant_sig_obj.direction == 1: - quant_signal = "BUY" if quant_sig_obj.confidence >= 0.7 else "OVERWEIGHT" -elif quant_sig_obj.direction == -1: - quant_signal = "SELL" if quant_sig_obj.confidence >= 0.7 else "UNDERWEIGHT" -else: - quant_signal = "HOLD" - -if direction == 1: - signal = "BUY" if confidence >= 0.7 else "OVERWEIGHT" -elif direction == -1: - signal = "SELL" if confidence >= 0.7 else "UNDERWEIGHT" -else: - signal = "HOLD" - -results_dir = Path(repo_root) / "results" / ticker / date -results_dir.mkdir(parents=True, exist_ok=True) - -report_content = ( - "# TradingAgents 分析报告\\n\\n" - "**股票**: " + ticker + "\\n" - "**日期**: " + date + "\\n\\n" - "## 最终决策\\n\\n" - "**" + signal + "**\\n\\n" - "## 信号详情\\n\\n" - "- LLM 信号: " + llm_signal + "\\n" - "- Quant 信号: " + quant_signal + "\\n" - "- 置信度: " + f"{confidence:.1%}" + "\\n\\n" - "## 分析摘要\\n\\n" - "N/A\\n" -) - -report_path = results_dir / "complete_report.md" -report_path.write_text(report_content) - -print("STAGE:portfolio", flush=True) -signal_detail = json.dumps({"llm_signal": llm_signal, "quant_signal": quant_signal, "confidence": confidence}) -print("SIGNAL_DETAIL:" + signal_detail, flush=True) -print("ANALYSIS_COMPLETE:" + signal, flush=True) -""" - - @app.post("/api/analysis/start") -async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Header(None)): - """Start a new analysis task""" +async def start_analysis( + payload: AnalysisRequest, + http_request: Request, + api_key: Optional[str] = Header(None), +): + """Start a new analysis task.""" import uuid - task_id = f"{request.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}" - date = request.date or datetime.now().strftime("%Y-%m-%d") - # Check dashboard API key (opt-in auth) if not _check_api_key(api_key): _auth_error() - # Validate ANTHROPIC_API_KEY for the analysis subprocess - anthropic_key = _get_analysis_api_key() - if not anthropic_key: - raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set") + task_id = f"{payload.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}" + date = payload.date or datetime.now().strftime("%Y-%m-%d") + request_context = build_request_context(http_request, api_key=api_key) - # Initialize task state - app.state.task_results[task_id] = { - "task_id": task_id, - "ticker": request.ticker, - "date": date, - "status": "running", - "progress": 0, - "current_stage": "analysts", - "created_at": datetime.now().isoformat(), - "elapsed": 0, - "stages": [ - {"status": "running", "completed_at": None}, - {"status": "pending", "completed_at": None}, - {"status": "pending", "completed_at": None}, - {"status": "pending", "completed_at": None}, - {"status": "pending", "completed_at": None}, - ], - "logs": [], - "decision": None, - "quant_signal": None, - "llm_signal": None, - "confidence": None, - "error": None, - } - await broadcast_progress(task_id, app.state.task_results[task_id]) - - # Write analysis script to temp file with restrictive permissions (avoids subprocess -c quoting issues) - fd, script_path_str = tempfile.mkstemp(suffix=".py", prefix=f"analysis_{task_id}_") - script_path = Path(script_path_str) - os.chmod(script_path, 0o600) - with os.fdopen(fd, "w") as f: - f.write(ANALYSIS_SCRIPT_TEMPLATE) - - # Store process reference for cancellation - app.state.processes = getattr(app.state, 'processes', {}) - app.state.processes[task_id] = None - - # Cancellation event for the monitor coroutine - cancel_event = asyncio.Event() - - # Stage name to index mapping - STAGE_NAMES = ["analysts", "research", "trading", "risk", "portfolio"] - - def _update_task_stage(stage_name: str): - """Update task state for a completed stage and mark next as running.""" - try: - idx = STAGE_NAMES.index(stage_name) - except ValueError: - return - # Mark all previous stages as completed - for i in range(idx): - if app.state.task_results[task_id]["stages"][i]["status"] != "completed": - app.state.task_results[task_id]["stages"][i]["status"] = "completed" - app.state.task_results[task_id]["stages"][i]["completed_at"] = datetime.now().strftime("%H:%M:%S") - # Mark current as completed - if app.state.task_results[task_id]["stages"][idx]["status"] != "completed": - app.state.task_results[task_id]["stages"][idx]["status"] = "completed" - app.state.task_results[task_id]["stages"][idx]["completed_at"] = datetime.now().strftime("%H:%M:%S") - # Mark next as running - if idx + 1 < 5: - if app.state.task_results[task_id]["stages"][idx + 1]["status"] == "pending": - app.state.task_results[task_id]["stages"][idx + 1]["status"] = "running" - # Update progress - app.state.task_results[task_id]["progress"] = int((idx + 1) / 5 * 100) - app.state.task_results[task_id]["current_stage"] = stage_name - - async def run_analysis(): - """Run analysis subprocess and broadcast progress""" - try: - # Use clean environment - don't inherit parent env - clean_env = {k: v for k, v in os.environ.items() - if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))} - clean_env["ANTHROPIC_API_KEY"] = anthropic_key - clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic" - - proc = await asyncio.create_subprocess_exec( - str(ANALYSIS_PYTHON), - str(script_path), - request.ticker, - date, - str(REPO_ROOT), - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - env=clean_env, - ) - app.state.processes[task_id] = proc - - # Read stdout line-by-line for real-time stage updates - stdout_lines = [] - while True: - try: - line_bytes = await asyncio.wait_for(proc.stdout.readline(), timeout=300.0) - except asyncio.TimeoutError: - break - if not line_bytes: - break - line = line_bytes.decode(errors="replace").rstrip() - stdout_lines.append(line) - if line.startswith("STAGE:"): - stage = line.split(":", 1)[1].strip() - _update_task_stage(stage) - await broadcast_progress(task_id, app.state.task_results[task_id]) - if cancel_event.is_set(): - break - - await proc.wait() - stderr_bytes = await proc.stderr.read() - - # Clean up script file - try: - script_path.unlink() - except Exception: - pass - - if proc.returncode == 0: - output = "\n".join(stdout_lines) - decision = "HOLD" - for line in stdout_lines: - if line.startswith("SIGNAL_DETAIL:"): - try: - detail = json.loads(line.split(":", 1)[1].strip()) - app.state.task_results[task_id]["quant_signal"] = detail.get("quant_signal") - app.state.task_results[task_id]["llm_signal"] = detail.get("llm_signal") - app.state.task_results[task_id]["confidence"] = detail.get("confidence") - except Exception: - pass - if line.startswith("ANALYSIS_COMPLETE:"): - decision = line.split(":", 1)[1].strip() - - app.state.task_results[task_id]["status"] = "completed" - app.state.task_results[task_id]["progress"] = 100 - app.state.task_results[task_id]["decision"] = decision - app.state.task_results[task_id]["current_stage"] = "portfolio" - for i in range(5): - app.state.task_results[task_id]["stages"][i]["status"] = "completed" - if not app.state.task_results[task_id]["stages"][i].get("completed_at"): - app.state.task_results[task_id]["stages"][i]["completed_at"] = datetime.now().strftime("%H:%M:%S") - else: - error_msg = stderr_bytes.decode(errors="replace")[-1000:] if stderr_bytes else "Unknown error" - app.state.task_results[task_id]["status"] = "failed" - app.state.task_results[task_id]["error"] = error_msg - - _save_task_status(task_id, app.state.task_results[task_id]) - - except Exception as e: - cancel_event.set() - app.state.task_results[task_id]["status"] = "failed" - app.state.task_results[task_id]["error"] = str(e) - try: - script_path.unlink() - except Exception: - pass - - _save_task_status(task_id, app.state.task_results[task_id]) - - await broadcast_progress(task_id, app.state.task_results[task_id]) - - task = asyncio.create_task(run_analysis()) - app.state.analysis_tasks[task_id] = task - - return { - "task_id": task_id, - "ticker": request.ticker, - "date": date, - "status": "running", - } + try: + return await app.state.analysis_service.start_analysis( + task_id=task_id, + ticker=payload.ticker, + date=date, + request_context=request_context, + broadcast_progress=broadcast_progress, + ) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + except RuntimeError as exc: + raise HTTPException(status_code=500, detail=str(exc)) @app.get("/api/analysis/status/{task_id}") @@ -600,13 +329,12 @@ async def list_tasks(api_key: Optional[str] = Header(None)): @app.delete("/api/analysis/cancel/{task_id}") async def cancel_task(task_id: str, api_key: Optional[str] = Header(None)): - """Cancel a running task""" + """Cancel a running task.""" if not _check_api_key(api_key): _auth_error() if task_id not in app.state.task_results: raise HTTPException(status_code=404, detail="Task not found") - # Kill the subprocess if it's still running proc = app.state.processes.get(task_id) if proc and proc.returncode is None: try: @@ -614,26 +342,18 @@ async def cancel_task(task_id: str, api_key: Optional[str] = Header(None)): except Exception: pass - # Cancel the asyncio task task = app.state.analysis_tasks.get(task_id) if task: task.cancel() - app.state.task_results[task_id]["status"] = "failed" - app.state.task_results[task_id]["error"] = "用户取消" - _save_task_status(task_id, app.state.task_results[task_id]) - await broadcast_progress(task_id, app.state.task_results[task_id]) - # Clean up temp script (may use tempfile.mkstemp with random suffix) - for p in Path("/tmp").glob(f"analysis_{task_id}_*.py"): - try: - p.unlink() - except Exception: - pass + state = app.state.task_results[task_id] + state["status"] = "cancelled" + state["error"] = "用户取消" + app.state.result_store.save_task_status(task_id, state) + await broadcast_progress(task_id, state) + app.state.result_store.delete_task_status(task_id) - # Remove persisted task state - _delete_task_status(task_id) - - return {"task_id": task_id, "status": "cancelled"} + return {"contract_version": "v1alpha1", "task_id": task_id, "status": "cancelled"} # ============== WebSocket ============== @@ -1091,169 +811,31 @@ async def get_recommendation_endpoint(date: str, ticker: str, api_key: Optional[ # --- Batch Analysis --- @app.post("/api/portfolio/analyze") -async def start_portfolio_analysis(api_key: Optional[str] = Header(None)): - """ - Trigger batch analysis for all watchlist tickers. - Runs serially, streaming progress via WebSocket (task_id prefixed with 'port_'). - """ +async def start_portfolio_analysis( + http_request: Request, + api_key: Optional[str] = Header(None), +): + """Trigger batch analysis for all watchlist tickers.""" if not _check_api_key(api_key): _auth_error() + import uuid + date = datetime.now().strftime("%Y-%m-%d") task_id = f"port_{date}_{uuid.uuid4().hex[:6]}" + request_context = build_request_context(http_request, api_key=api_key) - if app.state.migration_flags.use_application_services: - request_context = build_request_context(api_key=api_key) - try: - return await app.state.analysis_service.start_portfolio_analysis( - task_id=task_id, - date=date, - request_context=request_context, - broadcast_progress=broadcast_progress, - ) - except ValueError as exc: - raise HTTPException(status_code=400, detail=str(exc)) - except RuntimeError as exc: - raise HTTPException(status_code=500, detail=str(exc)) - - watchlist = get_watchlist() - if not watchlist: - raise HTTPException(status_code=400, detail="自选股为空,请先添加股票") - - total = len(watchlist) - app.state.task_results[task_id] = { - "task_id": task_id, - "type": "portfolio", - "status": "running", - "total": total, - "completed": 0, - "failed": 0, - "current_ticker": None, - "results": [], - "error": None, - } - - api_key = _get_analysis_api_key() - if not api_key: - raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set") - - await broadcast_progress(task_id, app.state.task_results[task_id]) - - async def run_portfolio_analysis(): - max_retries = MAX_RETRY_COUNT - - async def run_single_analysis(ticker: str, stock: dict) -> tuple[bool, str, dict | None]: - """Run analysis for one ticker. Returns (success, decision, rec_or_error).""" - last_error = None - for attempt in range(max_retries + 1): - script_path = None - try: - fd, script_path_str = tempfile.mkstemp(suffix=".py", prefix=f"analysis_{task_id}_{stock['_idx']}_") - script_path = Path(script_path_str) - os.chmod(script_path, 0o600) - with os.fdopen(fd, "w") as f: - f.write(ANALYSIS_SCRIPT_TEMPLATE) - - clean_env = {k: v for k, v in os.environ.items() - if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))} - clean_env["ANTHROPIC_API_KEY"] = api_key - clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic" - - proc = await asyncio.create_subprocess_exec( - str(ANALYSIS_PYTHON), str(script_path), ticker, date, str(REPO_ROOT), - stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - env=clean_env, - ) - app.state.processes[task_id] = proc - - stdout, stderr = await proc.communicate() - - try: - script_path.unlink() - except Exception: - pass - - if proc.returncode == 0: - output = stdout.decode() - decision = "HOLD" - quant_signal = None - llm_signal = None - confidence = None - for line in output.splitlines(): - if line.startswith("SIGNAL_DETAIL:"): - try: - detail = json.loads(line.split(":", 1)[1].strip()) - quant_signal = detail.get("quant_signal") - llm_signal = detail.get("llm_signal") - confidence = detail.get("confidence") - except Exception: - pass - if line.startswith("ANALYSIS_COMPLETE:"): - decision = line.split(":", 1)[1].strip() - rec = { - "ticker": ticker, - "name": stock.get("name", ticker), - "analysis_date": date, - "decision": decision, - "quant_signal": quant_signal, - "llm_signal": llm_signal, - "confidence": confidence, - "created_at": datetime.now().isoformat(), - } - save_recommendation(date, ticker, rec) - return True, decision, rec - else: - last_error = stderr.decode()[-500:] if stderr else f"exit {proc.returncode}" - except Exception as e: - last_error = str(e) - finally: - if script_path: - try: - script_path.unlink() - except Exception: - pass - if attempt < max_retries: - await asyncio.sleep(RETRY_BASE_DELAY_SECS ** attempt) # exponential backoff: 1s, 2s - - return False, "HOLD", None - - try: - for i, stock in enumerate(watchlist): - stock["_idx"] = i # used in temp file name - ticker = stock["ticker"] - app.state.task_results[task_id]["current_ticker"] = ticker - app.state.task_results[task_id]["status"] = "running" - app.state.task_results[task_id]["completed"] = i - await broadcast_progress(task_id, app.state.task_results[task_id]) - - success, decision, rec = await run_single_analysis(ticker, stock) - if success: - app.state.task_results[task_id]["completed"] = i + 1 - app.state.task_results[task_id]["results"].append(rec) - else: - app.state.task_results[task_id]["failed"] += 1 - - await broadcast_progress(task_id, app.state.task_results[task_id]) - - app.state.task_results[task_id]["status"] = "completed" - app.state.task_results[task_id]["current_ticker"] = None - _save_task_status(task_id, app.state.task_results[task_id]) - - except Exception as e: - app.state.task_results[task_id]["status"] = "failed" - app.state.task_results[task_id]["error"] = str(e) - _save_task_status(task_id, app.state.task_results[task_id]) - - await broadcast_progress(task_id, app.state.task_results[task_id]) - - task = asyncio.create_task(run_portfolio_analysis()) - app.state.analysis_tasks[task_id] = task - - return { - "task_id": task_id, - "total": total, - "status": "running", - } + try: + return await app.state.analysis_service.start_portfolio_analysis( + task_id=task_id, + date=date, + request_context=request_context, + broadcast_progress=broadcast_progress, + ) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + except RuntimeError as exc: + raise HTTPException(status_code=500, detail=str(exc)) diff --git a/web_dashboard/backend/services/analysis_service.py b/web_dashboard/backend/services/analysis_service.py index 5e4fbe0a..9118e7d7 100644 --- a/web_dashboard/backend/services/analysis_service.py +++ b/web_dashboard/backend/services/analysis_service.py @@ -2,15 +2,15 @@ from __future__ import annotations import asyncio import json -import os -import tempfile +import time from datetime import datetime -from pathlib import Path from typing import Awaitable, Callable, Optional +from .executor import AnalysisExecutionOutput, AnalysisExecutor, AnalysisExecutorError from .request_context import RequestContext BroadcastFn = Callable[[str, dict], Awaitable[None]] +ANALYSIS_STAGE_NAMES = ["analysts", "research", "trading", "risk", "portfolio"] class AnalysisService: @@ -19,24 +19,56 @@ class AnalysisService: def __init__( self, *, - analysis_python: Path, - repo_root: Path, - analysis_script_template: str, - api_key_resolver: Callable[[], Optional[str]], + executor: AnalysisExecutor, result_store, job_service, retry_count: int = 2, retry_base_delay_secs: int = 1, ): - self.analysis_python = analysis_python - self.repo_root = repo_root - self.analysis_script_template = analysis_script_template - self.api_key_resolver = api_key_resolver + self.executor = executor self.result_store = result_store self.job_service = job_service self.retry_count = retry_count self.retry_base_delay_secs = retry_base_delay_secs + async def start_analysis( + self, + *, + task_id: str, + ticker: str, + date: str, + request_context: RequestContext, + broadcast_progress: BroadcastFn, + ) -> dict: + state = self.job_service.create_analysis_job( + task_id=task_id, + ticker=ticker, + date=date, + request_id=request_context.request_id, + executor_type=request_context.executor_type, + contract_version=request_context.contract_version, + ) + self.job_service.register_process(task_id, None) + await broadcast_progress(task_id, state) + + task = asyncio.create_task( + self._run_analysis( + task_id=task_id, + ticker=ticker, + date=date, + request_context=request_context, + broadcast_progress=broadcast_progress, + ) + ) + self.job_service.register_background_task(task_id, task) + return { + "contract_version": "v1alpha1", + "task_id": task_id, + "ticker": ticker, + "date": date, + "status": "running", + } + async def start_portfolio_analysis( self, *, @@ -45,16 +77,17 @@ class AnalysisService: request_context: RequestContext, broadcast_progress: BroadcastFn, ) -> dict: - del request_context # Reserved for future auditing/auth propagation. watchlist = self.result_store.get_watchlist() if not watchlist: raise ValueError("自选股为空,请先添加股票") - analysis_api_key = self.api_key_resolver() - if not analysis_api_key: - raise RuntimeError("ANTHROPIC_API_KEY environment variable not set") - - state = self.job_service.create_portfolio_job(task_id=task_id, total=len(watchlist)) + state = self.job_service.create_portfolio_job( + task_id=task_id, + total=len(watchlist), + request_id=request_context.request_id, + executor_type=request_context.executor_type, + contract_version=request_context.contract_version, + ) await broadcast_progress(task_id, state) task = asyncio.create_task( @@ -62,24 +95,111 @@ class AnalysisService: task_id=task_id, date=date, watchlist=watchlist, - analysis_api_key=analysis_api_key, + request_context=request_context, broadcast_progress=broadcast_progress, ) ) self.job_service.register_background_task(task_id, task) return { + "contract_version": "v1alpha1", "task_id": task_id, "total": len(watchlist), "status": "running", } + async def _run_analysis( + self, + *, + task_id: str, + ticker: str, + date: str, + request_context: RequestContext, + broadcast_progress: BroadcastFn, + ) -> None: + start_time = time.monotonic() + try: + output = await self.executor.execute( + task_id=task_id, + ticker=ticker, + date=date, + request_context=request_context, + on_stage=lambda stage: self._handle_analysis_stage( + task_id=task_id, + stage_name=stage, + started_at=start_time, + broadcast_progress=broadcast_progress, + ), + ) + state = self.job_service.task_results[task_id] + elapsed_seconds = int(time.monotonic() - start_time) + contract = output.to_result_contract( + task_id=task_id, + ticker=ticker, + date=date, + created_at=state["created_at"], + elapsed_seconds=elapsed_seconds, + current_stage=ANALYSIS_STAGE_NAMES[-1], + ) + result_ref = self.result_store.save_result_contract(task_id, contract) + self.job_service.complete_analysis_job( + task_id, + contract=contract, + result_ref=result_ref, + executor_type=request_context.executor_type, + ) + except AnalysisExecutorError as exc: + self._fail_analysis_state( + task_id=task_id, + message=str(exc), + started_at=start_time, + ) + except Exception as exc: + self._fail_analysis_state( + task_id=task_id, + message=str(exc), + started_at=start_time, + ) + + await broadcast_progress(task_id, self.job_service.task_results[task_id]) + + async def _handle_analysis_stage( + self, + *, + task_id: str, + stage_name: str, + started_at: float, + broadcast_progress: BroadcastFn, + ) -> None: + state = self.job_service.task_results[task_id] + try: + idx = ANALYSIS_STAGE_NAMES.index(stage_name) + except ValueError: + return + + for i, entry in enumerate(state["stages"]): + if i < idx: + if entry["status"] != "completed": + entry["status"] = "completed" + entry["completed_at"] = datetime.now().strftime("%H:%M:%S") + elif i == idx: + entry["status"] = "completed" + entry["completed_at"] = entry["completed_at"] or datetime.now().strftime("%H:%M:%S") + elif i == idx + 1 and entry["status"] == "pending": + entry["status"] = "running" + + state["progress"] = int((idx + 1) / len(ANALYSIS_STAGE_NAMES) * 100) + state["current_stage"] = stage_name + state["elapsed_seconds"] = int(time.monotonic() - started_at) + state["elapsed"] = state["elapsed_seconds"] + await broadcast_progress(task_id, state) + async def _run_portfolio_analysis( self, *, task_id: str, date: str, watchlist: list[dict], - analysis_api_key: str, + request_context: RequestContext, broadcast_progress: BroadcastFn, ) -> None: try: @@ -96,7 +216,7 @@ class AnalysisService: ticker=ticker, stock=stock, date=date, - analysis_api_key=analysis_api_key, + request_context=request_context, ) if success and rec is not None: self.job_service.append_portfolio_result(task_id, rec) @@ -118,61 +238,27 @@ class AnalysisService: ticker: str, stock: dict, date: str, - analysis_api_key: str, + request_context: RequestContext, ) -> tuple[bool, Optional[dict]]: last_error: Optional[str] = None for attempt in range(self.retry_count + 1): - script_path: Optional[Path] = None try: - fd, script_path_str = tempfile.mkstemp( - suffix=".py", - prefix=f"analysis_{task_id}_{stock['_idx']}_", + output = await self.executor.execute( + task_id=f"{task_id}_{stock['_idx']}", + ticker=ticker, + date=date, + request_context=request_context, ) - script_path = Path(script_path_str) - os.chmod(script_path, 0o600) - with os.fdopen(fd, "w") as handle: - handle.write(self.analysis_script_template) - - clean_env = { - key: value - for key, value in os.environ.items() - if not key.startswith(("PYTHON", "CONDA", "VIRTUAL")) - } - clean_env["ANTHROPIC_API_KEY"] = analysis_api_key - clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic" - - proc = await asyncio.create_subprocess_exec( - str(self.analysis_python), - str(script_path), - ticker, - date, - str(self.repo_root), - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - env=clean_env, + rec = self._build_recommendation_record( + output=output, + ticker=ticker, + stock=stock, + date=date, ) - self.job_service.register_process(task_id, proc) - stdout, stderr = await proc.communicate() - - if proc.returncode == 0: - rec = self._build_recommendation_record( - stdout=stdout.decode(), - ticker=ticker, - stock=stock, - date=date, - ) - self.result_store.save_recommendation(date, ticker, rec) - return True, rec - - last_error = stderr.decode()[-500:] if stderr else f"exit {proc.returncode}" + self.result_store.save_recommendation(date, ticker, rec) + return True, rec except Exception as exc: last_error = str(exc) - finally: - if script_path is not None: - try: - script_path.unlink() - except Exception: - pass if attempt < self.retry_count: await asyncio.sleep(self.retry_base_delay_secs ** attempt) @@ -181,23 +267,45 @@ class AnalysisService: self.job_service.task_results[task_id]["last_error"] = last_error return False, None + def _fail_analysis_state(self, *, task_id: str, message: str, started_at: float) -> None: + state = self.job_service.task_results[task_id] + state["status"] = "failed" + state["elapsed_seconds"] = int(time.monotonic() - started_at) + state["elapsed"] = state["elapsed_seconds"] + state["result"] = None + state["error"] = message + self.result_store.save_task_status(task_id, state) + @staticmethod - def _build_recommendation_record(*, stdout: str, ticker: str, stock: dict, date: str) -> dict: - decision = "HOLD" - quant_signal = None - llm_signal = None - confidence = None - for line in stdout.splitlines(): - if line.startswith("SIGNAL_DETAIL:"): - try: - detail = json.loads(line.split(":", 1)[1].strip()) - except Exception: - continue - quant_signal = detail.get("quant_signal") - llm_signal = detail.get("llm_signal") - confidence = detail.get("confidence") - if line.startswith("ANALYSIS_COMPLETE:"): - decision = line.split(":", 1)[1].strip() + def _build_recommendation_record( + *, + ticker: str, + stock: dict, + date: str, + output: AnalysisExecutionOutput | None = None, + stdout: str | None = None, + ) -> dict: + if output is not None: + decision = output.decision + quant_signal = output.quant_signal + llm_signal = output.llm_signal + confidence = output.confidence + else: + decision = "HOLD" + quant_signal = None + llm_signal = None + confidence = None + for line in (stdout or "").splitlines(): + if line.startswith("SIGNAL_DETAIL:"): + try: + detail = json.loads(line.split(":", 1)[1].strip()) + except Exception: + continue + quant_signal = detail.get("quant_signal") + llm_signal = detail.get("llm_signal") + confidence = detail.get("confidence") + if line.startswith("ANALYSIS_COMPLETE:"): + decision = line.split(":", 1)[1].strip() return { "ticker": ticker, diff --git a/web_dashboard/backend/services/executor.py b/web_dashboard/backend/services/executor.py new file mode 100644 index 00000000..18844d6d --- /dev/null +++ b/web_dashboard/backend/services/executor.py @@ -0,0 +1,381 @@ +from __future__ import annotations + +import asyncio +import json +import os +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import Awaitable, Callable, Optional, Protocol + +from .request_context import ( + CONTRACT_VERSION, + DEFAULT_EXECUTOR_TYPE, + RequestContext, +) + +StageCallback = Callable[[str], Awaitable[None]] +ProcessRegistry = Callable[[str, asyncio.subprocess.Process | None], None] + +LEGACY_ANALYSIS_SCRIPT_TEMPLATE = """ +import json +import os +import sys +from pathlib import Path + +ticker = sys.argv[1] +date = sys.argv[2] +repo_root = sys.argv[3] + +sys.path.insert(0, repo_root) + +import py_mini_racer +sys.modules["mini_racer"] = py_mini_racer + +from orchestrator.config import OrchestratorConfig +from orchestrator.orchestrator import TradingOrchestrator +from tradingagents.default_config import get_default_config + +trading_config = get_default_config() +trading_config["project_dir"] = os.path.join(repo_root, "tradingagents") +trading_config["results_dir"] = os.path.join(repo_root, "results") +trading_config["max_debate_rounds"] = 1 +trading_config["max_risk_discuss_rounds"] = 1 + +print("STAGE:analysts", flush=True) +print("STAGE:research", flush=True) + +config = OrchestratorConfig( + quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""), + trading_agents_config=trading_config, +) + +orchestrator = TradingOrchestrator(config) + +print("STAGE:trading", flush=True) + +try: + result = orchestrator.get_combined_signal(ticker, date) +except ValueError as exc: + print("ANALYSIS_ERROR:" + str(exc), file=sys.stderr, flush=True) + sys.exit(1) + +print("STAGE:risk", flush=True) + +direction = result.direction +confidence = result.confidence +llm_sig_obj = result.llm_signal +quant_sig_obj = result.quant_signal +llm_signal = llm_sig_obj.metadata.get("rating", "HOLD") if llm_sig_obj else "HOLD" +if quant_sig_obj is None: + quant_signal = "HOLD" +elif quant_sig_obj.direction == 1: + quant_signal = "BUY" if quant_sig_obj.confidence >= 0.7 else "OVERWEIGHT" +elif quant_sig_obj.direction == -1: + quant_signal = "SELL" if quant_sig_obj.confidence >= 0.7 else "UNDERWEIGHT" +else: + quant_signal = "HOLD" + +if direction == 1: + signal = "BUY" if confidence >= 0.7 else "OVERWEIGHT" +elif direction == -1: + signal = "SELL" if confidence >= 0.7 else "UNDERWEIGHT" +else: + signal = "HOLD" + +results_dir = Path(repo_root) / "results" / ticker / date +results_dir.mkdir(parents=True, exist_ok=True) + +report_content = ( + "# TradingAgents 分析报告\\n\\n" + "**股票**: " + ticker + "\\n" + "**日期**: " + date + "\\n\\n" + "## 最终决策\\n\\n" + "**" + signal + "**\\n\\n" + "## 信号详情\\n\\n" + "- LLM 信号: " + llm_signal + "\\n" + "- Quant 信号: " + quant_signal + "\\n" + "- 置信度: " + f"{confidence:.1%}" + "\\n\\n" + "## 分析摘要\\n\\n" + "N/A\\n" +) + +report_path = results_dir / "complete_report.md" +report_path.write_text(report_content) + +print("STAGE:portfolio", flush=True) +signal_detail = json.dumps({"llm_signal": llm_signal, "quant_signal": quant_signal, "confidence": confidence}) +print("SIGNAL_DETAIL:" + signal_detail, flush=True) +print("ANALYSIS_COMPLETE:" + signal, flush=True) +""" + + +def _rating_to_direction(rating: Optional[str]) -> int: + if rating in {"BUY", "OVERWEIGHT"}: + return 1 + if rating in {"SELL", "UNDERWEIGHT"}: + return -1 + return 0 + + +@dataclass(frozen=True) +class AnalysisExecutionOutput: + decision: str + quant_signal: Optional[str] + llm_signal: Optional[str] + confidence: Optional[float] + report_path: Optional[str] = None + contract_version: str = CONTRACT_VERSION + executor_type: str = DEFAULT_EXECUTOR_TYPE + + def to_result_contract( + self, + *, + task_id: str, + ticker: str, + date: str, + created_at: str, + elapsed_seconds: int, + current_stage: str = "portfolio", + ) -> dict: + return { + "contract_version": self.contract_version, + "task_id": task_id, + "ticker": ticker, + "date": date, + "status": "completed", + "progress": 100, + "current_stage": current_stage, + "created_at": created_at, + "elapsed_seconds": elapsed_seconds, + "elapsed": elapsed_seconds, + "result": { + "decision": self.decision, + "confidence": self.confidence, + "signals": { + "merged": { + "direction": _rating_to_direction(self.decision), + "rating": self.decision, + }, + "quant": { + "direction": _rating_to_direction(self.quant_signal), + "rating": self.quant_signal, + "available": self.quant_signal is not None, + }, + "llm": { + "direction": _rating_to_direction(self.llm_signal), + "rating": self.llm_signal, + "available": self.llm_signal is not None, + }, + }, + "degraded": self.quant_signal is None or self.llm_signal is None, + "report": { + "path": self.report_path, + "available": bool(self.report_path), + }, + }, + "error": None, + } + + +class AnalysisExecutorError(RuntimeError): + def __init__(self, message: str, *, code: str = "analysis_failed", retryable: bool = False): + super().__init__(message) + self.code = code + self.retryable = retryable + + +class AnalysisExecutor(Protocol): + async def execute( + self, + *, + task_id: str, + ticker: str, + date: str, + request_context: RequestContext, + on_stage: Optional[StageCallback] = None, + ) -> AnalysisExecutionOutput: ... + + +class LegacySubprocessAnalysisExecutor: + """Run the legacy dashboard analysis script behind a stable executor contract.""" + + def __init__( + self, + *, + analysis_python: Path, + repo_root: Path, + api_key_resolver: Callable[[], Optional[str]], + process_registry: Optional[ProcessRegistry] = None, + script_template: str = LEGACY_ANALYSIS_SCRIPT_TEMPLATE, + stdout_timeout_secs: float = 300.0, + ): + self.analysis_python = analysis_python + self.repo_root = repo_root + self.api_key_resolver = api_key_resolver + self.process_registry = process_registry + self.script_template = script_template + self.stdout_timeout_secs = stdout_timeout_secs + + async def execute( + self, + *, + task_id: str, + ticker: str, + date: str, + request_context: RequestContext, + on_stage: Optional[StageCallback] = None, + ) -> AnalysisExecutionOutput: + analysis_api_key = request_context.api_key or self.api_key_resolver() + if not analysis_api_key: + raise RuntimeError("ANTHROPIC_API_KEY environment variable not set") + + script_path: Optional[Path] = None + proc: asyncio.subprocess.Process | None = None + try: + fd, script_path_str = tempfile.mkstemp(suffix=".py", prefix=f"analysis_{task_id}_") + script_path = Path(script_path_str) + os.chmod(script_path, 0o600) + with os.fdopen(fd, "w", encoding="utf-8") as handle: + handle.write(self.script_template) + + clean_env = { + key: value + for key, value in os.environ.items() + if not key.startswith(("PYTHON", "CONDA", "VIRTUAL")) + } + clean_env["ANTHROPIC_API_KEY"] = analysis_api_key + + proc = await asyncio.create_subprocess_exec( + str(self.analysis_python), + str(script_path), + ticker, + date, + str(self.repo_root), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=clean_env, + ) + if self.process_registry is not None: + self.process_registry(task_id, proc) + + stdout_lines: list[str] = [] + assert proc.stdout is not None + while True: + try: + line_bytes = await asyncio.wait_for( + proc.stdout.readline(), + timeout=self.stdout_timeout_secs, + ) + except asyncio.TimeoutError as exc: + await self._terminate_process(proc) + raise AnalysisExecutorError( + f"analysis subprocess timed out after {self.stdout_timeout_secs:g}s", + retryable=True, + ) from exc + if not line_bytes: + break + line = line_bytes.decode(errors="replace").rstrip() + stdout_lines.append(line) + if on_stage is not None and line.startswith("STAGE:"): + await on_stage(line.split(":", 1)[1].strip()) + + await proc.wait() + stderr_bytes = await proc.stderr.read() if proc.stderr is not None else b"" + if proc.returncode != 0: + message = stderr_bytes.decode(errors="replace")[-1000:] if stderr_bytes else f"exit {proc.returncode}" + raise AnalysisExecutorError(message) + + return self._parse_output( + stdout_lines=stdout_lines, + ticker=ticker, + date=date, + contract_version=request_context.contract_version, + executor_type=request_context.executor_type, + ) + finally: + if self.process_registry is not None: + self.process_registry(task_id, None) + if script_path is not None: + try: + script_path.unlink() + except Exception: + pass + + @staticmethod + async def _terminate_process(proc: asyncio.subprocess.Process) -> None: + if proc.returncode is not None: + return + try: + proc.kill() + except ProcessLookupError: + return + await proc.wait() + + @staticmethod + def _parse_output( + *, + stdout_lines: list[str], + ticker: str, + date: str, + contract_version: str, + executor_type: str, + ) -> AnalysisExecutionOutput: + decision: Optional[str] = None + quant_signal = None + llm_signal = None + confidence = None + seen_signal_detail = False + seen_complete = False + + for line in stdout_lines: + if line.startswith("SIGNAL_DETAIL:"): + seen_signal_detail = True + try: + detail = json.loads(line.split(":", 1)[1].strip()) + except Exception as exc: + raise AnalysisExecutorError("failed to parse SIGNAL_DETAIL payload") from exc + quant_signal = detail.get("quant_signal") + llm_signal = detail.get("llm_signal") + confidence = detail.get("confidence") + elif line.startswith("ANALYSIS_COMPLETE:"): + seen_complete = True + decision = line.split(":", 1)[1].strip() + + missing_markers = [] + if not seen_signal_detail: + missing_markers.append("SIGNAL_DETAIL") + if not seen_complete: + missing_markers.append("ANALYSIS_COMPLETE") + if missing_markers: + raise AnalysisExecutorError( + "analysis subprocess completed without required markers: " + + ", ".join(missing_markers) + ) + + report_path = str(Path("results") / ticker / date / "complete_report.md") + return AnalysisExecutionOutput( + decision=decision or "HOLD", + quant_signal=quant_signal, + llm_signal=llm_signal, + confidence=confidence, + report_path=report_path, + contract_version=contract_version, + executor_type=executor_type, + ) + + +class DirectAnalysisExecutor: + """Placeholder for a future in-process executor implementation.""" + + async def execute( + self, + *, + task_id: str, + ticker: str, + date: str, + request_context: RequestContext, + on_stage: Optional[StageCallback] = None, + ) -> AnalysisExecutionOutput: + del task_id, ticker, date, request_context, on_stage + raise NotImplementedError("DirectAnalysisExecutor is not implemented in phase 1") diff --git a/web_dashboard/backend/services/job_service.py b/web_dashboard/backend/services/job_service.py index c510dfcf..7fb55003 100644 --- a/web_dashboard/backend/services/job_service.py +++ b/web_dashboard/backend/services/job_service.py @@ -5,6 +5,10 @@ from datetime import datetime from typing import Any, Callable +CONTRACT_VERSION = "v1alpha1" +DEFAULT_EXECUTOR_TYPE = "legacy_subprocess" + + class JobService: """Application-layer job state orchestrator with legacy-compatible payloads.""" @@ -24,10 +28,71 @@ class JobService: self.delete_task = delete_task def restore_task_results(self, restored: dict[str, dict]) -> None: - self.task_results.update(restored) + self.task_results.update( + { + task_id: self._normalize_task_state(task_id, state) + for task_id, state in restored.items() + } + ) - def create_portfolio_job(self, *, task_id: str, total: int) -> dict: - state = { + def create_analysis_job( + self, + *, + task_id: str, + ticker: str, + date: str, + request_id: str | None = None, + executor_type: str = DEFAULT_EXECUTOR_TYPE, + contract_version: str = CONTRACT_VERSION, + result_ref: str | None = None, + ) -> dict: + state = self._normalize_task_state(task_id, { + "task_id": task_id, + "ticker": ticker, + "date": date, + "status": "running", + "progress": 0, + "current_stage": "analysts", + "created_at": datetime.now().isoformat(), + "elapsed_seconds": 0, + "elapsed": 0, + "stages": [ + { + "name": stage_name, + "status": "running" if index == 0 else "pending", + "completed_at": None, + } + for index, stage_name in enumerate( + ["analysts", "research", "trading", "risk", "portfolio"] + ) + ], + "logs": [], + "decision": None, + "quant_signal": None, + "llm_signal": None, + "confidence": None, + "result": None, + "error": None, + "request_id": request_id, + "executor_type": executor_type, + "contract_version": contract_version, + "result_ref": result_ref, + }) + self.task_results[task_id] = state + self.processes.setdefault(task_id, None) + return state + + def create_portfolio_job( + self, + *, + task_id: str, + total: int, + request_id: str | None = None, + executor_type: str = DEFAULT_EXECUTOR_TYPE, + contract_version: str = CONTRACT_VERSION, + result_ref: str | None = None, + ) -> dict: + state = self._normalize_task_state(task_id, { "task_id": task_id, "type": "portfolio", "status": "running", @@ -38,11 +103,65 @@ class JobService: "results": [], "error": None, "created_at": datetime.now().isoformat(), - } + "request_id": request_id, + "executor_type": executor_type, + "contract_version": contract_version, + "result_ref": result_ref, + }) self.task_results[task_id] = state self.processes.setdefault(task_id, None) return state + def attach_result_contract( + self, + task_id: str, + *, + result_ref: str, + contract_version: str = CONTRACT_VERSION, + executor_type: str | None = None, + ) -> dict: + state = self.task_results[task_id] + state["result_ref"] = result_ref + state["contract_version"] = contract_version or state.get("contract_version") or CONTRACT_VERSION + if executor_type: + state["executor_type"] = executor_type + return state + + def complete_analysis_job( + self, + task_id: str, + *, + contract: dict, + result_ref: str, + executor_type: str | None = None, + ) -> dict: + state = self.task_results[task_id] + result = dict(contract.get("result") or {}) + signals = result.get("signals") or {} + quant = signals.get("quant") or {} + llm = signals.get("llm") or {} + + state["status"] = contract.get("status", "completed") + state["progress"] = contract.get("progress", 100) + state["current_stage"] = contract.get("current_stage", state.get("current_stage")) + state["elapsed_seconds"] = contract.get("elapsed_seconds", state.get("elapsed_seconds", 0)) + state["elapsed"] = contract.get("elapsed", state["elapsed_seconds"]) + state["decision"] = result.get("decision") + state["quant_signal"] = quant.get("rating") + state["llm_signal"] = llm.get("rating") + state["confidence"] = result.get("confidence") + state["result"] = result + state["error"] = contract.get("error") + state["contract_version"] = contract.get("contract_version", state.get("contract_version")) + self.attach_result_contract( + task_id, + result_ref=result_ref, + contract_version=state["contract_version"], + executor_type=executor_type, + ) + self.persist_task(task_id, state) + return state + def update_portfolio_progress(self, task_id: str, *, ticker: str, completed: int) -> dict: state = self.task_results[task_id] state["current_ticker"] = ticker @@ -92,3 +211,12 @@ class JobService: state["error"] = error self.persist_task(task_id, state) return state + + @staticmethod + def _normalize_task_state(task_id: str, state: dict) -> dict: + normalized = dict(state) + normalized.setdefault("request_id", task_id) + normalized.setdefault("executor_type", DEFAULT_EXECUTOR_TYPE) + normalized.setdefault("contract_version", CONTRACT_VERSION) + normalized.setdefault("result_ref", None) + return normalized diff --git a/web_dashboard/backend/services/request_context.py b/web_dashboard/backend/services/request_context.py index 1ab44cea..c88340a0 100644 --- a/web_dashboard/backend/services/request_context.py +++ b/web_dashboard/backend/services/request_context.py @@ -7,11 +7,17 @@ from uuid import uuid4 from fastapi import Request +CONTRACT_VERSION = "v1alpha1" +DEFAULT_EXECUTOR_TYPE = "legacy_subprocess" + + @dataclass(frozen=True) class RequestContext: """Minimal request-scoped metadata passed into application services.""" request_id: str + contract_version: str = CONTRACT_VERSION + executor_type: str = DEFAULT_EXECUTOR_TYPE api_key: Optional[str] = None client_host: Optional[str] = None is_local: bool = False @@ -23,6 +29,8 @@ def build_request_context( *, api_key: Optional[str] = None, request_id: Optional[str] = None, + contract_version: str = CONTRACT_VERSION, + executor_type: str = DEFAULT_EXECUTOR_TYPE, metadata: Optional[dict[str, str]] = None, ) -> RequestContext: """Create a stable request context without leaking FastAPI internals into services.""" @@ -30,6 +38,8 @@ def build_request_context( is_local = client_host in {"127.0.0.1", "::1", "localhost", "testclient"} return RequestContext( request_id=request_id or uuid4().hex, + contract_version=contract_version, + executor_type=executor_type, api_key=api_key, client_host=client_host, is_local=is_local, diff --git a/web_dashboard/backend/services/result_store.py b/web_dashboard/backend/services/result_store.py index 6efa89f7..6f4dcf71 100644 --- a/web_dashboard/backend/services/result_store.py +++ b/web_dashboard/backend/services/result_store.py @@ -5,11 +5,15 @@ from pathlib import Path from typing import Optional +CONTRACT_VERSION = "v1alpha1" + + class ResultStore: """Storage boundary for persisted task state and portfolio results.""" def __init__(self, task_status_dir: Path, portfolio_gateway): self.task_status_dir = task_status_dir + self.result_contract_dir = self.task_status_dir / "result_contracts" self.portfolio_gateway = portfolio_gateway def restore_task_results(self) -> dict[str, dict]: @@ -29,6 +33,15 @@ class ResultStore: self.task_status_dir.mkdir(parents=True, exist_ok=True) (self.task_status_dir / f"{task_id}.json").write_text(json.dumps(data, ensure_ascii=False)) + def save_result_contract(self, task_id: str, contract: dict) -> str: + self.result_contract_dir.mkdir(parents=True, exist_ok=True) + payload = dict(contract) + payload.setdefault("task_id", task_id) + payload.setdefault("contract_version", CONTRACT_VERSION) + file_path = self.result_contract_dir / f"{task_id}.json" + file_path.write_text(json.dumps(payload, ensure_ascii=False)) + return file_path.relative_to(self.task_status_dir).as_posix() + def delete_task_status(self, task_id: str) -> None: (self.task_status_dir / f"{task_id}.json").unlink(missing_ok=True) diff --git a/web_dashboard/backend/tests/test_api_smoke.py b/web_dashboard/backend/tests/test_api_smoke.py index b3ff7225..137b5765 100644 --- a/web_dashboard/backend/tests/test_api_smoke.py +++ b/web_dashboard/backend/tests/test_api_smoke.py @@ -53,3 +53,73 @@ def test_analysis_task_routes_smoke(monkeypatch): assert any(task["task_id"] == "task-smoke" for task in tasks_response.json()["tasks"]) assert status_response.status_code == 200 assert status_response.json()["task_id"] == "task-smoke" + + +def test_analysis_start_route_uses_analysis_service(monkeypatch): + monkeypatch.delenv("DASHBOARD_API_KEY", raising=False) + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + + main = _load_main_module(monkeypatch) + created: dict[str, object] = {} + + class DummyTask: + def cancel(self): + return None + + def fake_create_task(coro): + created["scheduled_coro"] = coro.cr_code.co_name + coro.close() + task = DummyTask() + created["task"] = task + return task + + monkeypatch.setattr(main.asyncio, "create_task", fake_create_task) + + with TestClient(main.app) as client: + response = client.post( + "/api/analysis/start", + json={"ticker": "AAPL", "date": "2026-04-11"}, + headers={"api-key": "test-key"}, + ) + + payload = response.json() + task_id = payload["task_id"] + + assert response.status_code == 200 + assert payload["ticker"] == "AAPL" + assert payload["date"] == "2026-04-11" + assert payload["status"] == "running" + assert created["scheduled_coro"] == "_run_analysis" + assert main.app.state.analysis_tasks[task_id] is created["task"] + assert main.app.state.task_results[task_id]["current_stage"] == "analysts" + assert main.app.state.task_results[task_id]["status"] == "running" + assert main.app.state.task_results[task_id]["request_id"] + assert main.app.state.task_results[task_id]["executor_type"] == "legacy_subprocess" + assert main.app.state.task_results[task_id]["result_ref"] is None + + +def test_portfolio_analyze_route_uses_analysis_service_smoke(monkeypatch): + monkeypatch.delenv("DASHBOARD_API_KEY", raising=False) + monkeypatch.setenv("TRADINGAGENTS_USE_APPLICATION_SERVICES", "1") + monkeypatch.setenv("ANTHROPIC_API_KEY", "service-key") + + main = _load_main_module(monkeypatch) + captured: dict[str, object] = {} + + async def fake_start_portfolio_analysis(*, task_id, date, request_context, broadcast_progress): + captured["task_id"] = task_id + captured["date"] = date + captured["request_context"] = request_context + captured["broadcast_progress"] = broadcast_progress + return {"task_id": task_id, "status": "running", "total": 3} + + with TestClient(main.app) as client: + monkeypatch.setattr(main.app.state.analysis_service, "start_portfolio_analysis", fake_start_portfolio_analysis) + response = client.post("/api/portfolio/analyze", headers={"api-key": "service-key"}) + + assert response.status_code == 200 + assert response.json()["status"] == "running" + assert str(captured["task_id"]).startswith("port_") + assert isinstance(captured["date"], str) + assert captured["request_context"].api_key == "service-key" + assert callable(captured["broadcast_progress"]) diff --git a/web_dashboard/backend/tests/test_executors.py b/web_dashboard/backend/tests/test_executors.py new file mode 100644 index 00000000..dcbe5b62 --- /dev/null +++ b/web_dashboard/backend/tests/test_executors.py @@ -0,0 +1,112 @@ +import asyncio +from pathlib import Path + +import pytest + +from services.executor import AnalysisExecutorError, LegacySubprocessAnalysisExecutor +from services.request_context import build_request_context + + +class _FakeStdout: + def __init__(self, lines, *, stall: bool = False): + self._lines = list(lines) + self._stall = stall + + async def readline(self): + if self._stall: + await asyncio.sleep(3600) + if self._lines: + return self._lines.pop(0) + return b"" + + +class _FakeStderr: + def __init__(self, payload: bytes = b""): + self._payload = payload + + async def read(self): + return self._payload + + +class _FakeProcess: + def __init__(self, stdout, *, stderr: bytes = b"", returncode=None): + self.stdout = stdout + self.stderr = _FakeStderr(stderr) + self.returncode = returncode + self.kill_called = False + self.wait_called = False + + async def wait(self): + self.wait_called = True + if self.returncode is None: + self.returncode = -9 if self.kill_called else 0 + return self.returncode + + def kill(self): + self.kill_called = True + self.returncode = -9 + + +def test_executor_raises_when_required_markers_missing(monkeypatch): + process = _FakeProcess( + _FakeStdout( + [ + b"STAGE:analysts\n", + b"STAGE:portfolio\n", + b"SIGNAL_DETAIL:{\"quant_signal\":\"BUY\",\"llm_signal\":\"BUY\",\"confidence\":0.8}\n", + ], + ), + returncode=0, + ) + + async def fake_create_subprocess_exec(*args, **kwargs): + return process + + monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_create_subprocess_exec) + + executor = LegacySubprocessAnalysisExecutor( + analysis_python=Path("/usr/bin/python3"), + repo_root=Path("."), + api_key_resolver=lambda: "env-key", + ) + + async def scenario(): + with pytest.raises(AnalysisExecutorError, match="required markers: ANALYSIS_COMPLETE"): + await executor.execute( + task_id="task-1", + ticker="AAPL", + date="2026-04-13", + request_context=build_request_context(api_key="ctx-key"), + ) + + asyncio.run(scenario()) + + +def test_executor_kills_subprocess_on_timeout(monkeypatch): + process = _FakeProcess(_FakeStdout([], stall=True)) + + async def fake_create_subprocess_exec(*args, **kwargs): + return process + + monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_create_subprocess_exec) + + executor = LegacySubprocessAnalysisExecutor( + analysis_python=Path("/usr/bin/python3"), + repo_root=Path("."), + api_key_resolver=lambda: "env-key", + stdout_timeout_secs=0.01, + ) + + async def scenario(): + with pytest.raises(AnalysisExecutorError, match="timed out"): + await executor.execute( + task_id="task-2", + ticker="AAPL", + date="2026-04-13", + request_context=build_request_context(api_key="ctx-key"), + ) + + asyncio.run(scenario()) + + assert process.kill_called is True + assert process.wait_called is True diff --git a/web_dashboard/backend/tests/test_services_migration.py b/web_dashboard/backend/tests/test_services_migration.py index 60088633..7bf419c5 100644 --- a/web_dashboard/backend/tests/test_services_migration.py +++ b/web_dashboard/backend/tests/test_services_migration.py @@ -2,6 +2,7 @@ import json import asyncio from services.analysis_service import AnalysisService +from services.executor import AnalysisExecutionOutput from services.job_service import JobService from services.migration_flags import load_migration_flags from services.request_context import build_request_context @@ -48,6 +49,8 @@ def test_build_request_context_defaults(): assert context.api_key == "secret" assert context.request_id + assert context.contract_version == "v1alpha1" + assert context.executor_type == "legacy_subprocess" assert context.metadata == {"source": "test"} @@ -64,6 +67,23 @@ def test_result_store_round_trip(tmp_path): assert positions == [{"ticker": "AAPL", "account": "模拟账户"}] +def test_result_store_saves_result_contract(tmp_path): + gateway = DummyPortfolioGateway() + store = ResultStore(tmp_path / "task_status", gateway) + + result_ref = store.save_result_contract( + "task-2", + {"status": "completed", "result": {"decision": "BUY"}}, + ) + + saved = json.loads((tmp_path / "task_status" / result_ref).read_text()) + + assert result_ref == "result_contracts/task-2.json" + assert saved["task_id"] == "task-2" + assert saved["contract_version"] == "v1alpha1" + assert saved["result"]["decision"] == "BUY" + + def test_job_service_create_and_fail_job(): task_results = {} analysis_tasks = {} @@ -78,15 +98,48 @@ def test_job_service_create_and_fail_job(): delete_task=lambda task_id: persisted.pop(task_id, None), ) - state = service.create_portfolio_job(task_id="port_1", total=2) + state = service.create_portfolio_job( + task_id="port_1", + total=2, + request_id="req-1", + executor_type="analysis_executor", + ) assert state["total"] == 2 assert processes["port_1"] is None + assert state["request_id"] == "req-1" + assert state["executor_type"] == "analysis_executor" + assert state["contract_version"] == "v1alpha1" + assert state["result_ref"] is None + + attached = service.attach_result_contract( + "port_1", + result_ref="result_contracts/port_1.json", + ) + assert attached["result_ref"] == "result_contracts/port_1.json" failed = service.fail_job("port_1", "boom") assert failed["status"] == "failed" assert persisted["port_1"]["error"] == "boom" +def test_job_service_restores_legacy_tasks_with_contract_metadata(): + service = JobService( + task_results={}, + analysis_tasks={}, + processes={}, + persist_task=lambda task_id, data: None, + delete_task=lambda task_id: None, + ) + + service.restore_task_results({"legacy-task": {"task_id": "legacy-task", "status": "running"}}) + + restored = service.task_results["legacy-task"] + assert restored["request_id"] == "legacy-task" + assert restored["executor_type"] == "legacy_subprocess" + assert restored["contract_version"] == "v1alpha1" + assert restored["result_ref"] is None + + def test_analysis_service_build_recommendation_record(): rec = AnalysisService._build_recommendation_record( stdout='\n'.join([ @@ -103,3 +156,74 @@ def test_analysis_service_build_recommendation_record(): assert rec["quant_signal"] == "BUY" assert rec["llm_signal"] == "HOLD" assert rec["confidence"] == 0.75 + + +class FakeExecutor: + async def execute(self, *, task_id, ticker, date, request_context, on_stage=None): + if on_stage is not None: + await on_stage("analysts") + await on_stage("research") + await on_stage("trading") + await on_stage("risk") + await on_stage("portfolio") + return AnalysisExecutionOutput( + decision="BUY", + quant_signal="OVERWEIGHT", + llm_signal="BUY", + confidence=0.82, + report_path=f"results/{ticker}/{date}/complete_report.md", + ) + + +def test_analysis_service_start_analysis_uses_executor(tmp_path): + gateway = DummyPortfolioGateway() + store = ResultStore(tmp_path / "task_status", gateway) + task_results = {} + analysis_tasks = {} + processes = {} + service = JobService( + task_results=task_results, + analysis_tasks=analysis_tasks, + processes=processes, + persist_task=store.save_task_status, + delete_task=store.delete_task_status, + ) + analysis_service = AnalysisService( + executor=FakeExecutor(), + result_store=store, + job_service=service, + ) + broadcasts = [] + + async def _broadcast(task_id, payload): + broadcasts.append((task_id, payload["status"], payload.get("current_stage"))) + + async def scenario(): + response = await analysis_service.start_analysis( + task_id="task-1", + ticker="AAPL", + date="2026-04-13", + request_context=build_request_context(api_key="secret"), + broadcast_progress=_broadcast, + ) + await analysis_tasks["task-1"] + return response + + response = asyncio.run(scenario()) + + assert response == { + "contract_version": "v1alpha1", + "task_id": "task-1", + "ticker": "AAPL", + "date": "2026-04-13", + "status": "running", + } + assert task_results["task-1"]["status"] == "completed" + assert task_results["task-1"]["decision"] == "BUY" + assert task_results["task-1"]["result_ref"] == "result_contracts/task-1.json" + assert task_results["task-1"]["result"]["signals"]["llm"]["rating"] == "BUY" + saved_contract = json.loads((tmp_path / "task_status" / "result_contracts" / "task-1.json").read_text()) + assert saved_contract["status"] == "completed" + assert saved_contract["result"]["signals"]["merged"]["rating"] == "BUY" + assert broadcasts[0] == ("task-1", "running", "analysts") + assert broadcasts[-1][1] == "completed"