Prevent executor regressions from leaking through the dashboard
Phase 1 left the backend halfway between legacy task payloads and the new executor boundary. This commit finishes the review-fix pass so missing protocol markers fail closed, timed-out subprocesses are killed, and successful analysis runs persist a result contract before task state is marked complete. Constraint: env312 lacks pytest-asyncio so async executor tests must run without extra plugins Rejected: Keep missing marker fallback as HOLD | masks protocol regressions as neutral signals Rejected: Leave service success assembly in AnalysisService | breaks contract-first persistence and result_ref wiring Confidence: high Scope-risk: moderate Reversibility: clean Directive: Keep backend success state driven by persisted result contracts; do not reintroduce raw stdout parsing in services Tested: python -m compileall orchestrator tradingagents web_dashboard/backend Tested: python -m pytest web_dashboard/backend/tests/test_executors.py web_dashboard/backend/tests/test_services_migration.py web_dashboard/backend/tests/test_api_smoke.py web_dashboard/backend/tests/test_main_api.py web_dashboard/backend/tests/test_portfolio_api.py -q Tested: python -m pytest orchestrator/tests/test_application_service.py orchestrator/tests/test_trading_graph_config.py -q Not-tested: real provider-backed MiniMax execution Not-tested: full dashboard websocket/manual UI flow
This commit is contained in:
parent
e802af3a1d
commit
255f478cd1
|
|
@ -6,11 +6,8 @@ import asyncio
|
|||
import hmac
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
|
@ -23,6 +20,7 @@ from fastapi.staticfiles import StaticFiles
|
|||
from pydantic import BaseModel
|
||||
|
||||
from services import AnalysisService, JobService, ResultStore, build_request_context, load_migration_flags
|
||||
from services.executor import LegacySubprocessAnalysisExecutor
|
||||
|
||||
# Path to TradingAgents repo root
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
|
|
@ -54,10 +52,12 @@ async def lifespan(app: FastAPI):
|
|||
delete_task=app.state.result_store.delete_task_status,
|
||||
)
|
||||
app.state.analysis_service = AnalysisService(
|
||||
analysis_python=ANALYSIS_PYTHON,
|
||||
repo_root=REPO_ROOT,
|
||||
analysis_script_template=ANALYSIS_SCRIPT_TEMPLATE,
|
||||
api_key_resolver=_get_analysis_api_key,
|
||||
executor=LegacySubprocessAnalysisExecutor(
|
||||
analysis_python=ANALYSIS_PYTHON,
|
||||
repo_root=REPO_ROOT,
|
||||
api_key_resolver=_get_analysis_api_key,
|
||||
process_registry=app.state.job_service.register_process,
|
||||
),
|
||||
result_store=app.state.result_store,
|
||||
job_service=app.state.job_service,
|
||||
retry_count=MAX_RETRY_COUNT,
|
||||
|
|
@ -229,23 +229,6 @@ def _save_to_cache(mode: str, data: dict):
|
|||
pass
|
||||
|
||||
|
||||
def _save_task_status(task_id: str, data: dict):
|
||||
"""Persist task state to disk"""
|
||||
try:
|
||||
TASK_STATUS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
(TASK_STATUS_DIR / f"{task_id}.json").write_text(json.dumps(data, ensure_ascii=False))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _delete_task_status(task_id: str):
|
||||
"""Remove persisted task state from disk"""
|
||||
try:
|
||||
(TASK_STATUS_DIR / f"{task_id}.json").unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ============== SEPA Screening ==============
|
||||
|
||||
def _run_sepa_screening(mode: str) -> dict:
|
||||
|
|
@ -282,288 +265,34 @@ async def screen_stocks(mode: str = Query("china_strict"), refresh: bool = Query
|
|||
|
||||
# ============== Analysis Execution ==============
|
||||
|
||||
# Script template for subprocess-based analysis
|
||||
# api_key is passed via environment variable (not CLI) for security
|
||||
ANALYSIS_SCRIPT_TEMPLATE = """
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
ticker = sys.argv[1]
|
||||
date = sys.argv[2]
|
||||
repo_root = sys.argv[3]
|
||||
|
||||
sys.path.insert(0, repo_root)
|
||||
os.environ["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
|
||||
import py_mini_racer
|
||||
sys.modules["mini_racer"] = py_mini_racer
|
||||
from pathlib import Path
|
||||
|
||||
print("STAGE:analysts", flush=True)
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.orchestrator import TradingOrchestrator
|
||||
|
||||
config = OrchestratorConfig(
|
||||
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
|
||||
trading_agents_config={
|
||||
"llm_provider": "anthropic",
|
||||
"deep_think_llm": "MiniMax-M2.7-highspeed",
|
||||
"quick_think_llm": "MiniMax-M2.7-highspeed",
|
||||
"backend_url": "https://api.minimaxi.com/anthropic",
|
||||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
"project_dir": os.path.join(repo_root, "tradingagents"),
|
||||
"results_dir": os.path.join(repo_root, "results"),
|
||||
}
|
||||
)
|
||||
|
||||
print("STAGE:research", flush=True)
|
||||
|
||||
orchestrator = TradingOrchestrator(config)
|
||||
|
||||
print("STAGE:trading", flush=True)
|
||||
|
||||
try:
|
||||
result = orchestrator.get_combined_signal(ticker, date)
|
||||
except ValueError as _e:
|
||||
print("ANALYSIS_ERROR:" + str(_e), file=sys.stderr, flush=True)
|
||||
sys.exit(1)
|
||||
|
||||
print("STAGE:risk", flush=True)
|
||||
|
||||
# Map direction + confidence to 5-level signal
|
||||
# FinalSignal is a dataclass, access via attributes not .get()
|
||||
direction = result.direction
|
||||
confidence = result.confidence
|
||||
llm_sig_obj = result.llm_signal
|
||||
quant_sig_obj = result.quant_signal
|
||||
# LLM metadata has "rating" field; quant metadata does not — derive from direction
|
||||
llm_signal = llm_sig_obj.metadata.get("rating", "HOLD") if llm_sig_obj else "HOLD"
|
||||
if quant_sig_obj is None:
|
||||
quant_signal = "HOLD"
|
||||
elif quant_sig_obj.direction == 1:
|
||||
quant_signal = "BUY" if quant_sig_obj.confidence >= 0.7 else "OVERWEIGHT"
|
||||
elif quant_sig_obj.direction == -1:
|
||||
quant_signal = "SELL" if quant_sig_obj.confidence >= 0.7 else "UNDERWEIGHT"
|
||||
else:
|
||||
quant_signal = "HOLD"
|
||||
|
||||
if direction == 1:
|
||||
signal = "BUY" if confidence >= 0.7 else "OVERWEIGHT"
|
||||
elif direction == -1:
|
||||
signal = "SELL" if confidence >= 0.7 else "UNDERWEIGHT"
|
||||
else:
|
||||
signal = "HOLD"
|
||||
|
||||
results_dir = Path(repo_root) / "results" / ticker / date
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
report_content = (
|
||||
"# TradingAgents 分析报告\\n\\n"
|
||||
"**股票**: " + ticker + "\\n"
|
||||
"**日期**: " + date + "\\n\\n"
|
||||
"## 最终决策\\n\\n"
|
||||
"**" + signal + "**\\n\\n"
|
||||
"## 信号详情\\n\\n"
|
||||
"- LLM 信号: " + llm_signal + "\\n"
|
||||
"- Quant 信号: " + quant_signal + "\\n"
|
||||
"- 置信度: " + f"{confidence:.1%}" + "\\n\\n"
|
||||
"## 分析摘要\\n\\n"
|
||||
"N/A\\n"
|
||||
)
|
||||
|
||||
report_path = results_dir / "complete_report.md"
|
||||
report_path.write_text(report_content)
|
||||
|
||||
print("STAGE:portfolio", flush=True)
|
||||
signal_detail = json.dumps({"llm_signal": llm_signal, "quant_signal": quant_signal, "confidence": confidence})
|
||||
print("SIGNAL_DETAIL:" + signal_detail, flush=True)
|
||||
print("ANALYSIS_COMPLETE:" + signal, flush=True)
|
||||
"""
|
||||
|
||||
|
||||
@app.post("/api/analysis/start")
|
||||
async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Header(None)):
|
||||
"""Start a new analysis task"""
|
||||
async def start_analysis(
|
||||
payload: AnalysisRequest,
|
||||
http_request: Request,
|
||||
api_key: Optional[str] = Header(None),
|
||||
):
|
||||
"""Start a new analysis task."""
|
||||
import uuid
|
||||
task_id = f"{request.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
|
||||
date = request.date or datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Check dashboard API key (opt-in auth)
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
|
||||
# Validate ANTHROPIC_API_KEY for the analysis subprocess
|
||||
anthropic_key = _get_analysis_api_key()
|
||||
if not anthropic_key:
|
||||
raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set")
|
||||
task_id = f"{payload.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
|
||||
date = payload.date or datetime.now().strftime("%Y-%m-%d")
|
||||
request_context = build_request_context(http_request, api_key=api_key)
|
||||
|
||||
# Initialize task state
|
||||
app.state.task_results[task_id] = {
|
||||
"task_id": task_id,
|
||||
"ticker": request.ticker,
|
||||
"date": date,
|
||||
"status": "running",
|
||||
"progress": 0,
|
||||
"current_stage": "analysts",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"elapsed": 0,
|
||||
"stages": [
|
||||
{"status": "running", "completed_at": None},
|
||||
{"status": "pending", "completed_at": None},
|
||||
{"status": "pending", "completed_at": None},
|
||||
{"status": "pending", "completed_at": None},
|
||||
{"status": "pending", "completed_at": None},
|
||||
],
|
||||
"logs": [],
|
||||
"decision": None,
|
||||
"quant_signal": None,
|
||||
"llm_signal": None,
|
||||
"confidence": None,
|
||||
"error": None,
|
||||
}
|
||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
||||
|
||||
# Write analysis script to temp file with restrictive permissions (avoids subprocess -c quoting issues)
|
||||
fd, script_path_str = tempfile.mkstemp(suffix=".py", prefix=f"analysis_{task_id}_")
|
||||
script_path = Path(script_path_str)
|
||||
os.chmod(script_path, 0o600)
|
||||
with os.fdopen(fd, "w") as f:
|
||||
f.write(ANALYSIS_SCRIPT_TEMPLATE)
|
||||
|
||||
# Store process reference for cancellation
|
||||
app.state.processes = getattr(app.state, 'processes', {})
|
||||
app.state.processes[task_id] = None
|
||||
|
||||
# Cancellation event for the monitor coroutine
|
||||
cancel_event = asyncio.Event()
|
||||
|
||||
# Stage name to index mapping
|
||||
STAGE_NAMES = ["analysts", "research", "trading", "risk", "portfolio"]
|
||||
|
||||
def _update_task_stage(stage_name: str):
|
||||
"""Update task state for a completed stage and mark next as running."""
|
||||
try:
|
||||
idx = STAGE_NAMES.index(stage_name)
|
||||
except ValueError:
|
||||
return
|
||||
# Mark all previous stages as completed
|
||||
for i in range(idx):
|
||||
if app.state.task_results[task_id]["stages"][i]["status"] != "completed":
|
||||
app.state.task_results[task_id]["stages"][i]["status"] = "completed"
|
||||
app.state.task_results[task_id]["stages"][i]["completed_at"] = datetime.now().strftime("%H:%M:%S")
|
||||
# Mark current as completed
|
||||
if app.state.task_results[task_id]["stages"][idx]["status"] != "completed":
|
||||
app.state.task_results[task_id]["stages"][idx]["status"] = "completed"
|
||||
app.state.task_results[task_id]["stages"][idx]["completed_at"] = datetime.now().strftime("%H:%M:%S")
|
||||
# Mark next as running
|
||||
if idx + 1 < 5:
|
||||
if app.state.task_results[task_id]["stages"][idx + 1]["status"] == "pending":
|
||||
app.state.task_results[task_id]["stages"][idx + 1]["status"] = "running"
|
||||
# Update progress
|
||||
app.state.task_results[task_id]["progress"] = int((idx + 1) / 5 * 100)
|
||||
app.state.task_results[task_id]["current_stage"] = stage_name
|
||||
|
||||
async def run_analysis():
|
||||
"""Run analysis subprocess and broadcast progress"""
|
||||
try:
|
||||
# Use clean environment - don't inherit parent env
|
||||
clean_env = {k: v for k, v in os.environ.items()
|
||||
if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))}
|
||||
clean_env["ANTHROPIC_API_KEY"] = anthropic_key
|
||||
clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
str(ANALYSIS_PYTHON),
|
||||
str(script_path),
|
||||
request.ticker,
|
||||
date,
|
||||
str(REPO_ROOT),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=clean_env,
|
||||
)
|
||||
app.state.processes[task_id] = proc
|
||||
|
||||
# Read stdout line-by-line for real-time stage updates
|
||||
stdout_lines = []
|
||||
while True:
|
||||
try:
|
||||
line_bytes = await asyncio.wait_for(proc.stdout.readline(), timeout=300.0)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if not line_bytes:
|
||||
break
|
||||
line = line_bytes.decode(errors="replace").rstrip()
|
||||
stdout_lines.append(line)
|
||||
if line.startswith("STAGE:"):
|
||||
stage = line.split(":", 1)[1].strip()
|
||||
_update_task_stage(stage)
|
||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
||||
if cancel_event.is_set():
|
||||
break
|
||||
|
||||
await proc.wait()
|
||||
stderr_bytes = await proc.stderr.read()
|
||||
|
||||
# Clean up script file
|
||||
try:
|
||||
script_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if proc.returncode == 0:
|
||||
output = "\n".join(stdout_lines)
|
||||
decision = "HOLD"
|
||||
for line in stdout_lines:
|
||||
if line.startswith("SIGNAL_DETAIL:"):
|
||||
try:
|
||||
detail = json.loads(line.split(":", 1)[1].strip())
|
||||
app.state.task_results[task_id]["quant_signal"] = detail.get("quant_signal")
|
||||
app.state.task_results[task_id]["llm_signal"] = detail.get("llm_signal")
|
||||
app.state.task_results[task_id]["confidence"] = detail.get("confidence")
|
||||
except Exception:
|
||||
pass
|
||||
if line.startswith("ANALYSIS_COMPLETE:"):
|
||||
decision = line.split(":", 1)[1].strip()
|
||||
|
||||
app.state.task_results[task_id]["status"] = "completed"
|
||||
app.state.task_results[task_id]["progress"] = 100
|
||||
app.state.task_results[task_id]["decision"] = decision
|
||||
app.state.task_results[task_id]["current_stage"] = "portfolio"
|
||||
for i in range(5):
|
||||
app.state.task_results[task_id]["stages"][i]["status"] = "completed"
|
||||
if not app.state.task_results[task_id]["stages"][i].get("completed_at"):
|
||||
app.state.task_results[task_id]["stages"][i]["completed_at"] = datetime.now().strftime("%H:%M:%S")
|
||||
else:
|
||||
error_msg = stderr_bytes.decode(errors="replace")[-1000:] if stderr_bytes else "Unknown error"
|
||||
app.state.task_results[task_id]["status"] = "failed"
|
||||
app.state.task_results[task_id]["error"] = error_msg
|
||||
|
||||
_save_task_status(task_id, app.state.task_results[task_id])
|
||||
|
||||
except Exception as e:
|
||||
cancel_event.set()
|
||||
app.state.task_results[task_id]["status"] = "failed"
|
||||
app.state.task_results[task_id]["error"] = str(e)
|
||||
try:
|
||||
script_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_save_task_status(task_id, app.state.task_results[task_id])
|
||||
|
||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
||||
|
||||
task = asyncio.create_task(run_analysis())
|
||||
app.state.analysis_tasks[task_id] = task
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"ticker": request.ticker,
|
||||
"date": date,
|
||||
"status": "running",
|
||||
}
|
||||
try:
|
||||
return await app.state.analysis_service.start_analysis(
|
||||
task_id=task_id,
|
||||
ticker=payload.ticker,
|
||||
date=date,
|
||||
request_context=request_context,
|
||||
broadcast_progress=broadcast_progress,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@app.get("/api/analysis/status/{task_id}")
|
||||
|
|
@ -600,13 +329,12 @@ async def list_tasks(api_key: Optional[str] = Header(None)):
|
|||
|
||||
@app.delete("/api/analysis/cancel/{task_id}")
|
||||
async def cancel_task(task_id: str, api_key: Optional[str] = Header(None)):
|
||||
"""Cancel a running task"""
|
||||
"""Cancel a running task."""
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
if task_id not in app.state.task_results:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Kill the subprocess if it's still running
|
||||
proc = app.state.processes.get(task_id)
|
||||
if proc and proc.returncode is None:
|
||||
try:
|
||||
|
|
@ -614,26 +342,18 @@ async def cancel_task(task_id: str, api_key: Optional[str] = Header(None)):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
# Cancel the asyncio task
|
||||
task = app.state.analysis_tasks.get(task_id)
|
||||
if task:
|
||||
task.cancel()
|
||||
app.state.task_results[task_id]["status"] = "failed"
|
||||
app.state.task_results[task_id]["error"] = "用户取消"
|
||||
_save_task_status(task_id, app.state.task_results[task_id])
|
||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
||||
|
||||
# Clean up temp script (may use tempfile.mkstemp with random suffix)
|
||||
for p in Path("/tmp").glob(f"analysis_{task_id}_*.py"):
|
||||
try:
|
||||
p.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
state = app.state.task_results[task_id]
|
||||
state["status"] = "cancelled"
|
||||
state["error"] = "用户取消"
|
||||
app.state.result_store.save_task_status(task_id, state)
|
||||
await broadcast_progress(task_id, state)
|
||||
app.state.result_store.delete_task_status(task_id)
|
||||
|
||||
# Remove persisted task state
|
||||
_delete_task_status(task_id)
|
||||
|
||||
return {"task_id": task_id, "status": "cancelled"}
|
||||
return {"contract_version": "v1alpha1", "task_id": task_id, "status": "cancelled"}
|
||||
|
||||
|
||||
# ============== WebSocket ==============
|
||||
|
|
@ -1091,169 +811,31 @@ async def get_recommendation_endpoint(date: str, ticker: str, api_key: Optional[
|
|||
# --- Batch Analysis ---
|
||||
|
||||
@app.post("/api/portfolio/analyze")
|
||||
async def start_portfolio_analysis(api_key: Optional[str] = Header(None)):
|
||||
"""
|
||||
Trigger batch analysis for all watchlist tickers.
|
||||
Runs serially, streaming progress via WebSocket (task_id prefixed with 'port_').
|
||||
"""
|
||||
async def start_portfolio_analysis(
|
||||
http_request: Request,
|
||||
api_key: Optional[str] = Header(None),
|
||||
):
|
||||
"""Trigger batch analysis for all watchlist tickers."""
|
||||
if not _check_api_key(api_key):
|
||||
_auth_error()
|
||||
|
||||
import uuid
|
||||
|
||||
date = datetime.now().strftime("%Y-%m-%d")
|
||||
task_id = f"port_{date}_{uuid.uuid4().hex[:6]}"
|
||||
request_context = build_request_context(http_request, api_key=api_key)
|
||||
|
||||
if app.state.migration_flags.use_application_services:
|
||||
request_context = build_request_context(api_key=api_key)
|
||||
try:
|
||||
return await app.state.analysis_service.start_portfolio_analysis(
|
||||
task_id=task_id,
|
||||
date=date,
|
||||
request_context=request_context,
|
||||
broadcast_progress=broadcast_progress,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
watchlist = get_watchlist()
|
||||
if not watchlist:
|
||||
raise HTTPException(status_code=400, detail="自选股为空,请先添加股票")
|
||||
|
||||
total = len(watchlist)
|
||||
app.state.task_results[task_id] = {
|
||||
"task_id": task_id,
|
||||
"type": "portfolio",
|
||||
"status": "running",
|
||||
"total": total,
|
||||
"completed": 0,
|
||||
"failed": 0,
|
||||
"current_ticker": None,
|
||||
"results": [],
|
||||
"error": None,
|
||||
}
|
||||
|
||||
api_key = _get_analysis_api_key()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set")
|
||||
|
||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
||||
|
||||
async def run_portfolio_analysis():
|
||||
max_retries = MAX_RETRY_COUNT
|
||||
|
||||
async def run_single_analysis(ticker: str, stock: dict) -> tuple[bool, str, dict | None]:
|
||||
"""Run analysis for one ticker. Returns (success, decision, rec_or_error)."""
|
||||
last_error = None
|
||||
for attempt in range(max_retries + 1):
|
||||
script_path = None
|
||||
try:
|
||||
fd, script_path_str = tempfile.mkstemp(suffix=".py", prefix=f"analysis_{task_id}_{stock['_idx']}_")
|
||||
script_path = Path(script_path_str)
|
||||
os.chmod(script_path, 0o600)
|
||||
with os.fdopen(fd, "w") as f:
|
||||
f.write(ANALYSIS_SCRIPT_TEMPLATE)
|
||||
|
||||
clean_env = {k: v for k, v in os.environ.items()
|
||||
if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))}
|
||||
clean_env["ANTHROPIC_API_KEY"] = api_key
|
||||
clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
str(ANALYSIS_PYTHON), str(script_path), ticker, date, str(REPO_ROOT),
|
||||
stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
|
||||
env=clean_env,
|
||||
)
|
||||
app.state.processes[task_id] = proc
|
||||
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
try:
|
||||
script_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if proc.returncode == 0:
|
||||
output = stdout.decode()
|
||||
decision = "HOLD"
|
||||
quant_signal = None
|
||||
llm_signal = None
|
||||
confidence = None
|
||||
for line in output.splitlines():
|
||||
if line.startswith("SIGNAL_DETAIL:"):
|
||||
try:
|
||||
detail = json.loads(line.split(":", 1)[1].strip())
|
||||
quant_signal = detail.get("quant_signal")
|
||||
llm_signal = detail.get("llm_signal")
|
||||
confidence = detail.get("confidence")
|
||||
except Exception:
|
||||
pass
|
||||
if line.startswith("ANALYSIS_COMPLETE:"):
|
||||
decision = line.split(":", 1)[1].strip()
|
||||
rec = {
|
||||
"ticker": ticker,
|
||||
"name": stock.get("name", ticker),
|
||||
"analysis_date": date,
|
||||
"decision": decision,
|
||||
"quant_signal": quant_signal,
|
||||
"llm_signal": llm_signal,
|
||||
"confidence": confidence,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
save_recommendation(date, ticker, rec)
|
||||
return True, decision, rec
|
||||
else:
|
||||
last_error = stderr.decode()[-500:] if stderr else f"exit {proc.returncode}"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
finally:
|
||||
if script_path:
|
||||
try:
|
||||
script_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(RETRY_BASE_DELAY_SECS ** attempt) # exponential backoff: 1s, 2s
|
||||
|
||||
return False, "HOLD", None
|
||||
|
||||
try:
|
||||
for i, stock in enumerate(watchlist):
|
||||
stock["_idx"] = i # used in temp file name
|
||||
ticker = stock["ticker"]
|
||||
app.state.task_results[task_id]["current_ticker"] = ticker
|
||||
app.state.task_results[task_id]["status"] = "running"
|
||||
app.state.task_results[task_id]["completed"] = i
|
||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
||||
|
||||
success, decision, rec = await run_single_analysis(ticker, stock)
|
||||
if success:
|
||||
app.state.task_results[task_id]["completed"] = i + 1
|
||||
app.state.task_results[task_id]["results"].append(rec)
|
||||
else:
|
||||
app.state.task_results[task_id]["failed"] += 1
|
||||
|
||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
||||
|
||||
app.state.task_results[task_id]["status"] = "completed"
|
||||
app.state.task_results[task_id]["current_ticker"] = None
|
||||
_save_task_status(task_id, app.state.task_results[task_id])
|
||||
|
||||
except Exception as e:
|
||||
app.state.task_results[task_id]["status"] = "failed"
|
||||
app.state.task_results[task_id]["error"] = str(e)
|
||||
_save_task_status(task_id, app.state.task_results[task_id])
|
||||
|
||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
||||
|
||||
task = asyncio.create_task(run_portfolio_analysis())
|
||||
app.state.analysis_tasks[task_id] = task
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"total": total,
|
||||
"status": "running",
|
||||
}
|
||||
try:
|
||||
return await app.state.analysis_service.start_portfolio_analysis(
|
||||
task_id=task_id,
|
||||
date=date,
|
||||
request_context=request_context,
|
||||
broadcast_progress=broadcast_progress,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,15 +2,15 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from .executor import AnalysisExecutionOutput, AnalysisExecutor, AnalysisExecutorError
|
||||
from .request_context import RequestContext
|
||||
|
||||
BroadcastFn = Callable[[str, dict], Awaitable[None]]
|
||||
ANALYSIS_STAGE_NAMES = ["analysts", "research", "trading", "risk", "portfolio"]
|
||||
|
||||
|
||||
class AnalysisService:
|
||||
|
|
@ -19,24 +19,56 @@ class AnalysisService:
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
analysis_python: Path,
|
||||
repo_root: Path,
|
||||
analysis_script_template: str,
|
||||
api_key_resolver: Callable[[], Optional[str]],
|
||||
executor: AnalysisExecutor,
|
||||
result_store,
|
||||
job_service,
|
||||
retry_count: int = 2,
|
||||
retry_base_delay_secs: int = 1,
|
||||
):
|
||||
self.analysis_python = analysis_python
|
||||
self.repo_root = repo_root
|
||||
self.analysis_script_template = analysis_script_template
|
||||
self.api_key_resolver = api_key_resolver
|
||||
self.executor = executor
|
||||
self.result_store = result_store
|
||||
self.job_service = job_service
|
||||
self.retry_count = retry_count
|
||||
self.retry_base_delay_secs = retry_base_delay_secs
|
||||
|
||||
async def start_analysis(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
ticker: str,
|
||||
date: str,
|
||||
request_context: RequestContext,
|
||||
broadcast_progress: BroadcastFn,
|
||||
) -> dict:
|
||||
state = self.job_service.create_analysis_job(
|
||||
task_id=task_id,
|
||||
ticker=ticker,
|
||||
date=date,
|
||||
request_id=request_context.request_id,
|
||||
executor_type=request_context.executor_type,
|
||||
contract_version=request_context.contract_version,
|
||||
)
|
||||
self.job_service.register_process(task_id, None)
|
||||
await broadcast_progress(task_id, state)
|
||||
|
||||
task = asyncio.create_task(
|
||||
self._run_analysis(
|
||||
task_id=task_id,
|
||||
ticker=ticker,
|
||||
date=date,
|
||||
request_context=request_context,
|
||||
broadcast_progress=broadcast_progress,
|
||||
)
|
||||
)
|
||||
self.job_service.register_background_task(task_id, task)
|
||||
return {
|
||||
"contract_version": "v1alpha1",
|
||||
"task_id": task_id,
|
||||
"ticker": ticker,
|
||||
"date": date,
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
async def start_portfolio_analysis(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -45,16 +77,17 @@ class AnalysisService:
|
|||
request_context: RequestContext,
|
||||
broadcast_progress: BroadcastFn,
|
||||
) -> dict:
|
||||
del request_context # Reserved for future auditing/auth propagation.
|
||||
watchlist = self.result_store.get_watchlist()
|
||||
if not watchlist:
|
||||
raise ValueError("自选股为空,请先添加股票")
|
||||
|
||||
analysis_api_key = self.api_key_resolver()
|
||||
if not analysis_api_key:
|
||||
raise RuntimeError("ANTHROPIC_API_KEY environment variable not set")
|
||||
|
||||
state = self.job_service.create_portfolio_job(task_id=task_id, total=len(watchlist))
|
||||
state = self.job_service.create_portfolio_job(
|
||||
task_id=task_id,
|
||||
total=len(watchlist),
|
||||
request_id=request_context.request_id,
|
||||
executor_type=request_context.executor_type,
|
||||
contract_version=request_context.contract_version,
|
||||
)
|
||||
await broadcast_progress(task_id, state)
|
||||
|
||||
task = asyncio.create_task(
|
||||
|
|
@ -62,24 +95,111 @@ class AnalysisService:
|
|||
task_id=task_id,
|
||||
date=date,
|
||||
watchlist=watchlist,
|
||||
analysis_api_key=analysis_api_key,
|
||||
request_context=request_context,
|
||||
broadcast_progress=broadcast_progress,
|
||||
)
|
||||
)
|
||||
self.job_service.register_background_task(task_id, task)
|
||||
return {
|
||||
"contract_version": "v1alpha1",
|
||||
"task_id": task_id,
|
||||
"total": len(watchlist),
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
async def _run_analysis(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
ticker: str,
|
||||
date: str,
|
||||
request_context: RequestContext,
|
||||
broadcast_progress: BroadcastFn,
|
||||
) -> None:
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
output = await self.executor.execute(
|
||||
task_id=task_id,
|
||||
ticker=ticker,
|
||||
date=date,
|
||||
request_context=request_context,
|
||||
on_stage=lambda stage: self._handle_analysis_stage(
|
||||
task_id=task_id,
|
||||
stage_name=stage,
|
||||
started_at=start_time,
|
||||
broadcast_progress=broadcast_progress,
|
||||
),
|
||||
)
|
||||
state = self.job_service.task_results[task_id]
|
||||
elapsed_seconds = int(time.monotonic() - start_time)
|
||||
contract = output.to_result_contract(
|
||||
task_id=task_id,
|
||||
ticker=ticker,
|
||||
date=date,
|
||||
created_at=state["created_at"],
|
||||
elapsed_seconds=elapsed_seconds,
|
||||
current_stage=ANALYSIS_STAGE_NAMES[-1],
|
||||
)
|
||||
result_ref = self.result_store.save_result_contract(task_id, contract)
|
||||
self.job_service.complete_analysis_job(
|
||||
task_id,
|
||||
contract=contract,
|
||||
result_ref=result_ref,
|
||||
executor_type=request_context.executor_type,
|
||||
)
|
||||
except AnalysisExecutorError as exc:
|
||||
self._fail_analysis_state(
|
||||
task_id=task_id,
|
||||
message=str(exc),
|
||||
started_at=start_time,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._fail_analysis_state(
|
||||
task_id=task_id,
|
||||
message=str(exc),
|
||||
started_at=start_time,
|
||||
)
|
||||
|
||||
await broadcast_progress(task_id, self.job_service.task_results[task_id])
|
||||
|
||||
async def _handle_analysis_stage(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
stage_name: str,
|
||||
started_at: float,
|
||||
broadcast_progress: BroadcastFn,
|
||||
) -> None:
|
||||
state = self.job_service.task_results[task_id]
|
||||
try:
|
||||
idx = ANALYSIS_STAGE_NAMES.index(stage_name)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
for i, entry in enumerate(state["stages"]):
|
||||
if i < idx:
|
||||
if entry["status"] != "completed":
|
||||
entry["status"] = "completed"
|
||||
entry["completed_at"] = datetime.now().strftime("%H:%M:%S")
|
||||
elif i == idx:
|
||||
entry["status"] = "completed"
|
||||
entry["completed_at"] = entry["completed_at"] or datetime.now().strftime("%H:%M:%S")
|
||||
elif i == idx + 1 and entry["status"] == "pending":
|
||||
entry["status"] = "running"
|
||||
|
||||
state["progress"] = int((idx + 1) / len(ANALYSIS_STAGE_NAMES) * 100)
|
||||
state["current_stage"] = stage_name
|
||||
state["elapsed_seconds"] = int(time.monotonic() - started_at)
|
||||
state["elapsed"] = state["elapsed_seconds"]
|
||||
await broadcast_progress(task_id, state)
|
||||
|
||||
async def _run_portfolio_analysis(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
date: str,
|
||||
watchlist: list[dict],
|
||||
analysis_api_key: str,
|
||||
request_context: RequestContext,
|
||||
broadcast_progress: BroadcastFn,
|
||||
) -> None:
|
||||
try:
|
||||
|
|
@ -96,7 +216,7 @@ class AnalysisService:
|
|||
ticker=ticker,
|
||||
stock=stock,
|
||||
date=date,
|
||||
analysis_api_key=analysis_api_key,
|
||||
request_context=request_context,
|
||||
)
|
||||
if success and rec is not None:
|
||||
self.job_service.append_portfolio_result(task_id, rec)
|
||||
|
|
@ -118,61 +238,27 @@ class AnalysisService:
|
|||
ticker: str,
|
||||
stock: dict,
|
||||
date: str,
|
||||
analysis_api_key: str,
|
||||
request_context: RequestContext,
|
||||
) -> tuple[bool, Optional[dict]]:
|
||||
last_error: Optional[str] = None
|
||||
for attempt in range(self.retry_count + 1):
|
||||
script_path: Optional[Path] = None
|
||||
try:
|
||||
fd, script_path_str = tempfile.mkstemp(
|
||||
suffix=".py",
|
||||
prefix=f"analysis_{task_id}_{stock['_idx']}_",
|
||||
output = await self.executor.execute(
|
||||
task_id=f"{task_id}_{stock['_idx']}",
|
||||
ticker=ticker,
|
||||
date=date,
|
||||
request_context=request_context,
|
||||
)
|
||||
script_path = Path(script_path_str)
|
||||
os.chmod(script_path, 0o600)
|
||||
with os.fdopen(fd, "w") as handle:
|
||||
handle.write(self.analysis_script_template)
|
||||
|
||||
clean_env = {
|
||||
key: value
|
||||
for key, value in os.environ.items()
|
||||
if not key.startswith(("PYTHON", "CONDA", "VIRTUAL"))
|
||||
}
|
||||
clean_env["ANTHROPIC_API_KEY"] = analysis_api_key
|
||||
clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
str(self.analysis_python),
|
||||
str(script_path),
|
||||
ticker,
|
||||
date,
|
||||
str(self.repo_root),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=clean_env,
|
||||
rec = self._build_recommendation_record(
|
||||
output=output,
|
||||
ticker=ticker,
|
||||
stock=stock,
|
||||
date=date,
|
||||
)
|
||||
self.job_service.register_process(task_id, proc)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
if proc.returncode == 0:
|
||||
rec = self._build_recommendation_record(
|
||||
stdout=stdout.decode(),
|
||||
ticker=ticker,
|
||||
stock=stock,
|
||||
date=date,
|
||||
)
|
||||
self.result_store.save_recommendation(date, ticker, rec)
|
||||
return True, rec
|
||||
|
||||
last_error = stderr.decode()[-500:] if stderr else f"exit {proc.returncode}"
|
||||
self.result_store.save_recommendation(date, ticker, rec)
|
||||
return True, rec
|
||||
except Exception as exc:
|
||||
last_error = str(exc)
|
||||
finally:
|
||||
if script_path is not None:
|
||||
try:
|
||||
script_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if attempt < self.retry_count:
|
||||
await asyncio.sleep(self.retry_base_delay_secs ** attempt)
|
||||
|
|
@ -181,23 +267,45 @@ class AnalysisService:
|
|||
self.job_service.task_results[task_id]["last_error"] = last_error
|
||||
return False, None
|
||||
|
||||
def _fail_analysis_state(self, *, task_id: str, message: str, started_at: float) -> None:
|
||||
state = self.job_service.task_results[task_id]
|
||||
state["status"] = "failed"
|
||||
state["elapsed_seconds"] = int(time.monotonic() - started_at)
|
||||
state["elapsed"] = state["elapsed_seconds"]
|
||||
state["result"] = None
|
||||
state["error"] = message
|
||||
self.result_store.save_task_status(task_id, state)
|
||||
|
||||
@staticmethod
|
||||
def _build_recommendation_record(*, stdout: str, ticker: str, stock: dict, date: str) -> dict:
|
||||
decision = "HOLD"
|
||||
quant_signal = None
|
||||
llm_signal = None
|
||||
confidence = None
|
||||
for line in stdout.splitlines():
|
||||
if line.startswith("SIGNAL_DETAIL:"):
|
||||
try:
|
||||
detail = json.loads(line.split(":", 1)[1].strip())
|
||||
except Exception:
|
||||
continue
|
||||
quant_signal = detail.get("quant_signal")
|
||||
llm_signal = detail.get("llm_signal")
|
||||
confidence = detail.get("confidence")
|
||||
if line.startswith("ANALYSIS_COMPLETE:"):
|
||||
decision = line.split(":", 1)[1].strip()
|
||||
def _build_recommendation_record(
|
||||
*,
|
||||
ticker: str,
|
||||
stock: dict,
|
||||
date: str,
|
||||
output: AnalysisExecutionOutput | None = None,
|
||||
stdout: str | None = None,
|
||||
) -> dict:
|
||||
if output is not None:
|
||||
decision = output.decision
|
||||
quant_signal = output.quant_signal
|
||||
llm_signal = output.llm_signal
|
||||
confidence = output.confidence
|
||||
else:
|
||||
decision = "HOLD"
|
||||
quant_signal = None
|
||||
llm_signal = None
|
||||
confidence = None
|
||||
for line in (stdout or "").splitlines():
|
||||
if line.startswith("SIGNAL_DETAIL:"):
|
||||
try:
|
||||
detail = json.loads(line.split(":", 1)[1].strip())
|
||||
except Exception:
|
||||
continue
|
||||
quant_signal = detail.get("quant_signal")
|
||||
llm_signal = detail.get("llm_signal")
|
||||
confidence = detail.get("confidence")
|
||||
if line.startswith("ANALYSIS_COMPLETE:"):
|
||||
decision = line.split(":", 1)[1].strip()
|
||||
|
||||
return {
|
||||
"ticker": ticker,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,381 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Awaitable, Callable, Optional, Protocol
|
||||
|
||||
from .request_context import (
|
||||
CONTRACT_VERSION,
|
||||
DEFAULT_EXECUTOR_TYPE,
|
||||
RequestContext,
|
||||
)
|
||||
|
||||
StageCallback = Callable[[str], Awaitable[None]]
|
||||
ProcessRegistry = Callable[[str, asyncio.subprocess.Process | None], None]
|
||||
|
||||
LEGACY_ANALYSIS_SCRIPT_TEMPLATE = """
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ticker = sys.argv[1]
|
||||
date = sys.argv[2]
|
||||
repo_root = sys.argv[3]
|
||||
|
||||
sys.path.insert(0, repo_root)
|
||||
|
||||
import py_mini_racer
|
||||
sys.modules["mini_racer"] = py_mini_racer
|
||||
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
from orchestrator.orchestrator import TradingOrchestrator
|
||||
from tradingagents.default_config import get_default_config
|
||||
|
||||
trading_config = get_default_config()
|
||||
trading_config["project_dir"] = os.path.join(repo_root, "tradingagents")
|
||||
trading_config["results_dir"] = os.path.join(repo_root, "results")
|
||||
trading_config["max_debate_rounds"] = 1
|
||||
trading_config["max_risk_discuss_rounds"] = 1
|
||||
|
||||
print("STAGE:analysts", flush=True)
|
||||
print("STAGE:research", flush=True)
|
||||
|
||||
config = OrchestratorConfig(
|
||||
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
|
||||
trading_agents_config=trading_config,
|
||||
)
|
||||
|
||||
orchestrator = TradingOrchestrator(config)
|
||||
|
||||
print("STAGE:trading", flush=True)
|
||||
|
||||
try:
|
||||
result = orchestrator.get_combined_signal(ticker, date)
|
||||
except ValueError as exc:
|
||||
print("ANALYSIS_ERROR:" + str(exc), file=sys.stderr, flush=True)
|
||||
sys.exit(1)
|
||||
|
||||
print("STAGE:risk", flush=True)
|
||||
|
||||
direction = result.direction
|
||||
confidence = result.confidence
|
||||
llm_sig_obj = result.llm_signal
|
||||
quant_sig_obj = result.quant_signal
|
||||
llm_signal = llm_sig_obj.metadata.get("rating", "HOLD") if llm_sig_obj else "HOLD"
|
||||
if quant_sig_obj is None:
|
||||
quant_signal = "HOLD"
|
||||
elif quant_sig_obj.direction == 1:
|
||||
quant_signal = "BUY" if quant_sig_obj.confidence >= 0.7 else "OVERWEIGHT"
|
||||
elif quant_sig_obj.direction == -1:
|
||||
quant_signal = "SELL" if quant_sig_obj.confidence >= 0.7 else "UNDERWEIGHT"
|
||||
else:
|
||||
quant_signal = "HOLD"
|
||||
|
||||
if direction == 1:
|
||||
signal = "BUY" if confidence >= 0.7 else "OVERWEIGHT"
|
||||
elif direction == -1:
|
||||
signal = "SELL" if confidence >= 0.7 else "UNDERWEIGHT"
|
||||
else:
|
||||
signal = "HOLD"
|
||||
|
||||
results_dir = Path(repo_root) / "results" / ticker / date
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
report_content = (
|
||||
"# TradingAgents 分析报告\\n\\n"
|
||||
"**股票**: " + ticker + "\\n"
|
||||
"**日期**: " + date + "\\n\\n"
|
||||
"## 最终决策\\n\\n"
|
||||
"**" + signal + "**\\n\\n"
|
||||
"## 信号详情\\n\\n"
|
||||
"- LLM 信号: " + llm_signal + "\\n"
|
||||
"- Quant 信号: " + quant_signal + "\\n"
|
||||
"- 置信度: " + f"{confidence:.1%}" + "\\n\\n"
|
||||
"## 分析摘要\\n\\n"
|
||||
"N/A\\n"
|
||||
)
|
||||
|
||||
report_path = results_dir / "complete_report.md"
|
||||
report_path.write_text(report_content)
|
||||
|
||||
print("STAGE:portfolio", flush=True)
|
||||
signal_detail = json.dumps({"llm_signal": llm_signal, "quant_signal": quant_signal, "confidence": confidence})
|
||||
print("SIGNAL_DETAIL:" + signal_detail, flush=True)
|
||||
print("ANALYSIS_COMPLETE:" + signal, flush=True)
|
||||
"""
|
||||
|
||||
|
||||
def _rating_to_direction(rating: Optional[str]) -> int:
|
||||
if rating in {"BUY", "OVERWEIGHT"}:
|
||||
return 1
|
||||
if rating in {"SELL", "UNDERWEIGHT"}:
|
||||
return -1
|
||||
return 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AnalysisExecutionOutput:
|
||||
decision: str
|
||||
quant_signal: Optional[str]
|
||||
llm_signal: Optional[str]
|
||||
confidence: Optional[float]
|
||||
report_path: Optional[str] = None
|
||||
contract_version: str = CONTRACT_VERSION
|
||||
executor_type: str = DEFAULT_EXECUTOR_TYPE
|
||||
|
||||
def to_result_contract(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
ticker: str,
|
||||
date: str,
|
||||
created_at: str,
|
||||
elapsed_seconds: int,
|
||||
current_stage: str = "portfolio",
|
||||
) -> dict:
|
||||
return {
|
||||
"contract_version": self.contract_version,
|
||||
"task_id": task_id,
|
||||
"ticker": ticker,
|
||||
"date": date,
|
||||
"status": "completed",
|
||||
"progress": 100,
|
||||
"current_stage": current_stage,
|
||||
"created_at": created_at,
|
||||
"elapsed_seconds": elapsed_seconds,
|
||||
"elapsed": elapsed_seconds,
|
||||
"result": {
|
||||
"decision": self.decision,
|
||||
"confidence": self.confidence,
|
||||
"signals": {
|
||||
"merged": {
|
||||
"direction": _rating_to_direction(self.decision),
|
||||
"rating": self.decision,
|
||||
},
|
||||
"quant": {
|
||||
"direction": _rating_to_direction(self.quant_signal),
|
||||
"rating": self.quant_signal,
|
||||
"available": self.quant_signal is not None,
|
||||
},
|
||||
"llm": {
|
||||
"direction": _rating_to_direction(self.llm_signal),
|
||||
"rating": self.llm_signal,
|
||||
"available": self.llm_signal is not None,
|
||||
},
|
||||
},
|
||||
"degraded": self.quant_signal is None or self.llm_signal is None,
|
||||
"report": {
|
||||
"path": self.report_path,
|
||||
"available": bool(self.report_path),
|
||||
},
|
||||
},
|
||||
"error": None,
|
||||
}
|
||||
|
||||
|
||||
class AnalysisExecutorError(RuntimeError):
|
||||
def __init__(self, message: str, *, code: str = "analysis_failed", retryable: bool = False):
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
self.retryable = retryable
|
||||
|
||||
|
||||
class AnalysisExecutor(Protocol):
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
ticker: str,
|
||||
date: str,
|
||||
request_context: RequestContext,
|
||||
on_stage: Optional[StageCallback] = None,
|
||||
) -> AnalysisExecutionOutput: ...
|
||||
|
||||
|
||||
class LegacySubprocessAnalysisExecutor:
|
||||
"""Run the legacy dashboard analysis script behind a stable executor contract."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
analysis_python: Path,
|
||||
repo_root: Path,
|
||||
api_key_resolver: Callable[[], Optional[str]],
|
||||
process_registry: Optional[ProcessRegistry] = None,
|
||||
script_template: str = LEGACY_ANALYSIS_SCRIPT_TEMPLATE,
|
||||
stdout_timeout_secs: float = 300.0,
|
||||
):
|
||||
self.analysis_python = analysis_python
|
||||
self.repo_root = repo_root
|
||||
self.api_key_resolver = api_key_resolver
|
||||
self.process_registry = process_registry
|
||||
self.script_template = script_template
|
||||
self.stdout_timeout_secs = stdout_timeout_secs
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
ticker: str,
|
||||
date: str,
|
||||
request_context: RequestContext,
|
||||
on_stage: Optional[StageCallback] = None,
|
||||
) -> AnalysisExecutionOutput:
|
||||
analysis_api_key = request_context.api_key or self.api_key_resolver()
|
||||
if not analysis_api_key:
|
||||
raise RuntimeError("ANTHROPIC_API_KEY environment variable not set")
|
||||
|
||||
script_path: Optional[Path] = None
|
||||
proc: asyncio.subprocess.Process | None = None
|
||||
try:
|
||||
fd, script_path_str = tempfile.mkstemp(suffix=".py", prefix=f"analysis_{task_id}_")
|
||||
script_path = Path(script_path_str)
|
||||
os.chmod(script_path, 0o600)
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as handle:
|
||||
handle.write(self.script_template)
|
||||
|
||||
clean_env = {
|
||||
key: value
|
||||
for key, value in os.environ.items()
|
||||
if not key.startswith(("PYTHON", "CONDA", "VIRTUAL"))
|
||||
}
|
||||
clean_env["ANTHROPIC_API_KEY"] = analysis_api_key
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
str(self.analysis_python),
|
||||
str(script_path),
|
||||
ticker,
|
||||
date,
|
||||
str(self.repo_root),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=clean_env,
|
||||
)
|
||||
if self.process_registry is not None:
|
||||
self.process_registry(task_id, proc)
|
||||
|
||||
stdout_lines: list[str] = []
|
||||
assert proc.stdout is not None
|
||||
while True:
|
||||
try:
|
||||
line_bytes = await asyncio.wait_for(
|
||||
proc.stdout.readline(),
|
||||
timeout=self.stdout_timeout_secs,
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
await self._terminate_process(proc)
|
||||
raise AnalysisExecutorError(
|
||||
f"analysis subprocess timed out after {self.stdout_timeout_secs:g}s",
|
||||
retryable=True,
|
||||
) from exc
|
||||
if not line_bytes:
|
||||
break
|
||||
line = line_bytes.decode(errors="replace").rstrip()
|
||||
stdout_lines.append(line)
|
||||
if on_stage is not None and line.startswith("STAGE:"):
|
||||
await on_stage(line.split(":", 1)[1].strip())
|
||||
|
||||
await proc.wait()
|
||||
stderr_bytes = await proc.stderr.read() if proc.stderr is not None else b""
|
||||
if proc.returncode != 0:
|
||||
message = stderr_bytes.decode(errors="replace")[-1000:] if stderr_bytes else f"exit {proc.returncode}"
|
||||
raise AnalysisExecutorError(message)
|
||||
|
||||
return self._parse_output(
|
||||
stdout_lines=stdout_lines,
|
||||
ticker=ticker,
|
||||
date=date,
|
||||
contract_version=request_context.contract_version,
|
||||
executor_type=request_context.executor_type,
|
||||
)
|
||||
finally:
|
||||
if self.process_registry is not None:
|
||||
self.process_registry(task_id, None)
|
||||
if script_path is not None:
|
||||
try:
|
||||
script_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def _terminate_process(proc: asyncio.subprocess.Process) -> None:
|
||||
if proc.returncode is not None:
|
||||
return
|
||||
try:
|
||||
proc.kill()
|
||||
except ProcessLookupError:
|
||||
return
|
||||
await proc.wait()
|
||||
|
||||
@staticmethod
|
||||
def _parse_output(
|
||||
*,
|
||||
stdout_lines: list[str],
|
||||
ticker: str,
|
||||
date: str,
|
||||
contract_version: str,
|
||||
executor_type: str,
|
||||
) -> AnalysisExecutionOutput:
|
||||
decision: Optional[str] = None
|
||||
quant_signal = None
|
||||
llm_signal = None
|
||||
confidence = None
|
||||
seen_signal_detail = False
|
||||
seen_complete = False
|
||||
|
||||
for line in stdout_lines:
|
||||
if line.startswith("SIGNAL_DETAIL:"):
|
||||
seen_signal_detail = True
|
||||
try:
|
||||
detail = json.loads(line.split(":", 1)[1].strip())
|
||||
except Exception as exc:
|
||||
raise AnalysisExecutorError("failed to parse SIGNAL_DETAIL payload") from exc
|
||||
quant_signal = detail.get("quant_signal")
|
||||
llm_signal = detail.get("llm_signal")
|
||||
confidence = detail.get("confidence")
|
||||
elif line.startswith("ANALYSIS_COMPLETE:"):
|
||||
seen_complete = True
|
||||
decision = line.split(":", 1)[1].strip()
|
||||
|
||||
missing_markers = []
|
||||
if not seen_signal_detail:
|
||||
missing_markers.append("SIGNAL_DETAIL")
|
||||
if not seen_complete:
|
||||
missing_markers.append("ANALYSIS_COMPLETE")
|
||||
if missing_markers:
|
||||
raise AnalysisExecutorError(
|
||||
"analysis subprocess completed without required markers: "
|
||||
+ ", ".join(missing_markers)
|
||||
)
|
||||
|
||||
report_path = str(Path("results") / ticker / date / "complete_report.md")
|
||||
return AnalysisExecutionOutput(
|
||||
decision=decision or "HOLD",
|
||||
quant_signal=quant_signal,
|
||||
llm_signal=llm_signal,
|
||||
confidence=confidence,
|
||||
report_path=report_path,
|
||||
contract_version=contract_version,
|
||||
executor_type=executor_type,
|
||||
)
|
||||
|
||||
|
||||
class DirectAnalysisExecutor:
|
||||
"""Placeholder for a future in-process executor implementation."""
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
ticker: str,
|
||||
date: str,
|
||||
request_context: RequestContext,
|
||||
on_stage: Optional[StageCallback] = None,
|
||||
) -> AnalysisExecutionOutput:
|
||||
del task_id, ticker, date, request_context, on_stage
|
||||
raise NotImplementedError("DirectAnalysisExecutor is not implemented in phase 1")
|
||||
|
|
@ -5,6 +5,10 @@ from datetime import datetime
|
|||
from typing import Any, Callable
|
||||
|
||||
|
||||
CONTRACT_VERSION = "v1alpha1"
|
||||
DEFAULT_EXECUTOR_TYPE = "legacy_subprocess"
|
||||
|
||||
|
||||
class JobService:
|
||||
"""Application-layer job state orchestrator with legacy-compatible payloads."""
|
||||
|
||||
|
|
@ -24,10 +28,71 @@ class JobService:
|
|||
self.delete_task = delete_task
|
||||
|
||||
def restore_task_results(self, restored: dict[str, dict]) -> None:
|
||||
self.task_results.update(restored)
|
||||
self.task_results.update(
|
||||
{
|
||||
task_id: self._normalize_task_state(task_id, state)
|
||||
for task_id, state in restored.items()
|
||||
}
|
||||
)
|
||||
|
||||
def create_portfolio_job(self, *, task_id: str, total: int) -> dict:
|
||||
state = {
|
||||
def create_analysis_job(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
ticker: str,
|
||||
date: str,
|
||||
request_id: str | None = None,
|
||||
executor_type: str = DEFAULT_EXECUTOR_TYPE,
|
||||
contract_version: str = CONTRACT_VERSION,
|
||||
result_ref: str | None = None,
|
||||
) -> dict:
|
||||
state = self._normalize_task_state(task_id, {
|
||||
"task_id": task_id,
|
||||
"ticker": ticker,
|
||||
"date": date,
|
||||
"status": "running",
|
||||
"progress": 0,
|
||||
"current_stage": "analysts",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"elapsed_seconds": 0,
|
||||
"elapsed": 0,
|
||||
"stages": [
|
||||
{
|
||||
"name": stage_name,
|
||||
"status": "running" if index == 0 else "pending",
|
||||
"completed_at": None,
|
||||
}
|
||||
for index, stage_name in enumerate(
|
||||
["analysts", "research", "trading", "risk", "portfolio"]
|
||||
)
|
||||
],
|
||||
"logs": [],
|
||||
"decision": None,
|
||||
"quant_signal": None,
|
||||
"llm_signal": None,
|
||||
"confidence": None,
|
||||
"result": None,
|
||||
"error": None,
|
||||
"request_id": request_id,
|
||||
"executor_type": executor_type,
|
||||
"contract_version": contract_version,
|
||||
"result_ref": result_ref,
|
||||
})
|
||||
self.task_results[task_id] = state
|
||||
self.processes.setdefault(task_id, None)
|
||||
return state
|
||||
|
||||
def create_portfolio_job(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
total: int,
|
||||
request_id: str | None = None,
|
||||
executor_type: str = DEFAULT_EXECUTOR_TYPE,
|
||||
contract_version: str = CONTRACT_VERSION,
|
||||
result_ref: str | None = None,
|
||||
) -> dict:
|
||||
state = self._normalize_task_state(task_id, {
|
||||
"task_id": task_id,
|
||||
"type": "portfolio",
|
||||
"status": "running",
|
||||
|
|
@ -38,11 +103,65 @@ class JobService:
|
|||
"results": [],
|
||||
"error": None,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
"request_id": request_id,
|
||||
"executor_type": executor_type,
|
||||
"contract_version": contract_version,
|
||||
"result_ref": result_ref,
|
||||
})
|
||||
self.task_results[task_id] = state
|
||||
self.processes.setdefault(task_id, None)
|
||||
return state
|
||||
|
||||
def attach_result_contract(
|
||||
self,
|
||||
task_id: str,
|
||||
*,
|
||||
result_ref: str,
|
||||
contract_version: str = CONTRACT_VERSION,
|
||||
executor_type: str | None = None,
|
||||
) -> dict:
|
||||
state = self.task_results[task_id]
|
||||
state["result_ref"] = result_ref
|
||||
state["contract_version"] = contract_version or state.get("contract_version") or CONTRACT_VERSION
|
||||
if executor_type:
|
||||
state["executor_type"] = executor_type
|
||||
return state
|
||||
|
||||
def complete_analysis_job(
|
||||
self,
|
||||
task_id: str,
|
||||
*,
|
||||
contract: dict,
|
||||
result_ref: str,
|
||||
executor_type: str | None = None,
|
||||
) -> dict:
|
||||
state = self.task_results[task_id]
|
||||
result = dict(contract.get("result") or {})
|
||||
signals = result.get("signals") or {}
|
||||
quant = signals.get("quant") or {}
|
||||
llm = signals.get("llm") or {}
|
||||
|
||||
state["status"] = contract.get("status", "completed")
|
||||
state["progress"] = contract.get("progress", 100)
|
||||
state["current_stage"] = contract.get("current_stage", state.get("current_stage"))
|
||||
state["elapsed_seconds"] = contract.get("elapsed_seconds", state.get("elapsed_seconds", 0))
|
||||
state["elapsed"] = contract.get("elapsed", state["elapsed_seconds"])
|
||||
state["decision"] = result.get("decision")
|
||||
state["quant_signal"] = quant.get("rating")
|
||||
state["llm_signal"] = llm.get("rating")
|
||||
state["confidence"] = result.get("confidence")
|
||||
state["result"] = result
|
||||
state["error"] = contract.get("error")
|
||||
state["contract_version"] = contract.get("contract_version", state.get("contract_version"))
|
||||
self.attach_result_contract(
|
||||
task_id,
|
||||
result_ref=result_ref,
|
||||
contract_version=state["contract_version"],
|
||||
executor_type=executor_type,
|
||||
)
|
||||
self.persist_task(task_id, state)
|
||||
return state
|
||||
|
||||
def update_portfolio_progress(self, task_id: str, *, ticker: str, completed: int) -> dict:
|
||||
state = self.task_results[task_id]
|
||||
state["current_ticker"] = ticker
|
||||
|
|
@ -92,3 +211,12 @@ class JobService:
|
|||
state["error"] = error
|
||||
self.persist_task(task_id, state)
|
||||
return state
|
||||
|
||||
@staticmethod
|
||||
def _normalize_task_state(task_id: str, state: dict) -> dict:
|
||||
normalized = dict(state)
|
||||
normalized.setdefault("request_id", task_id)
|
||||
normalized.setdefault("executor_type", DEFAULT_EXECUTOR_TYPE)
|
||||
normalized.setdefault("contract_version", CONTRACT_VERSION)
|
||||
normalized.setdefault("result_ref", None)
|
||||
return normalized
|
||||
|
|
|
|||
|
|
@ -7,11 +7,17 @@ from uuid import uuid4
|
|||
from fastapi import Request
|
||||
|
||||
|
||||
CONTRACT_VERSION = "v1alpha1"
|
||||
DEFAULT_EXECUTOR_TYPE = "legacy_subprocess"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestContext:
|
||||
"""Minimal request-scoped metadata passed into application services."""
|
||||
|
||||
request_id: str
|
||||
contract_version: str = CONTRACT_VERSION
|
||||
executor_type: str = DEFAULT_EXECUTOR_TYPE
|
||||
api_key: Optional[str] = None
|
||||
client_host: Optional[str] = None
|
||||
is_local: bool = False
|
||||
|
|
@ -23,6 +29,8 @@ def build_request_context(
|
|||
*,
|
||||
api_key: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
contract_version: str = CONTRACT_VERSION,
|
||||
executor_type: str = DEFAULT_EXECUTOR_TYPE,
|
||||
metadata: Optional[dict[str, str]] = None,
|
||||
) -> RequestContext:
|
||||
"""Create a stable request context without leaking FastAPI internals into services."""
|
||||
|
|
@ -30,6 +38,8 @@ def build_request_context(
|
|||
is_local = client_host in {"127.0.0.1", "::1", "localhost", "testclient"}
|
||||
return RequestContext(
|
||||
request_id=request_id or uuid4().hex,
|
||||
contract_version=contract_version,
|
||||
executor_type=executor_type,
|
||||
api_key=api_key,
|
||||
client_host=client_host,
|
||||
is_local=is_local,
|
||||
|
|
|
|||
|
|
@ -5,11 +5,15 @@ from pathlib import Path
|
|||
from typing import Optional
|
||||
|
||||
|
||||
CONTRACT_VERSION = "v1alpha1"
|
||||
|
||||
|
||||
class ResultStore:
|
||||
"""Storage boundary for persisted task state and portfolio results."""
|
||||
|
||||
def __init__(self, task_status_dir: Path, portfolio_gateway):
|
||||
self.task_status_dir = task_status_dir
|
||||
self.result_contract_dir = self.task_status_dir / "result_contracts"
|
||||
self.portfolio_gateway = portfolio_gateway
|
||||
|
||||
def restore_task_results(self) -> dict[str, dict]:
|
||||
|
|
@ -29,6 +33,15 @@ class ResultStore:
|
|||
self.task_status_dir.mkdir(parents=True, exist_ok=True)
|
||||
(self.task_status_dir / f"{task_id}.json").write_text(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
def save_result_contract(self, task_id: str, contract: dict) -> str:
|
||||
self.result_contract_dir.mkdir(parents=True, exist_ok=True)
|
||||
payload = dict(contract)
|
||||
payload.setdefault("task_id", task_id)
|
||||
payload.setdefault("contract_version", CONTRACT_VERSION)
|
||||
file_path = self.result_contract_dir / f"{task_id}.json"
|
||||
file_path.write_text(json.dumps(payload, ensure_ascii=False))
|
||||
return file_path.relative_to(self.task_status_dir).as_posix()
|
||||
|
||||
def delete_task_status(self, task_id: str) -> None:
|
||||
(self.task_status_dir / f"{task_id}.json").unlink(missing_ok=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -53,3 +53,73 @@ def test_analysis_task_routes_smoke(monkeypatch):
|
|||
assert any(task["task_id"] == "task-smoke" for task in tasks_response.json()["tasks"])
|
||||
assert status_response.status_code == 200
|
||||
assert status_response.json()["task_id"] == "task-smoke"
|
||||
|
||||
|
||||
def test_analysis_start_route_uses_analysis_service(monkeypatch):
|
||||
monkeypatch.delenv("DASHBOARD_API_KEY", raising=False)
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||
|
||||
main = _load_main_module(monkeypatch)
|
||||
created: dict[str, object] = {}
|
||||
|
||||
class DummyTask:
|
||||
def cancel(self):
|
||||
return None
|
||||
|
||||
def fake_create_task(coro):
|
||||
created["scheduled_coro"] = coro.cr_code.co_name
|
||||
coro.close()
|
||||
task = DummyTask()
|
||||
created["task"] = task
|
||||
return task
|
||||
|
||||
monkeypatch.setattr(main.asyncio, "create_task", fake_create_task)
|
||||
|
||||
with TestClient(main.app) as client:
|
||||
response = client.post(
|
||||
"/api/analysis/start",
|
||||
json={"ticker": "AAPL", "date": "2026-04-11"},
|
||||
headers={"api-key": "test-key"},
|
||||
)
|
||||
|
||||
payload = response.json()
|
||||
task_id = payload["task_id"]
|
||||
|
||||
assert response.status_code == 200
|
||||
assert payload["ticker"] == "AAPL"
|
||||
assert payload["date"] == "2026-04-11"
|
||||
assert payload["status"] == "running"
|
||||
assert created["scheduled_coro"] == "_run_analysis"
|
||||
assert main.app.state.analysis_tasks[task_id] is created["task"]
|
||||
assert main.app.state.task_results[task_id]["current_stage"] == "analysts"
|
||||
assert main.app.state.task_results[task_id]["status"] == "running"
|
||||
assert main.app.state.task_results[task_id]["request_id"]
|
||||
assert main.app.state.task_results[task_id]["executor_type"] == "legacy_subprocess"
|
||||
assert main.app.state.task_results[task_id]["result_ref"] is None
|
||||
|
||||
|
||||
def test_portfolio_analyze_route_uses_analysis_service_smoke(monkeypatch):
|
||||
monkeypatch.delenv("DASHBOARD_API_KEY", raising=False)
|
||||
monkeypatch.setenv("TRADINGAGENTS_USE_APPLICATION_SERVICES", "1")
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "service-key")
|
||||
|
||||
main = _load_main_module(monkeypatch)
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
async def fake_start_portfolio_analysis(*, task_id, date, request_context, broadcast_progress):
|
||||
captured["task_id"] = task_id
|
||||
captured["date"] = date
|
||||
captured["request_context"] = request_context
|
||||
captured["broadcast_progress"] = broadcast_progress
|
||||
return {"task_id": task_id, "status": "running", "total": 3}
|
||||
|
||||
with TestClient(main.app) as client:
|
||||
monkeypatch.setattr(main.app.state.analysis_service, "start_portfolio_analysis", fake_start_portfolio_analysis)
|
||||
response = client.post("/api/portfolio/analyze", headers={"api-key": "service-key"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "running"
|
||||
assert str(captured["task_id"]).startswith("port_")
|
||||
assert isinstance(captured["date"], str)
|
||||
assert captured["request_context"].api_key == "service-key"
|
||||
assert callable(captured["broadcast_progress"])
|
||||
|
|
|
|||
|
|
@ -0,0 +1,112 @@
|
|||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from services.executor import AnalysisExecutorError, LegacySubprocessAnalysisExecutor
|
||||
from services.request_context import build_request_context
|
||||
|
||||
|
||||
class _FakeStdout:
|
||||
def __init__(self, lines, *, stall: bool = False):
|
||||
self._lines = list(lines)
|
||||
self._stall = stall
|
||||
|
||||
async def readline(self):
|
||||
if self._stall:
|
||||
await asyncio.sleep(3600)
|
||||
if self._lines:
|
||||
return self._lines.pop(0)
|
||||
return b""
|
||||
|
||||
|
||||
class _FakeStderr:
|
||||
def __init__(self, payload: bytes = b""):
|
||||
self._payload = payload
|
||||
|
||||
async def read(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class _FakeProcess:
|
||||
def __init__(self, stdout, *, stderr: bytes = b"", returncode=None):
|
||||
self.stdout = stdout
|
||||
self.stderr = _FakeStderr(stderr)
|
||||
self.returncode = returncode
|
||||
self.kill_called = False
|
||||
self.wait_called = False
|
||||
|
||||
async def wait(self):
|
||||
self.wait_called = True
|
||||
if self.returncode is None:
|
||||
self.returncode = -9 if self.kill_called else 0
|
||||
return self.returncode
|
||||
|
||||
def kill(self):
|
||||
self.kill_called = True
|
||||
self.returncode = -9
|
||||
|
||||
|
||||
def test_executor_raises_when_required_markers_missing(monkeypatch):
|
||||
process = _FakeProcess(
|
||||
_FakeStdout(
|
||||
[
|
||||
b"STAGE:analysts\n",
|
||||
b"STAGE:portfolio\n",
|
||||
b"SIGNAL_DETAIL:{\"quant_signal\":\"BUY\",\"llm_signal\":\"BUY\",\"confidence\":0.8}\n",
|
||||
],
|
||||
),
|
||||
returncode=0,
|
||||
)
|
||||
|
||||
async def fake_create_subprocess_exec(*args, **kwargs):
|
||||
return process
|
||||
|
||||
monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_create_subprocess_exec)
|
||||
|
||||
executor = LegacySubprocessAnalysisExecutor(
|
||||
analysis_python=Path("/usr/bin/python3"),
|
||||
repo_root=Path("."),
|
||||
api_key_resolver=lambda: "env-key",
|
||||
)
|
||||
|
||||
async def scenario():
|
||||
with pytest.raises(AnalysisExecutorError, match="required markers: ANALYSIS_COMPLETE"):
|
||||
await executor.execute(
|
||||
task_id="task-1",
|
||||
ticker="AAPL",
|
||||
date="2026-04-13",
|
||||
request_context=build_request_context(api_key="ctx-key"),
|
||||
)
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_executor_kills_subprocess_on_timeout(monkeypatch):
|
||||
process = _FakeProcess(_FakeStdout([], stall=True))
|
||||
|
||||
async def fake_create_subprocess_exec(*args, **kwargs):
|
||||
return process
|
||||
|
||||
monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_create_subprocess_exec)
|
||||
|
||||
executor = LegacySubprocessAnalysisExecutor(
|
||||
analysis_python=Path("/usr/bin/python3"),
|
||||
repo_root=Path("."),
|
||||
api_key_resolver=lambda: "env-key",
|
||||
stdout_timeout_secs=0.01,
|
||||
)
|
||||
|
||||
async def scenario():
|
||||
with pytest.raises(AnalysisExecutorError, match="timed out"):
|
||||
await executor.execute(
|
||||
task_id="task-2",
|
||||
ticker="AAPL",
|
||||
date="2026-04-13",
|
||||
request_context=build_request_context(api_key="ctx-key"),
|
||||
)
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
assert process.kill_called is True
|
||||
assert process.wait_called is True
|
||||
|
|
@ -2,6 +2,7 @@ import json
|
|||
import asyncio
|
||||
|
||||
from services.analysis_service import AnalysisService
|
||||
from services.executor import AnalysisExecutionOutput
|
||||
from services.job_service import JobService
|
||||
from services.migration_flags import load_migration_flags
|
||||
from services.request_context import build_request_context
|
||||
|
|
@ -48,6 +49,8 @@ def test_build_request_context_defaults():
|
|||
|
||||
assert context.api_key == "secret"
|
||||
assert context.request_id
|
||||
assert context.contract_version == "v1alpha1"
|
||||
assert context.executor_type == "legacy_subprocess"
|
||||
assert context.metadata == {"source": "test"}
|
||||
|
||||
|
||||
|
|
@ -64,6 +67,23 @@ def test_result_store_round_trip(tmp_path):
|
|||
assert positions == [{"ticker": "AAPL", "account": "模拟账户"}]
|
||||
|
||||
|
||||
def test_result_store_saves_result_contract(tmp_path):
|
||||
gateway = DummyPortfolioGateway()
|
||||
store = ResultStore(tmp_path / "task_status", gateway)
|
||||
|
||||
result_ref = store.save_result_contract(
|
||||
"task-2",
|
||||
{"status": "completed", "result": {"decision": "BUY"}},
|
||||
)
|
||||
|
||||
saved = json.loads((tmp_path / "task_status" / result_ref).read_text())
|
||||
|
||||
assert result_ref == "result_contracts/task-2.json"
|
||||
assert saved["task_id"] == "task-2"
|
||||
assert saved["contract_version"] == "v1alpha1"
|
||||
assert saved["result"]["decision"] == "BUY"
|
||||
|
||||
|
||||
def test_job_service_create_and_fail_job():
|
||||
task_results = {}
|
||||
analysis_tasks = {}
|
||||
|
|
@ -78,15 +98,48 @@ def test_job_service_create_and_fail_job():
|
|||
delete_task=lambda task_id: persisted.pop(task_id, None),
|
||||
)
|
||||
|
||||
state = service.create_portfolio_job(task_id="port_1", total=2)
|
||||
state = service.create_portfolio_job(
|
||||
task_id="port_1",
|
||||
total=2,
|
||||
request_id="req-1",
|
||||
executor_type="analysis_executor",
|
||||
)
|
||||
assert state["total"] == 2
|
||||
assert processes["port_1"] is None
|
||||
assert state["request_id"] == "req-1"
|
||||
assert state["executor_type"] == "analysis_executor"
|
||||
assert state["contract_version"] == "v1alpha1"
|
||||
assert state["result_ref"] is None
|
||||
|
||||
attached = service.attach_result_contract(
|
||||
"port_1",
|
||||
result_ref="result_contracts/port_1.json",
|
||||
)
|
||||
assert attached["result_ref"] == "result_contracts/port_1.json"
|
||||
|
||||
failed = service.fail_job("port_1", "boom")
|
||||
assert failed["status"] == "failed"
|
||||
assert persisted["port_1"]["error"] == "boom"
|
||||
|
||||
|
||||
def test_job_service_restores_legacy_tasks_with_contract_metadata():
|
||||
service = JobService(
|
||||
task_results={},
|
||||
analysis_tasks={},
|
||||
processes={},
|
||||
persist_task=lambda task_id, data: None,
|
||||
delete_task=lambda task_id: None,
|
||||
)
|
||||
|
||||
service.restore_task_results({"legacy-task": {"task_id": "legacy-task", "status": "running"}})
|
||||
|
||||
restored = service.task_results["legacy-task"]
|
||||
assert restored["request_id"] == "legacy-task"
|
||||
assert restored["executor_type"] == "legacy_subprocess"
|
||||
assert restored["contract_version"] == "v1alpha1"
|
||||
assert restored["result_ref"] is None
|
||||
|
||||
|
||||
def test_analysis_service_build_recommendation_record():
|
||||
rec = AnalysisService._build_recommendation_record(
|
||||
stdout='\n'.join([
|
||||
|
|
@ -103,3 +156,74 @@ def test_analysis_service_build_recommendation_record():
|
|||
assert rec["quant_signal"] == "BUY"
|
||||
assert rec["llm_signal"] == "HOLD"
|
||||
assert rec["confidence"] == 0.75
|
||||
|
||||
|
||||
class FakeExecutor:
|
||||
async def execute(self, *, task_id, ticker, date, request_context, on_stage=None):
|
||||
if on_stage is not None:
|
||||
await on_stage("analysts")
|
||||
await on_stage("research")
|
||||
await on_stage("trading")
|
||||
await on_stage("risk")
|
||||
await on_stage("portfolio")
|
||||
return AnalysisExecutionOutput(
|
||||
decision="BUY",
|
||||
quant_signal="OVERWEIGHT",
|
||||
llm_signal="BUY",
|
||||
confidence=0.82,
|
||||
report_path=f"results/{ticker}/{date}/complete_report.md",
|
||||
)
|
||||
|
||||
|
||||
def test_analysis_service_start_analysis_uses_executor(tmp_path):
|
||||
gateway = DummyPortfolioGateway()
|
||||
store = ResultStore(tmp_path / "task_status", gateway)
|
||||
task_results = {}
|
||||
analysis_tasks = {}
|
||||
processes = {}
|
||||
service = JobService(
|
||||
task_results=task_results,
|
||||
analysis_tasks=analysis_tasks,
|
||||
processes=processes,
|
||||
persist_task=store.save_task_status,
|
||||
delete_task=store.delete_task_status,
|
||||
)
|
||||
analysis_service = AnalysisService(
|
||||
executor=FakeExecutor(),
|
||||
result_store=store,
|
||||
job_service=service,
|
||||
)
|
||||
broadcasts = []
|
||||
|
||||
async def _broadcast(task_id, payload):
|
||||
broadcasts.append((task_id, payload["status"], payload.get("current_stage")))
|
||||
|
||||
async def scenario():
|
||||
response = await analysis_service.start_analysis(
|
||||
task_id="task-1",
|
||||
ticker="AAPL",
|
||||
date="2026-04-13",
|
||||
request_context=build_request_context(api_key="secret"),
|
||||
broadcast_progress=_broadcast,
|
||||
)
|
||||
await analysis_tasks["task-1"]
|
||||
return response
|
||||
|
||||
response = asyncio.run(scenario())
|
||||
|
||||
assert response == {
|
||||
"contract_version": "v1alpha1",
|
||||
"task_id": "task-1",
|
||||
"ticker": "AAPL",
|
||||
"date": "2026-04-13",
|
||||
"status": "running",
|
||||
}
|
||||
assert task_results["task-1"]["status"] == "completed"
|
||||
assert task_results["task-1"]["decision"] == "BUY"
|
||||
assert task_results["task-1"]["result_ref"] == "result_contracts/task-1.json"
|
||||
assert task_results["task-1"]["result"]["signals"]["llm"]["rating"] == "BUY"
|
||||
saved_contract = json.loads((tmp_path / "task_status" / "result_contracts" / "task-1.json").read_text())
|
||||
assert saved_contract["status"] == "completed"
|
||||
assert saved_contract["result"]["signals"]["merged"]["rating"] == "BUY"
|
||||
assert broadcasts[0] == ("task-1", "running", "analysts")
|
||||
assert broadcasts[-1][1] == "completed"
|
||||
|
|
|
|||
Loading…
Reference in New Issue