Prevent executor regressions from leaking through the dashboard
Phase 1 left the backend halfway between legacy task payloads and the new executor boundary. This commit finishes the review-fix pass so missing protocol markers fail closed, timed-out subprocesses are killed, and successful analysis runs persist a result contract before task state is marked complete. Constraint: env312 lacks pytest-asyncio so async executor tests must run without extra plugins Rejected: Keep missing marker fallback as HOLD | masks protocol regressions as neutral signals Rejected: Leave service success assembly in AnalysisService | breaks contract-first persistence and result_ref wiring Confidence: high Scope-risk: moderate Reversibility: clean Directive: Keep backend success state driven by persisted result contracts; do not reintroduce raw stdout parsing in services Tested: python -m compileall orchestrator tradingagents web_dashboard/backend Tested: python -m pytest web_dashboard/backend/tests/test_executors.py web_dashboard/backend/tests/test_services_migration.py web_dashboard/backend/tests/test_api_smoke.py web_dashboard/backend/tests/test_main_api.py web_dashboard/backend/tests/test_portfolio_api.py -q Tested: python -m pytest orchestrator/tests/test_application_service.py orchestrator/tests/test_trading_graph_config.py -q Not-tested: real provider-backed MiniMax execution Not-tested: full dashboard websocket/manual UI flow
This commit is contained in:
parent
b6e57d01e3
commit
a4fb0c4060
|
|
@ -6,11 +6,8 @@ import asyncio
|
||||||
import hmac
|
import hmac
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
@ -23,6 +20,7 @@ from fastapi.staticfiles import StaticFiles
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from services import AnalysisService, JobService, ResultStore, build_request_context, load_migration_flags
|
from services import AnalysisService, JobService, ResultStore, build_request_context, load_migration_flags
|
||||||
|
from services.executor import LegacySubprocessAnalysisExecutor
|
||||||
|
|
||||||
# Path to TradingAgents repo root
|
# Path to TradingAgents repo root
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||||
|
|
@ -54,10 +52,12 @@ async def lifespan(app: FastAPI):
|
||||||
delete_task=app.state.result_store.delete_task_status,
|
delete_task=app.state.result_store.delete_task_status,
|
||||||
)
|
)
|
||||||
app.state.analysis_service = AnalysisService(
|
app.state.analysis_service = AnalysisService(
|
||||||
analysis_python=ANALYSIS_PYTHON,
|
executor=LegacySubprocessAnalysisExecutor(
|
||||||
repo_root=REPO_ROOT,
|
analysis_python=ANALYSIS_PYTHON,
|
||||||
analysis_script_template=ANALYSIS_SCRIPT_TEMPLATE,
|
repo_root=REPO_ROOT,
|
||||||
api_key_resolver=_get_analysis_api_key,
|
api_key_resolver=_get_analysis_api_key,
|
||||||
|
process_registry=app.state.job_service.register_process,
|
||||||
|
),
|
||||||
result_store=app.state.result_store,
|
result_store=app.state.result_store,
|
||||||
job_service=app.state.job_service,
|
job_service=app.state.job_service,
|
||||||
retry_count=MAX_RETRY_COUNT,
|
retry_count=MAX_RETRY_COUNT,
|
||||||
|
|
@ -229,23 +229,6 @@ def _save_to_cache(mode: str, data: dict):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _save_task_status(task_id: str, data: dict):
|
|
||||||
"""Persist task state to disk"""
|
|
||||||
try:
|
|
||||||
TASK_STATUS_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
(TASK_STATUS_DIR / f"{task_id}.json").write_text(json.dumps(data, ensure_ascii=False))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _delete_task_status(task_id: str):
|
|
||||||
"""Remove persisted task state from disk"""
|
|
||||||
try:
|
|
||||||
(TASK_STATUS_DIR / f"{task_id}.json").unlink(missing_ok=True)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# ============== SEPA Screening ==============
|
# ============== SEPA Screening ==============
|
||||||
|
|
||||||
def _run_sepa_screening(mode: str) -> dict:
|
def _run_sepa_screening(mode: str) -> dict:
|
||||||
|
|
@ -282,288 +265,34 @@ async def screen_stocks(mode: str = Query("china_strict"), refresh: bool = Query
|
||||||
|
|
||||||
# ============== Analysis Execution ==============
|
# ============== Analysis Execution ==============
|
||||||
|
|
||||||
# Script template for subprocess-based analysis
|
|
||||||
# api_key is passed via environment variable (not CLI) for security
|
|
||||||
ANALYSIS_SCRIPT_TEMPLATE = """
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
ticker = sys.argv[1]
|
|
||||||
date = sys.argv[2]
|
|
||||||
repo_root = sys.argv[3]
|
|
||||||
|
|
||||||
sys.path.insert(0, repo_root)
|
|
||||||
os.environ["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
|
|
||||||
import py_mini_racer
|
|
||||||
sys.modules["mini_racer"] = py_mini_racer
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
print("STAGE:analysts", flush=True)
|
|
||||||
|
|
||||||
from orchestrator.config import OrchestratorConfig
|
|
||||||
from orchestrator.orchestrator import TradingOrchestrator
|
|
||||||
|
|
||||||
config = OrchestratorConfig(
|
|
||||||
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
|
|
||||||
trading_agents_config={
|
|
||||||
"llm_provider": "anthropic",
|
|
||||||
"deep_think_llm": "MiniMax-M2.7-highspeed",
|
|
||||||
"quick_think_llm": "MiniMax-M2.7-highspeed",
|
|
||||||
"backend_url": "https://api.minimaxi.com/anthropic",
|
|
||||||
"max_debate_rounds": 1,
|
|
||||||
"max_risk_discuss_rounds": 1,
|
|
||||||
"project_dir": os.path.join(repo_root, "tradingagents"),
|
|
||||||
"results_dir": os.path.join(repo_root, "results"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
print("STAGE:research", flush=True)
|
|
||||||
|
|
||||||
orchestrator = TradingOrchestrator(config)
|
|
||||||
|
|
||||||
print("STAGE:trading", flush=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = orchestrator.get_combined_signal(ticker, date)
|
|
||||||
except ValueError as _e:
|
|
||||||
print("ANALYSIS_ERROR:" + str(_e), file=sys.stderr, flush=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
print("STAGE:risk", flush=True)
|
|
||||||
|
|
||||||
# Map direction + confidence to 5-level signal
|
|
||||||
# FinalSignal is a dataclass, access via attributes not .get()
|
|
||||||
direction = result.direction
|
|
||||||
confidence = result.confidence
|
|
||||||
llm_sig_obj = result.llm_signal
|
|
||||||
quant_sig_obj = result.quant_signal
|
|
||||||
# LLM metadata has "rating" field; quant metadata does not — derive from direction
|
|
||||||
llm_signal = llm_sig_obj.metadata.get("rating", "HOLD") if llm_sig_obj else "HOLD"
|
|
||||||
if quant_sig_obj is None:
|
|
||||||
quant_signal = "HOLD"
|
|
||||||
elif quant_sig_obj.direction == 1:
|
|
||||||
quant_signal = "BUY" if quant_sig_obj.confidence >= 0.7 else "OVERWEIGHT"
|
|
||||||
elif quant_sig_obj.direction == -1:
|
|
||||||
quant_signal = "SELL" if quant_sig_obj.confidence >= 0.7 else "UNDERWEIGHT"
|
|
||||||
else:
|
|
||||||
quant_signal = "HOLD"
|
|
||||||
|
|
||||||
if direction == 1:
|
|
||||||
signal = "BUY" if confidence >= 0.7 else "OVERWEIGHT"
|
|
||||||
elif direction == -1:
|
|
||||||
signal = "SELL" if confidence >= 0.7 else "UNDERWEIGHT"
|
|
||||||
else:
|
|
||||||
signal = "HOLD"
|
|
||||||
|
|
||||||
results_dir = Path(repo_root) / "results" / ticker / date
|
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
report_content = (
|
|
||||||
"# TradingAgents 分析报告\\n\\n"
|
|
||||||
"**股票**: " + ticker + "\\n"
|
|
||||||
"**日期**: " + date + "\\n\\n"
|
|
||||||
"## 最终决策\\n\\n"
|
|
||||||
"**" + signal + "**\\n\\n"
|
|
||||||
"## 信号详情\\n\\n"
|
|
||||||
"- LLM 信号: " + llm_signal + "\\n"
|
|
||||||
"- Quant 信号: " + quant_signal + "\\n"
|
|
||||||
"- 置信度: " + f"{confidence:.1%}" + "\\n\\n"
|
|
||||||
"## 分析摘要\\n\\n"
|
|
||||||
"N/A\\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
report_path = results_dir / "complete_report.md"
|
|
||||||
report_path.write_text(report_content)
|
|
||||||
|
|
||||||
print("STAGE:portfolio", flush=True)
|
|
||||||
signal_detail = json.dumps({"llm_signal": llm_signal, "quant_signal": quant_signal, "confidence": confidence})
|
|
||||||
print("SIGNAL_DETAIL:" + signal_detail, flush=True)
|
|
||||||
print("ANALYSIS_COMPLETE:" + signal, flush=True)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/analysis/start")
|
@app.post("/api/analysis/start")
|
||||||
async def start_analysis(request: AnalysisRequest, api_key: Optional[str] = Header(None)):
|
async def start_analysis(
|
||||||
"""Start a new analysis task"""
|
payload: AnalysisRequest,
|
||||||
|
http_request: Request,
|
||||||
|
api_key: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
|
"""Start a new analysis task."""
|
||||||
import uuid
|
import uuid
|
||||||
task_id = f"{request.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
|
|
||||||
date = request.date or datetime.now().strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
# Check dashboard API key (opt-in auth)
|
|
||||||
if not _check_api_key(api_key):
|
if not _check_api_key(api_key):
|
||||||
_auth_error()
|
_auth_error()
|
||||||
|
|
||||||
# Validate ANTHROPIC_API_KEY for the analysis subprocess
|
task_id = f"{payload.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
|
||||||
anthropic_key = _get_analysis_api_key()
|
date = payload.date or datetime.now().strftime("%Y-%m-%d")
|
||||||
if not anthropic_key:
|
request_context = build_request_context(http_request, api_key=api_key)
|
||||||
raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set")
|
|
||||||
|
|
||||||
# Initialize task state
|
try:
|
||||||
app.state.task_results[task_id] = {
|
return await app.state.analysis_service.start_analysis(
|
||||||
"task_id": task_id,
|
task_id=task_id,
|
||||||
"ticker": request.ticker,
|
ticker=payload.ticker,
|
||||||
"date": date,
|
date=date,
|
||||||
"status": "running",
|
request_context=request_context,
|
||||||
"progress": 0,
|
broadcast_progress=broadcast_progress,
|
||||||
"current_stage": "analysts",
|
)
|
||||||
"created_at": datetime.now().isoformat(),
|
except ValueError as exc:
|
||||||
"elapsed": 0,
|
raise HTTPException(status_code=400, detail=str(exc))
|
||||||
"stages": [
|
except RuntimeError as exc:
|
||||||
{"status": "running", "completed_at": None},
|
raise HTTPException(status_code=500, detail=str(exc))
|
||||||
{"status": "pending", "completed_at": None},
|
|
||||||
{"status": "pending", "completed_at": None},
|
|
||||||
{"status": "pending", "completed_at": None},
|
|
||||||
{"status": "pending", "completed_at": None},
|
|
||||||
],
|
|
||||||
"logs": [],
|
|
||||||
"decision": None,
|
|
||||||
"quant_signal": None,
|
|
||||||
"llm_signal": None,
|
|
||||||
"confidence": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
|
||||||
|
|
||||||
# Write analysis script to temp file with restrictive permissions (avoids subprocess -c quoting issues)
|
|
||||||
fd, script_path_str = tempfile.mkstemp(suffix=".py", prefix=f"analysis_{task_id}_")
|
|
||||||
script_path = Path(script_path_str)
|
|
||||||
os.chmod(script_path, 0o600)
|
|
||||||
with os.fdopen(fd, "w") as f:
|
|
||||||
f.write(ANALYSIS_SCRIPT_TEMPLATE)
|
|
||||||
|
|
||||||
# Store process reference for cancellation
|
|
||||||
app.state.processes = getattr(app.state, 'processes', {})
|
|
||||||
app.state.processes[task_id] = None
|
|
||||||
|
|
||||||
# Cancellation event for the monitor coroutine
|
|
||||||
cancel_event = asyncio.Event()
|
|
||||||
|
|
||||||
# Stage name to index mapping
|
|
||||||
STAGE_NAMES = ["analysts", "research", "trading", "risk", "portfolio"]
|
|
||||||
|
|
||||||
def _update_task_stage(stage_name: str):
|
|
||||||
"""Update task state for a completed stage and mark next as running."""
|
|
||||||
try:
|
|
||||||
idx = STAGE_NAMES.index(stage_name)
|
|
||||||
except ValueError:
|
|
||||||
return
|
|
||||||
# Mark all previous stages as completed
|
|
||||||
for i in range(idx):
|
|
||||||
if app.state.task_results[task_id]["stages"][i]["status"] != "completed":
|
|
||||||
app.state.task_results[task_id]["stages"][i]["status"] = "completed"
|
|
||||||
app.state.task_results[task_id]["stages"][i]["completed_at"] = datetime.now().strftime("%H:%M:%S")
|
|
||||||
# Mark current as completed
|
|
||||||
if app.state.task_results[task_id]["stages"][idx]["status"] != "completed":
|
|
||||||
app.state.task_results[task_id]["stages"][idx]["status"] = "completed"
|
|
||||||
app.state.task_results[task_id]["stages"][idx]["completed_at"] = datetime.now().strftime("%H:%M:%S")
|
|
||||||
# Mark next as running
|
|
||||||
if idx + 1 < 5:
|
|
||||||
if app.state.task_results[task_id]["stages"][idx + 1]["status"] == "pending":
|
|
||||||
app.state.task_results[task_id]["stages"][idx + 1]["status"] = "running"
|
|
||||||
# Update progress
|
|
||||||
app.state.task_results[task_id]["progress"] = int((idx + 1) / 5 * 100)
|
|
||||||
app.state.task_results[task_id]["current_stage"] = stage_name
|
|
||||||
|
|
||||||
async def run_analysis():
|
|
||||||
"""Run analysis subprocess and broadcast progress"""
|
|
||||||
try:
|
|
||||||
# Use clean environment - don't inherit parent env
|
|
||||||
clean_env = {k: v for k, v in os.environ.items()
|
|
||||||
if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))}
|
|
||||||
clean_env["ANTHROPIC_API_KEY"] = anthropic_key
|
|
||||||
clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
|
|
||||||
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
|
||||||
str(ANALYSIS_PYTHON),
|
|
||||||
str(script_path),
|
|
||||||
request.ticker,
|
|
||||||
date,
|
|
||||||
str(REPO_ROOT),
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE,
|
|
||||||
env=clean_env,
|
|
||||||
)
|
|
||||||
app.state.processes[task_id] = proc
|
|
||||||
|
|
||||||
# Read stdout line-by-line for real-time stage updates
|
|
||||||
stdout_lines = []
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
line_bytes = await asyncio.wait_for(proc.stdout.readline(), timeout=300.0)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
break
|
|
||||||
if not line_bytes:
|
|
||||||
break
|
|
||||||
line = line_bytes.decode(errors="replace").rstrip()
|
|
||||||
stdout_lines.append(line)
|
|
||||||
if line.startswith("STAGE:"):
|
|
||||||
stage = line.split(":", 1)[1].strip()
|
|
||||||
_update_task_stage(stage)
|
|
||||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
|
||||||
if cancel_event.is_set():
|
|
||||||
break
|
|
||||||
|
|
||||||
await proc.wait()
|
|
||||||
stderr_bytes = await proc.stderr.read()
|
|
||||||
|
|
||||||
# Clean up script file
|
|
||||||
try:
|
|
||||||
script_path.unlink()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if proc.returncode == 0:
|
|
||||||
output = "\n".join(stdout_lines)
|
|
||||||
decision = "HOLD"
|
|
||||||
for line in stdout_lines:
|
|
||||||
if line.startswith("SIGNAL_DETAIL:"):
|
|
||||||
try:
|
|
||||||
detail = json.loads(line.split(":", 1)[1].strip())
|
|
||||||
app.state.task_results[task_id]["quant_signal"] = detail.get("quant_signal")
|
|
||||||
app.state.task_results[task_id]["llm_signal"] = detail.get("llm_signal")
|
|
||||||
app.state.task_results[task_id]["confidence"] = detail.get("confidence")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if line.startswith("ANALYSIS_COMPLETE:"):
|
|
||||||
decision = line.split(":", 1)[1].strip()
|
|
||||||
|
|
||||||
app.state.task_results[task_id]["status"] = "completed"
|
|
||||||
app.state.task_results[task_id]["progress"] = 100
|
|
||||||
app.state.task_results[task_id]["decision"] = decision
|
|
||||||
app.state.task_results[task_id]["current_stage"] = "portfolio"
|
|
||||||
for i in range(5):
|
|
||||||
app.state.task_results[task_id]["stages"][i]["status"] = "completed"
|
|
||||||
if not app.state.task_results[task_id]["stages"][i].get("completed_at"):
|
|
||||||
app.state.task_results[task_id]["stages"][i]["completed_at"] = datetime.now().strftime("%H:%M:%S")
|
|
||||||
else:
|
|
||||||
error_msg = stderr_bytes.decode(errors="replace")[-1000:] if stderr_bytes else "Unknown error"
|
|
||||||
app.state.task_results[task_id]["status"] = "failed"
|
|
||||||
app.state.task_results[task_id]["error"] = error_msg
|
|
||||||
|
|
||||||
_save_task_status(task_id, app.state.task_results[task_id])
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
cancel_event.set()
|
|
||||||
app.state.task_results[task_id]["status"] = "failed"
|
|
||||||
app.state.task_results[task_id]["error"] = str(e)
|
|
||||||
try:
|
|
||||||
script_path.unlink()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
_save_task_status(task_id, app.state.task_results[task_id])
|
|
||||||
|
|
||||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
|
||||||
|
|
||||||
task = asyncio.create_task(run_analysis())
|
|
||||||
app.state.analysis_tasks[task_id] = task
|
|
||||||
|
|
||||||
return {
|
|
||||||
"task_id": task_id,
|
|
||||||
"ticker": request.ticker,
|
|
||||||
"date": date,
|
|
||||||
"status": "running",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/analysis/status/{task_id}")
|
@app.get("/api/analysis/status/{task_id}")
|
||||||
|
|
@ -600,13 +329,12 @@ async def list_tasks(api_key: Optional[str] = Header(None)):
|
||||||
|
|
||||||
@app.delete("/api/analysis/cancel/{task_id}")
|
@app.delete("/api/analysis/cancel/{task_id}")
|
||||||
async def cancel_task(task_id: str, api_key: Optional[str] = Header(None)):
|
async def cancel_task(task_id: str, api_key: Optional[str] = Header(None)):
|
||||||
"""Cancel a running task"""
|
"""Cancel a running task."""
|
||||||
if not _check_api_key(api_key):
|
if not _check_api_key(api_key):
|
||||||
_auth_error()
|
_auth_error()
|
||||||
if task_id not in app.state.task_results:
|
if task_id not in app.state.task_results:
|
||||||
raise HTTPException(status_code=404, detail="Task not found")
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
# Kill the subprocess if it's still running
|
|
||||||
proc = app.state.processes.get(task_id)
|
proc = app.state.processes.get(task_id)
|
||||||
if proc and proc.returncode is None:
|
if proc and proc.returncode is None:
|
||||||
try:
|
try:
|
||||||
|
|
@ -614,26 +342,18 @@ async def cancel_task(task_id: str, api_key: Optional[str] = Header(None)):
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Cancel the asyncio task
|
|
||||||
task = app.state.analysis_tasks.get(task_id)
|
task = app.state.analysis_tasks.get(task_id)
|
||||||
if task:
|
if task:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
app.state.task_results[task_id]["status"] = "failed"
|
|
||||||
app.state.task_results[task_id]["error"] = "用户取消"
|
|
||||||
_save_task_status(task_id, app.state.task_results[task_id])
|
|
||||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
|
||||||
|
|
||||||
# Clean up temp script (may use tempfile.mkstemp with random suffix)
|
state = app.state.task_results[task_id]
|
||||||
for p in Path("/tmp").glob(f"analysis_{task_id}_*.py"):
|
state["status"] = "cancelled"
|
||||||
try:
|
state["error"] = "用户取消"
|
||||||
p.unlink()
|
app.state.result_store.save_task_status(task_id, state)
|
||||||
except Exception:
|
await broadcast_progress(task_id, state)
|
||||||
pass
|
app.state.result_store.delete_task_status(task_id)
|
||||||
|
|
||||||
# Remove persisted task state
|
return {"contract_version": "v1alpha1", "task_id": task_id, "status": "cancelled"}
|
||||||
_delete_task_status(task_id)
|
|
||||||
|
|
||||||
return {"task_id": task_id, "status": "cancelled"}
|
|
||||||
|
|
||||||
|
|
||||||
# ============== WebSocket ==============
|
# ============== WebSocket ==============
|
||||||
|
|
@ -1091,169 +811,31 @@ async def get_recommendation_endpoint(date: str, ticker: str, api_key: Optional[
|
||||||
# --- Batch Analysis ---
|
# --- Batch Analysis ---
|
||||||
|
|
||||||
@app.post("/api/portfolio/analyze")
|
@app.post("/api/portfolio/analyze")
|
||||||
async def start_portfolio_analysis(api_key: Optional[str] = Header(None)):
|
async def start_portfolio_analysis(
|
||||||
"""
|
http_request: Request,
|
||||||
Trigger batch analysis for all watchlist tickers.
|
api_key: Optional[str] = Header(None),
|
||||||
Runs serially, streaming progress via WebSocket (task_id prefixed with 'port_').
|
):
|
||||||
"""
|
"""Trigger batch analysis for all watchlist tickers."""
|
||||||
if not _check_api_key(api_key):
|
if not _check_api_key(api_key):
|
||||||
_auth_error()
|
_auth_error()
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
date = datetime.now().strftime("%Y-%m-%d")
|
date = datetime.now().strftime("%Y-%m-%d")
|
||||||
task_id = f"port_{date}_{uuid.uuid4().hex[:6]}"
|
task_id = f"port_{date}_{uuid.uuid4().hex[:6]}"
|
||||||
|
request_context = build_request_context(http_request, api_key=api_key)
|
||||||
|
|
||||||
if app.state.migration_flags.use_application_services:
|
try:
|
||||||
request_context = build_request_context(api_key=api_key)
|
return await app.state.analysis_service.start_portfolio_analysis(
|
||||||
try:
|
task_id=task_id,
|
||||||
return await app.state.analysis_service.start_portfolio_analysis(
|
date=date,
|
||||||
task_id=task_id,
|
request_context=request_context,
|
||||||
date=date,
|
broadcast_progress=broadcast_progress,
|
||||||
request_context=request_context,
|
)
|
||||||
broadcast_progress=broadcast_progress,
|
except ValueError as exc:
|
||||||
)
|
raise HTTPException(status_code=400, detail=str(exc))
|
||||||
except ValueError as exc:
|
except RuntimeError as exc:
|
||||||
raise HTTPException(status_code=400, detail=str(exc))
|
raise HTTPException(status_code=500, detail=str(exc))
|
||||||
except RuntimeError as exc:
|
|
||||||
raise HTTPException(status_code=500, detail=str(exc))
|
|
||||||
|
|
||||||
watchlist = get_watchlist()
|
|
||||||
if not watchlist:
|
|
||||||
raise HTTPException(status_code=400, detail="自选股为空,请先添加股票")
|
|
||||||
|
|
||||||
total = len(watchlist)
|
|
||||||
app.state.task_results[task_id] = {
|
|
||||||
"task_id": task_id,
|
|
||||||
"type": "portfolio",
|
|
||||||
"status": "running",
|
|
||||||
"total": total,
|
|
||||||
"completed": 0,
|
|
||||||
"failed": 0,
|
|
||||||
"current_ticker": None,
|
|
||||||
"results": [],
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
api_key = _get_analysis_api_key()
|
|
||||||
if not api_key:
|
|
||||||
raise HTTPException(status_code=500, detail="ANTHROPIC_API_KEY environment variable not set")
|
|
||||||
|
|
||||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
|
||||||
|
|
||||||
async def run_portfolio_analysis():
|
|
||||||
max_retries = MAX_RETRY_COUNT
|
|
||||||
|
|
||||||
async def run_single_analysis(ticker: str, stock: dict) -> tuple[bool, str, dict | None]:
|
|
||||||
"""Run analysis for one ticker. Returns (success, decision, rec_or_error)."""
|
|
||||||
last_error = None
|
|
||||||
for attempt in range(max_retries + 1):
|
|
||||||
script_path = None
|
|
||||||
try:
|
|
||||||
fd, script_path_str = tempfile.mkstemp(suffix=".py", prefix=f"analysis_{task_id}_{stock['_idx']}_")
|
|
||||||
script_path = Path(script_path_str)
|
|
||||||
os.chmod(script_path, 0o600)
|
|
||||||
with os.fdopen(fd, "w") as f:
|
|
||||||
f.write(ANALYSIS_SCRIPT_TEMPLATE)
|
|
||||||
|
|
||||||
clean_env = {k: v for k, v in os.environ.items()
|
|
||||||
if not k.startswith(("PYTHON", "CONDA", "VIRTUAL"))}
|
|
||||||
clean_env["ANTHROPIC_API_KEY"] = api_key
|
|
||||||
clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
|
|
||||||
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
|
||||||
str(ANALYSIS_PYTHON), str(script_path), ticker, date, str(REPO_ROOT),
|
|
||||||
stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
|
|
||||||
env=clean_env,
|
|
||||||
)
|
|
||||||
app.state.processes[task_id] = proc
|
|
||||||
|
|
||||||
stdout, stderr = await proc.communicate()
|
|
||||||
|
|
||||||
try:
|
|
||||||
script_path.unlink()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if proc.returncode == 0:
|
|
||||||
output = stdout.decode()
|
|
||||||
decision = "HOLD"
|
|
||||||
quant_signal = None
|
|
||||||
llm_signal = None
|
|
||||||
confidence = None
|
|
||||||
for line in output.splitlines():
|
|
||||||
if line.startswith("SIGNAL_DETAIL:"):
|
|
||||||
try:
|
|
||||||
detail = json.loads(line.split(":", 1)[1].strip())
|
|
||||||
quant_signal = detail.get("quant_signal")
|
|
||||||
llm_signal = detail.get("llm_signal")
|
|
||||||
confidence = detail.get("confidence")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if line.startswith("ANALYSIS_COMPLETE:"):
|
|
||||||
decision = line.split(":", 1)[1].strip()
|
|
||||||
rec = {
|
|
||||||
"ticker": ticker,
|
|
||||||
"name": stock.get("name", ticker),
|
|
||||||
"analysis_date": date,
|
|
||||||
"decision": decision,
|
|
||||||
"quant_signal": quant_signal,
|
|
||||||
"llm_signal": llm_signal,
|
|
||||||
"confidence": confidence,
|
|
||||||
"created_at": datetime.now().isoformat(),
|
|
||||||
}
|
|
||||||
save_recommendation(date, ticker, rec)
|
|
||||||
return True, decision, rec
|
|
||||||
else:
|
|
||||||
last_error = stderr.decode()[-500:] if stderr else f"exit {proc.returncode}"
|
|
||||||
except Exception as e:
|
|
||||||
last_error = str(e)
|
|
||||||
finally:
|
|
||||||
if script_path:
|
|
||||||
try:
|
|
||||||
script_path.unlink()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if attempt < max_retries:
|
|
||||||
await asyncio.sleep(RETRY_BASE_DELAY_SECS ** attempt) # exponential backoff: 1s, 2s
|
|
||||||
|
|
||||||
return False, "HOLD", None
|
|
||||||
|
|
||||||
try:
|
|
||||||
for i, stock in enumerate(watchlist):
|
|
||||||
stock["_idx"] = i # used in temp file name
|
|
||||||
ticker = stock["ticker"]
|
|
||||||
app.state.task_results[task_id]["current_ticker"] = ticker
|
|
||||||
app.state.task_results[task_id]["status"] = "running"
|
|
||||||
app.state.task_results[task_id]["completed"] = i
|
|
||||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
|
||||||
|
|
||||||
success, decision, rec = await run_single_analysis(ticker, stock)
|
|
||||||
if success:
|
|
||||||
app.state.task_results[task_id]["completed"] = i + 1
|
|
||||||
app.state.task_results[task_id]["results"].append(rec)
|
|
||||||
else:
|
|
||||||
app.state.task_results[task_id]["failed"] += 1
|
|
||||||
|
|
||||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
|
||||||
|
|
||||||
app.state.task_results[task_id]["status"] = "completed"
|
|
||||||
app.state.task_results[task_id]["current_ticker"] = None
|
|
||||||
_save_task_status(task_id, app.state.task_results[task_id])
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
app.state.task_results[task_id]["status"] = "failed"
|
|
||||||
app.state.task_results[task_id]["error"] = str(e)
|
|
||||||
_save_task_status(task_id, app.state.task_results[task_id])
|
|
||||||
|
|
||||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
|
||||||
|
|
||||||
task = asyncio.create_task(run_portfolio_analysis())
|
|
||||||
app.state.analysis_tasks[task_id] = task
|
|
||||||
|
|
||||||
return {
|
|
||||||
"task_id": task_id,
|
|
||||||
"total": total,
|
|
||||||
"status": "running",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,15 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import time
|
||||||
import tempfile
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
|
||||||
from typing import Awaitable, Callable, Optional
|
from typing import Awaitable, Callable, Optional
|
||||||
|
|
||||||
|
from .executor import AnalysisExecutionOutput, AnalysisExecutor, AnalysisExecutorError
|
||||||
from .request_context import RequestContext
|
from .request_context import RequestContext
|
||||||
|
|
||||||
BroadcastFn = Callable[[str, dict], Awaitable[None]]
|
BroadcastFn = Callable[[str, dict], Awaitable[None]]
|
||||||
|
ANALYSIS_STAGE_NAMES = ["analysts", "research", "trading", "risk", "portfolio"]
|
||||||
|
|
||||||
|
|
||||||
class AnalysisService:
|
class AnalysisService:
|
||||||
|
|
@ -19,24 +19,56 @@ class AnalysisService:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
analysis_python: Path,
|
executor: AnalysisExecutor,
|
||||||
repo_root: Path,
|
|
||||||
analysis_script_template: str,
|
|
||||||
api_key_resolver: Callable[[], Optional[str]],
|
|
||||||
result_store,
|
result_store,
|
||||||
job_service,
|
job_service,
|
||||||
retry_count: int = 2,
|
retry_count: int = 2,
|
||||||
retry_base_delay_secs: int = 1,
|
retry_base_delay_secs: int = 1,
|
||||||
):
|
):
|
||||||
self.analysis_python = analysis_python
|
self.executor = executor
|
||||||
self.repo_root = repo_root
|
|
||||||
self.analysis_script_template = analysis_script_template
|
|
||||||
self.api_key_resolver = api_key_resolver
|
|
||||||
self.result_store = result_store
|
self.result_store = result_store
|
||||||
self.job_service = job_service
|
self.job_service = job_service
|
||||||
self.retry_count = retry_count
|
self.retry_count = retry_count
|
||||||
self.retry_base_delay_secs = retry_base_delay_secs
|
self.retry_base_delay_secs = retry_base_delay_secs
|
||||||
|
|
||||||
|
async def start_analysis(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
task_id: str,
|
||||||
|
ticker: str,
|
||||||
|
date: str,
|
||||||
|
request_context: RequestContext,
|
||||||
|
broadcast_progress: BroadcastFn,
|
||||||
|
) -> dict:
|
||||||
|
state = self.job_service.create_analysis_job(
|
||||||
|
task_id=task_id,
|
||||||
|
ticker=ticker,
|
||||||
|
date=date,
|
||||||
|
request_id=request_context.request_id,
|
||||||
|
executor_type=request_context.executor_type,
|
||||||
|
contract_version=request_context.contract_version,
|
||||||
|
)
|
||||||
|
self.job_service.register_process(task_id, None)
|
||||||
|
await broadcast_progress(task_id, state)
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._run_analysis(
|
||||||
|
task_id=task_id,
|
||||||
|
ticker=ticker,
|
||||||
|
date=date,
|
||||||
|
request_context=request_context,
|
||||||
|
broadcast_progress=broadcast_progress,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.job_service.register_background_task(task_id, task)
|
||||||
|
return {
|
||||||
|
"contract_version": "v1alpha1",
|
||||||
|
"task_id": task_id,
|
||||||
|
"ticker": ticker,
|
||||||
|
"date": date,
|
||||||
|
"status": "running",
|
||||||
|
}
|
||||||
|
|
||||||
async def start_portfolio_analysis(
|
async def start_portfolio_analysis(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|
@ -45,16 +77,17 @@ class AnalysisService:
|
||||||
request_context: RequestContext,
|
request_context: RequestContext,
|
||||||
broadcast_progress: BroadcastFn,
|
broadcast_progress: BroadcastFn,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
del request_context # Reserved for future auditing/auth propagation.
|
|
||||||
watchlist = self.result_store.get_watchlist()
|
watchlist = self.result_store.get_watchlist()
|
||||||
if not watchlist:
|
if not watchlist:
|
||||||
raise ValueError("自选股为空,请先添加股票")
|
raise ValueError("自选股为空,请先添加股票")
|
||||||
|
|
||||||
analysis_api_key = self.api_key_resolver()
|
state = self.job_service.create_portfolio_job(
|
||||||
if not analysis_api_key:
|
task_id=task_id,
|
||||||
raise RuntimeError("ANTHROPIC_API_KEY environment variable not set")
|
total=len(watchlist),
|
||||||
|
request_id=request_context.request_id,
|
||||||
state = self.job_service.create_portfolio_job(task_id=task_id, total=len(watchlist))
|
executor_type=request_context.executor_type,
|
||||||
|
contract_version=request_context.contract_version,
|
||||||
|
)
|
||||||
await broadcast_progress(task_id, state)
|
await broadcast_progress(task_id, state)
|
||||||
|
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
|
|
@ -62,24 +95,111 @@ class AnalysisService:
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
date=date,
|
date=date,
|
||||||
watchlist=watchlist,
|
watchlist=watchlist,
|
||||||
analysis_api_key=analysis_api_key,
|
request_context=request_context,
|
||||||
broadcast_progress=broadcast_progress,
|
broadcast_progress=broadcast_progress,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.job_service.register_background_task(task_id, task)
|
self.job_service.register_background_task(task_id, task)
|
||||||
return {
|
return {
|
||||||
|
"contract_version": "v1alpha1",
|
||||||
"task_id": task_id,
|
"task_id": task_id,
|
||||||
"total": len(watchlist),
|
"total": len(watchlist),
|
||||||
"status": "running",
|
"status": "running",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def _run_analysis(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
task_id: str,
|
||||||
|
ticker: str,
|
||||||
|
date: str,
|
||||||
|
request_context: RequestContext,
|
||||||
|
broadcast_progress: BroadcastFn,
|
||||||
|
) -> None:
|
||||||
|
start_time = time.monotonic()
|
||||||
|
try:
|
||||||
|
output = await self.executor.execute(
|
||||||
|
task_id=task_id,
|
||||||
|
ticker=ticker,
|
||||||
|
date=date,
|
||||||
|
request_context=request_context,
|
||||||
|
on_stage=lambda stage: self._handle_analysis_stage(
|
||||||
|
task_id=task_id,
|
||||||
|
stage_name=stage,
|
||||||
|
started_at=start_time,
|
||||||
|
broadcast_progress=broadcast_progress,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
state = self.job_service.task_results[task_id]
|
||||||
|
elapsed_seconds = int(time.monotonic() - start_time)
|
||||||
|
contract = output.to_result_contract(
|
||||||
|
task_id=task_id,
|
||||||
|
ticker=ticker,
|
||||||
|
date=date,
|
||||||
|
created_at=state["created_at"],
|
||||||
|
elapsed_seconds=elapsed_seconds,
|
||||||
|
current_stage=ANALYSIS_STAGE_NAMES[-1],
|
||||||
|
)
|
||||||
|
result_ref = self.result_store.save_result_contract(task_id, contract)
|
||||||
|
self.job_service.complete_analysis_job(
|
||||||
|
task_id,
|
||||||
|
contract=contract,
|
||||||
|
result_ref=result_ref,
|
||||||
|
executor_type=request_context.executor_type,
|
||||||
|
)
|
||||||
|
except AnalysisExecutorError as exc:
|
||||||
|
self._fail_analysis_state(
|
||||||
|
task_id=task_id,
|
||||||
|
message=str(exc),
|
||||||
|
started_at=start_time,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
self._fail_analysis_state(
|
||||||
|
task_id=task_id,
|
||||||
|
message=str(exc),
|
||||||
|
started_at=start_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
await broadcast_progress(task_id, self.job_service.task_results[task_id])
|
||||||
|
|
||||||
|
async def _handle_analysis_stage(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
task_id: str,
|
||||||
|
stage_name: str,
|
||||||
|
started_at: float,
|
||||||
|
broadcast_progress: BroadcastFn,
|
||||||
|
) -> None:
|
||||||
|
state = self.job_service.task_results[task_id]
|
||||||
|
try:
|
||||||
|
idx = ANALYSIS_STAGE_NAMES.index(stage_name)
|
||||||
|
except ValueError:
|
||||||
|
return
|
||||||
|
|
||||||
|
for i, entry in enumerate(state["stages"]):
|
||||||
|
if i < idx:
|
||||||
|
if entry["status"] != "completed":
|
||||||
|
entry["status"] = "completed"
|
||||||
|
entry["completed_at"] = datetime.now().strftime("%H:%M:%S")
|
||||||
|
elif i == idx:
|
||||||
|
entry["status"] = "completed"
|
||||||
|
entry["completed_at"] = entry["completed_at"] or datetime.now().strftime("%H:%M:%S")
|
||||||
|
elif i == idx + 1 and entry["status"] == "pending":
|
||||||
|
entry["status"] = "running"
|
||||||
|
|
||||||
|
state["progress"] = int((idx + 1) / len(ANALYSIS_STAGE_NAMES) * 100)
|
||||||
|
state["current_stage"] = stage_name
|
||||||
|
state["elapsed_seconds"] = int(time.monotonic() - started_at)
|
||||||
|
state["elapsed"] = state["elapsed_seconds"]
|
||||||
|
await broadcast_progress(task_id, state)
|
||||||
|
|
||||||
async def _run_portfolio_analysis(
|
async def _run_portfolio_analysis(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
date: str,
|
date: str,
|
||||||
watchlist: list[dict],
|
watchlist: list[dict],
|
||||||
analysis_api_key: str,
|
request_context: RequestContext,
|
||||||
broadcast_progress: BroadcastFn,
|
broadcast_progress: BroadcastFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
|
|
@ -96,7 +216,7 @@ class AnalysisService:
|
||||||
ticker=ticker,
|
ticker=ticker,
|
||||||
stock=stock,
|
stock=stock,
|
||||||
date=date,
|
date=date,
|
||||||
analysis_api_key=analysis_api_key,
|
request_context=request_context,
|
||||||
)
|
)
|
||||||
if success and rec is not None:
|
if success and rec is not None:
|
||||||
self.job_service.append_portfolio_result(task_id, rec)
|
self.job_service.append_portfolio_result(task_id, rec)
|
||||||
|
|
@ -118,61 +238,27 @@ class AnalysisService:
|
||||||
ticker: str,
|
ticker: str,
|
||||||
stock: dict,
|
stock: dict,
|
||||||
date: str,
|
date: str,
|
||||||
analysis_api_key: str,
|
request_context: RequestContext,
|
||||||
) -> tuple[bool, Optional[dict]]:
|
) -> tuple[bool, Optional[dict]]:
|
||||||
last_error: Optional[str] = None
|
last_error: Optional[str] = None
|
||||||
for attempt in range(self.retry_count + 1):
|
for attempt in range(self.retry_count + 1):
|
||||||
script_path: Optional[Path] = None
|
|
||||||
try:
|
try:
|
||||||
fd, script_path_str = tempfile.mkstemp(
|
output = await self.executor.execute(
|
||||||
suffix=".py",
|
task_id=f"{task_id}_{stock['_idx']}",
|
||||||
prefix=f"analysis_{task_id}_{stock['_idx']}_",
|
ticker=ticker,
|
||||||
|
date=date,
|
||||||
|
request_context=request_context,
|
||||||
)
|
)
|
||||||
script_path = Path(script_path_str)
|
rec = self._build_recommendation_record(
|
||||||
os.chmod(script_path, 0o600)
|
output=output,
|
||||||
with os.fdopen(fd, "w") as handle:
|
ticker=ticker,
|
||||||
handle.write(self.analysis_script_template)
|
stock=stock,
|
||||||
|
date=date,
|
||||||
clean_env = {
|
|
||||||
key: value
|
|
||||||
for key, value in os.environ.items()
|
|
||||||
if not key.startswith(("PYTHON", "CONDA", "VIRTUAL"))
|
|
||||||
}
|
|
||||||
clean_env["ANTHROPIC_API_KEY"] = analysis_api_key
|
|
||||||
clean_env["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
|
|
||||||
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
|
||||||
str(self.analysis_python),
|
|
||||||
str(script_path),
|
|
||||||
ticker,
|
|
||||||
date,
|
|
||||||
str(self.repo_root),
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE,
|
|
||||||
env=clean_env,
|
|
||||||
)
|
)
|
||||||
self.job_service.register_process(task_id, proc)
|
self.result_store.save_recommendation(date, ticker, rec)
|
||||||
stdout, stderr = await proc.communicate()
|
return True, rec
|
||||||
|
|
||||||
if proc.returncode == 0:
|
|
||||||
rec = self._build_recommendation_record(
|
|
||||||
stdout=stdout.decode(),
|
|
||||||
ticker=ticker,
|
|
||||||
stock=stock,
|
|
||||||
date=date,
|
|
||||||
)
|
|
||||||
self.result_store.save_recommendation(date, ticker, rec)
|
|
||||||
return True, rec
|
|
||||||
|
|
||||||
last_error = stderr.decode()[-500:] if stderr else f"exit {proc.returncode}"
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
last_error = str(exc)
|
last_error = str(exc)
|
||||||
finally:
|
|
||||||
if script_path is not None:
|
|
||||||
try:
|
|
||||||
script_path.unlink()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if attempt < self.retry_count:
|
if attempt < self.retry_count:
|
||||||
await asyncio.sleep(self.retry_base_delay_secs ** attempt)
|
await asyncio.sleep(self.retry_base_delay_secs ** attempt)
|
||||||
|
|
@ -181,23 +267,45 @@ class AnalysisService:
|
||||||
self.job_service.task_results[task_id]["last_error"] = last_error
|
self.job_service.task_results[task_id]["last_error"] = last_error
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
def _fail_analysis_state(self, *, task_id: str, message: str, started_at: float) -> None:
|
||||||
|
state = self.job_service.task_results[task_id]
|
||||||
|
state["status"] = "failed"
|
||||||
|
state["elapsed_seconds"] = int(time.monotonic() - started_at)
|
||||||
|
state["elapsed"] = state["elapsed_seconds"]
|
||||||
|
state["result"] = None
|
||||||
|
state["error"] = message
|
||||||
|
self.result_store.save_task_status(task_id, state)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_recommendation_record(*, stdout: str, ticker: str, stock: dict, date: str) -> dict:
|
def _build_recommendation_record(
|
||||||
decision = "HOLD"
|
*,
|
||||||
quant_signal = None
|
ticker: str,
|
||||||
llm_signal = None
|
stock: dict,
|
||||||
confidence = None
|
date: str,
|
||||||
for line in stdout.splitlines():
|
output: AnalysisExecutionOutput | None = None,
|
||||||
if line.startswith("SIGNAL_DETAIL:"):
|
stdout: str | None = None,
|
||||||
try:
|
) -> dict:
|
||||||
detail = json.loads(line.split(":", 1)[1].strip())
|
if output is not None:
|
||||||
except Exception:
|
decision = output.decision
|
||||||
continue
|
quant_signal = output.quant_signal
|
||||||
quant_signal = detail.get("quant_signal")
|
llm_signal = output.llm_signal
|
||||||
llm_signal = detail.get("llm_signal")
|
confidence = output.confidence
|
||||||
confidence = detail.get("confidence")
|
else:
|
||||||
if line.startswith("ANALYSIS_COMPLETE:"):
|
decision = "HOLD"
|
||||||
decision = line.split(":", 1)[1].strip()
|
quant_signal = None
|
||||||
|
llm_signal = None
|
||||||
|
confidence = None
|
||||||
|
for line in (stdout or "").splitlines():
|
||||||
|
if line.startswith("SIGNAL_DETAIL:"):
|
||||||
|
try:
|
||||||
|
detail = json.loads(line.split(":", 1)[1].strip())
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
quant_signal = detail.get("quant_signal")
|
||||||
|
llm_signal = detail.get("llm_signal")
|
||||||
|
confidence = detail.get("confidence")
|
||||||
|
if line.startswith("ANALYSIS_COMPLETE:"):
|
||||||
|
decision = line.split(":", 1)[1].strip()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"ticker": ticker,
|
"ticker": ticker,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,381 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Awaitable, Callable, Optional, Protocol
|
||||||
|
|
||||||
|
from .request_context import (
|
||||||
|
CONTRACT_VERSION,
|
||||||
|
DEFAULT_EXECUTOR_TYPE,
|
||||||
|
RequestContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
StageCallback = Callable[[str], Awaitable[None]]
|
||||||
|
ProcessRegistry = Callable[[str, asyncio.subprocess.Process | None], None]
|
||||||
|
|
||||||
|
LEGACY_ANALYSIS_SCRIPT_TEMPLATE = """
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ticker = sys.argv[1]
|
||||||
|
date = sys.argv[2]
|
||||||
|
repo_root = sys.argv[3]
|
||||||
|
|
||||||
|
sys.path.insert(0, repo_root)
|
||||||
|
|
||||||
|
import py_mini_racer
|
||||||
|
sys.modules["mini_racer"] = py_mini_racer
|
||||||
|
|
||||||
|
from orchestrator.config import OrchestratorConfig
|
||||||
|
from orchestrator.orchestrator import TradingOrchestrator
|
||||||
|
from tradingagents.default_config import get_default_config
|
||||||
|
|
||||||
|
trading_config = get_default_config()
|
||||||
|
trading_config["project_dir"] = os.path.join(repo_root, "tradingagents")
|
||||||
|
trading_config["results_dir"] = os.path.join(repo_root, "results")
|
||||||
|
trading_config["max_debate_rounds"] = 1
|
||||||
|
trading_config["max_risk_discuss_rounds"] = 1
|
||||||
|
|
||||||
|
print("STAGE:analysts", flush=True)
|
||||||
|
print("STAGE:research", flush=True)
|
||||||
|
|
||||||
|
config = OrchestratorConfig(
|
||||||
|
quant_backtest_path=os.environ.get("QUANT_BACKTEST_PATH", ""),
|
||||||
|
trading_agents_config=trading_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
orchestrator = TradingOrchestrator(config)
|
||||||
|
|
||||||
|
print("STAGE:trading", flush=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = orchestrator.get_combined_signal(ticker, date)
|
||||||
|
except ValueError as exc:
|
||||||
|
print("ANALYSIS_ERROR:" + str(exc), file=sys.stderr, flush=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print("STAGE:risk", flush=True)
|
||||||
|
|
||||||
|
direction = result.direction
|
||||||
|
confidence = result.confidence
|
||||||
|
llm_sig_obj = result.llm_signal
|
||||||
|
quant_sig_obj = result.quant_signal
|
||||||
|
llm_signal = llm_sig_obj.metadata.get("rating", "HOLD") if llm_sig_obj else "HOLD"
|
||||||
|
if quant_sig_obj is None:
|
||||||
|
quant_signal = "HOLD"
|
||||||
|
elif quant_sig_obj.direction == 1:
|
||||||
|
quant_signal = "BUY" if quant_sig_obj.confidence >= 0.7 else "OVERWEIGHT"
|
||||||
|
elif quant_sig_obj.direction == -1:
|
||||||
|
quant_signal = "SELL" if quant_sig_obj.confidence >= 0.7 else "UNDERWEIGHT"
|
||||||
|
else:
|
||||||
|
quant_signal = "HOLD"
|
||||||
|
|
||||||
|
if direction == 1:
|
||||||
|
signal = "BUY" if confidence >= 0.7 else "OVERWEIGHT"
|
||||||
|
elif direction == -1:
|
||||||
|
signal = "SELL" if confidence >= 0.7 else "UNDERWEIGHT"
|
||||||
|
else:
|
||||||
|
signal = "HOLD"
|
||||||
|
|
||||||
|
results_dir = Path(repo_root) / "results" / ticker / date
|
||||||
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
report_content = (
|
||||||
|
"# TradingAgents 分析报告\\n\\n"
|
||||||
|
"**股票**: " + ticker + "\\n"
|
||||||
|
"**日期**: " + date + "\\n\\n"
|
||||||
|
"## 最终决策\\n\\n"
|
||||||
|
"**" + signal + "**\\n\\n"
|
||||||
|
"## 信号详情\\n\\n"
|
||||||
|
"- LLM 信号: " + llm_signal + "\\n"
|
||||||
|
"- Quant 信号: " + quant_signal + "\\n"
|
||||||
|
"- 置信度: " + f"{confidence:.1%}" + "\\n\\n"
|
||||||
|
"## 分析摘要\\n\\n"
|
||||||
|
"N/A\\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
report_path = results_dir / "complete_report.md"
|
||||||
|
report_path.write_text(report_content)
|
||||||
|
|
||||||
|
print("STAGE:portfolio", flush=True)
|
||||||
|
signal_detail = json.dumps({"llm_signal": llm_signal, "quant_signal": quant_signal, "confidence": confidence})
|
||||||
|
print("SIGNAL_DETAIL:" + signal_detail, flush=True)
|
||||||
|
print("ANALYSIS_COMPLETE:" + signal, flush=True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _rating_to_direction(rating: Optional[str]) -> int:
|
||||||
|
if rating in {"BUY", "OVERWEIGHT"}:
|
||||||
|
return 1
|
||||||
|
if rating in {"SELL", "UNDERWEIGHT"}:
|
||||||
|
return -1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AnalysisExecutionOutput:
|
||||||
|
decision: str
|
||||||
|
quant_signal: Optional[str]
|
||||||
|
llm_signal: Optional[str]
|
||||||
|
confidence: Optional[float]
|
||||||
|
report_path: Optional[str] = None
|
||||||
|
contract_version: str = CONTRACT_VERSION
|
||||||
|
executor_type: str = DEFAULT_EXECUTOR_TYPE
|
||||||
|
|
||||||
|
def to_result_contract(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
task_id: str,
|
||||||
|
ticker: str,
|
||||||
|
date: str,
|
||||||
|
created_at: str,
|
||||||
|
elapsed_seconds: int,
|
||||||
|
current_stage: str = "portfolio",
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"contract_version": self.contract_version,
|
||||||
|
"task_id": task_id,
|
||||||
|
"ticker": ticker,
|
||||||
|
"date": date,
|
||||||
|
"status": "completed",
|
||||||
|
"progress": 100,
|
||||||
|
"current_stage": current_stage,
|
||||||
|
"created_at": created_at,
|
||||||
|
"elapsed_seconds": elapsed_seconds,
|
||||||
|
"elapsed": elapsed_seconds,
|
||||||
|
"result": {
|
||||||
|
"decision": self.decision,
|
||||||
|
"confidence": self.confidence,
|
||||||
|
"signals": {
|
||||||
|
"merged": {
|
||||||
|
"direction": _rating_to_direction(self.decision),
|
||||||
|
"rating": self.decision,
|
||||||
|
},
|
||||||
|
"quant": {
|
||||||
|
"direction": _rating_to_direction(self.quant_signal),
|
||||||
|
"rating": self.quant_signal,
|
||||||
|
"available": self.quant_signal is not None,
|
||||||
|
},
|
||||||
|
"llm": {
|
||||||
|
"direction": _rating_to_direction(self.llm_signal),
|
||||||
|
"rating": self.llm_signal,
|
||||||
|
"available": self.llm_signal is not None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"degraded": self.quant_signal is None or self.llm_signal is None,
|
||||||
|
"report": {
|
||||||
|
"path": self.report_path,
|
||||||
|
"available": bool(self.report_path),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisExecutorError(RuntimeError):
|
||||||
|
def __init__(self, message: str, *, code: str = "analysis_failed", retryable: bool = False):
|
||||||
|
super().__init__(message)
|
||||||
|
self.code = code
|
||||||
|
self.retryable = retryable
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisExecutor(Protocol):
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
task_id: str,
|
||||||
|
ticker: str,
|
||||||
|
date: str,
|
||||||
|
request_context: RequestContext,
|
||||||
|
on_stage: Optional[StageCallback] = None,
|
||||||
|
) -> AnalysisExecutionOutput: ...
|
||||||
|
|
||||||
|
|
||||||
|
class LegacySubprocessAnalysisExecutor:
|
||||||
|
"""Run the legacy dashboard analysis script behind a stable executor contract."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
analysis_python: Path,
|
||||||
|
repo_root: Path,
|
||||||
|
api_key_resolver: Callable[[], Optional[str]],
|
||||||
|
process_registry: Optional[ProcessRegistry] = None,
|
||||||
|
script_template: str = LEGACY_ANALYSIS_SCRIPT_TEMPLATE,
|
||||||
|
stdout_timeout_secs: float = 300.0,
|
||||||
|
):
|
||||||
|
self.analysis_python = analysis_python
|
||||||
|
self.repo_root = repo_root
|
||||||
|
self.api_key_resolver = api_key_resolver
|
||||||
|
self.process_registry = process_registry
|
||||||
|
self.script_template = script_template
|
||||||
|
self.stdout_timeout_secs = stdout_timeout_secs
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
task_id: str,
|
||||||
|
ticker: str,
|
||||||
|
date: str,
|
||||||
|
request_context: RequestContext,
|
||||||
|
on_stage: Optional[StageCallback] = None,
|
||||||
|
) -> AnalysisExecutionOutput:
|
||||||
|
analysis_api_key = request_context.api_key or self.api_key_resolver()
|
||||||
|
if not analysis_api_key:
|
||||||
|
raise RuntimeError("ANTHROPIC_API_KEY environment variable not set")
|
||||||
|
|
||||||
|
script_path: Optional[Path] = None
|
||||||
|
proc: asyncio.subprocess.Process | None = None
|
||||||
|
try:
|
||||||
|
fd, script_path_str = tempfile.mkstemp(suffix=".py", prefix=f"analysis_{task_id}_")
|
||||||
|
script_path = Path(script_path_str)
|
||||||
|
os.chmod(script_path, 0o600)
|
||||||
|
with os.fdopen(fd, "w", encoding="utf-8") as handle:
|
||||||
|
handle.write(self.script_template)
|
||||||
|
|
||||||
|
clean_env = {
|
||||||
|
key: value
|
||||||
|
for key, value in os.environ.items()
|
||||||
|
if not key.startswith(("PYTHON", "CONDA", "VIRTUAL"))
|
||||||
|
}
|
||||||
|
clean_env["ANTHROPIC_API_KEY"] = analysis_api_key
|
||||||
|
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
str(self.analysis_python),
|
||||||
|
str(script_path),
|
||||||
|
ticker,
|
||||||
|
date,
|
||||||
|
str(self.repo_root),
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
env=clean_env,
|
||||||
|
)
|
||||||
|
if self.process_registry is not None:
|
||||||
|
self.process_registry(task_id, proc)
|
||||||
|
|
||||||
|
stdout_lines: list[str] = []
|
||||||
|
assert proc.stdout is not None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
line_bytes = await asyncio.wait_for(
|
||||||
|
proc.stdout.readline(),
|
||||||
|
timeout=self.stdout_timeout_secs,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError as exc:
|
||||||
|
await self._terminate_process(proc)
|
||||||
|
raise AnalysisExecutorError(
|
||||||
|
f"analysis subprocess timed out after {self.stdout_timeout_secs:g}s",
|
||||||
|
retryable=True,
|
||||||
|
) from exc
|
||||||
|
if not line_bytes:
|
||||||
|
break
|
||||||
|
line = line_bytes.decode(errors="replace").rstrip()
|
||||||
|
stdout_lines.append(line)
|
||||||
|
if on_stage is not None and line.startswith("STAGE:"):
|
||||||
|
await on_stage(line.split(":", 1)[1].strip())
|
||||||
|
|
||||||
|
await proc.wait()
|
||||||
|
stderr_bytes = await proc.stderr.read() if proc.stderr is not None else b""
|
||||||
|
if proc.returncode != 0:
|
||||||
|
message = stderr_bytes.decode(errors="replace")[-1000:] if stderr_bytes else f"exit {proc.returncode}"
|
||||||
|
raise AnalysisExecutorError(message)
|
||||||
|
|
||||||
|
return self._parse_output(
|
||||||
|
stdout_lines=stdout_lines,
|
||||||
|
ticker=ticker,
|
||||||
|
date=date,
|
||||||
|
contract_version=request_context.contract_version,
|
||||||
|
executor_type=request_context.executor_type,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if self.process_registry is not None:
|
||||||
|
self.process_registry(task_id, None)
|
||||||
|
if script_path is not None:
|
||||||
|
try:
|
||||||
|
script_path.unlink()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _terminate_process(proc: asyncio.subprocess.Process) -> None:
|
||||||
|
if proc.returncode is not None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
proc.kill()
|
||||||
|
except ProcessLookupError:
|
||||||
|
return
|
||||||
|
await proc.wait()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_output(
|
||||||
|
*,
|
||||||
|
stdout_lines: list[str],
|
||||||
|
ticker: str,
|
||||||
|
date: str,
|
||||||
|
contract_version: str,
|
||||||
|
executor_type: str,
|
||||||
|
) -> AnalysisExecutionOutput:
|
||||||
|
decision: Optional[str] = None
|
||||||
|
quant_signal = None
|
||||||
|
llm_signal = None
|
||||||
|
confidence = None
|
||||||
|
seen_signal_detail = False
|
||||||
|
seen_complete = False
|
||||||
|
|
||||||
|
for line in stdout_lines:
|
||||||
|
if line.startswith("SIGNAL_DETAIL:"):
|
||||||
|
seen_signal_detail = True
|
||||||
|
try:
|
||||||
|
detail = json.loads(line.split(":", 1)[1].strip())
|
||||||
|
except Exception as exc:
|
||||||
|
raise AnalysisExecutorError("failed to parse SIGNAL_DETAIL payload") from exc
|
||||||
|
quant_signal = detail.get("quant_signal")
|
||||||
|
llm_signal = detail.get("llm_signal")
|
||||||
|
confidence = detail.get("confidence")
|
||||||
|
elif line.startswith("ANALYSIS_COMPLETE:"):
|
||||||
|
seen_complete = True
|
||||||
|
decision = line.split(":", 1)[1].strip()
|
||||||
|
|
||||||
|
missing_markers = []
|
||||||
|
if not seen_signal_detail:
|
||||||
|
missing_markers.append("SIGNAL_DETAIL")
|
||||||
|
if not seen_complete:
|
||||||
|
missing_markers.append("ANALYSIS_COMPLETE")
|
||||||
|
if missing_markers:
|
||||||
|
raise AnalysisExecutorError(
|
||||||
|
"analysis subprocess completed without required markers: "
|
||||||
|
+ ", ".join(missing_markers)
|
||||||
|
)
|
||||||
|
|
||||||
|
report_path = str(Path("results") / ticker / date / "complete_report.md")
|
||||||
|
return AnalysisExecutionOutput(
|
||||||
|
decision=decision or "HOLD",
|
||||||
|
quant_signal=quant_signal,
|
||||||
|
llm_signal=llm_signal,
|
||||||
|
confidence=confidence,
|
||||||
|
report_path=report_path,
|
||||||
|
contract_version=contract_version,
|
||||||
|
executor_type=executor_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DirectAnalysisExecutor:
|
||||||
|
"""Placeholder for a future in-process executor implementation."""
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
task_id: str,
|
||||||
|
ticker: str,
|
||||||
|
date: str,
|
||||||
|
request_context: RequestContext,
|
||||||
|
on_stage: Optional[StageCallback] = None,
|
||||||
|
) -> AnalysisExecutionOutput:
|
||||||
|
del task_id, ticker, date, request_context, on_stage
|
||||||
|
raise NotImplementedError("DirectAnalysisExecutor is not implemented in phase 1")
|
||||||
|
|
@ -5,6 +5,10 @@ from datetime import datetime
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
|
||||||
|
CONTRACT_VERSION = "v1alpha1"
|
||||||
|
DEFAULT_EXECUTOR_TYPE = "legacy_subprocess"
|
||||||
|
|
||||||
|
|
||||||
class JobService:
|
class JobService:
|
||||||
"""Application-layer job state orchestrator with legacy-compatible payloads."""
|
"""Application-layer job state orchestrator with legacy-compatible payloads."""
|
||||||
|
|
||||||
|
|
@ -24,10 +28,71 @@ class JobService:
|
||||||
self.delete_task = delete_task
|
self.delete_task = delete_task
|
||||||
|
|
||||||
def restore_task_results(self, restored: dict[str, dict]) -> None:
|
def restore_task_results(self, restored: dict[str, dict]) -> None:
|
||||||
self.task_results.update(restored)
|
self.task_results.update(
|
||||||
|
{
|
||||||
|
task_id: self._normalize_task_state(task_id, state)
|
||||||
|
for task_id, state in restored.items()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def create_portfolio_job(self, *, task_id: str, total: int) -> dict:
|
def create_analysis_job(
|
||||||
state = {
|
self,
|
||||||
|
*,
|
||||||
|
task_id: str,
|
||||||
|
ticker: str,
|
||||||
|
date: str,
|
||||||
|
request_id: str | None = None,
|
||||||
|
executor_type: str = DEFAULT_EXECUTOR_TYPE,
|
||||||
|
contract_version: str = CONTRACT_VERSION,
|
||||||
|
result_ref: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
state = self._normalize_task_state(task_id, {
|
||||||
|
"task_id": task_id,
|
||||||
|
"ticker": ticker,
|
||||||
|
"date": date,
|
||||||
|
"status": "running",
|
||||||
|
"progress": 0,
|
||||||
|
"current_stage": "analysts",
|
||||||
|
"created_at": datetime.now().isoformat(),
|
||||||
|
"elapsed_seconds": 0,
|
||||||
|
"elapsed": 0,
|
||||||
|
"stages": [
|
||||||
|
{
|
||||||
|
"name": stage_name,
|
||||||
|
"status": "running" if index == 0 else "pending",
|
||||||
|
"completed_at": None,
|
||||||
|
}
|
||||||
|
for index, stage_name in enumerate(
|
||||||
|
["analysts", "research", "trading", "risk", "portfolio"]
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"logs": [],
|
||||||
|
"decision": None,
|
||||||
|
"quant_signal": None,
|
||||||
|
"llm_signal": None,
|
||||||
|
"confidence": None,
|
||||||
|
"result": None,
|
||||||
|
"error": None,
|
||||||
|
"request_id": request_id,
|
||||||
|
"executor_type": executor_type,
|
||||||
|
"contract_version": contract_version,
|
||||||
|
"result_ref": result_ref,
|
||||||
|
})
|
||||||
|
self.task_results[task_id] = state
|
||||||
|
self.processes.setdefault(task_id, None)
|
||||||
|
return state
|
||||||
|
|
||||||
|
def create_portfolio_job(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
task_id: str,
|
||||||
|
total: int,
|
||||||
|
request_id: str | None = None,
|
||||||
|
executor_type: str = DEFAULT_EXECUTOR_TYPE,
|
||||||
|
contract_version: str = CONTRACT_VERSION,
|
||||||
|
result_ref: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
state = self._normalize_task_state(task_id, {
|
||||||
"task_id": task_id,
|
"task_id": task_id,
|
||||||
"type": "portfolio",
|
"type": "portfolio",
|
||||||
"status": "running",
|
"status": "running",
|
||||||
|
|
@ -38,11 +103,65 @@ class JobService:
|
||||||
"results": [],
|
"results": [],
|
||||||
"error": None,
|
"error": None,
|
||||||
"created_at": datetime.now().isoformat(),
|
"created_at": datetime.now().isoformat(),
|
||||||
}
|
"request_id": request_id,
|
||||||
|
"executor_type": executor_type,
|
||||||
|
"contract_version": contract_version,
|
||||||
|
"result_ref": result_ref,
|
||||||
|
})
|
||||||
self.task_results[task_id] = state
|
self.task_results[task_id] = state
|
||||||
self.processes.setdefault(task_id, None)
|
self.processes.setdefault(task_id, None)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
def attach_result_contract(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
*,
|
||||||
|
result_ref: str,
|
||||||
|
contract_version: str = CONTRACT_VERSION,
|
||||||
|
executor_type: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
state = self.task_results[task_id]
|
||||||
|
state["result_ref"] = result_ref
|
||||||
|
state["contract_version"] = contract_version or state.get("contract_version") or CONTRACT_VERSION
|
||||||
|
if executor_type:
|
||||||
|
state["executor_type"] = executor_type
|
||||||
|
return state
|
||||||
|
|
||||||
|
def complete_analysis_job(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
*,
|
||||||
|
contract: dict,
|
||||||
|
result_ref: str,
|
||||||
|
executor_type: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
state = self.task_results[task_id]
|
||||||
|
result = dict(contract.get("result") or {})
|
||||||
|
signals = result.get("signals") or {}
|
||||||
|
quant = signals.get("quant") or {}
|
||||||
|
llm = signals.get("llm") or {}
|
||||||
|
|
||||||
|
state["status"] = contract.get("status", "completed")
|
||||||
|
state["progress"] = contract.get("progress", 100)
|
||||||
|
state["current_stage"] = contract.get("current_stage", state.get("current_stage"))
|
||||||
|
state["elapsed_seconds"] = contract.get("elapsed_seconds", state.get("elapsed_seconds", 0))
|
||||||
|
state["elapsed"] = contract.get("elapsed", state["elapsed_seconds"])
|
||||||
|
state["decision"] = result.get("decision")
|
||||||
|
state["quant_signal"] = quant.get("rating")
|
||||||
|
state["llm_signal"] = llm.get("rating")
|
||||||
|
state["confidence"] = result.get("confidence")
|
||||||
|
state["result"] = result
|
||||||
|
state["error"] = contract.get("error")
|
||||||
|
state["contract_version"] = contract.get("contract_version", state.get("contract_version"))
|
||||||
|
self.attach_result_contract(
|
||||||
|
task_id,
|
||||||
|
result_ref=result_ref,
|
||||||
|
contract_version=state["contract_version"],
|
||||||
|
executor_type=executor_type,
|
||||||
|
)
|
||||||
|
self.persist_task(task_id, state)
|
||||||
|
return state
|
||||||
|
|
||||||
def update_portfolio_progress(self, task_id: str, *, ticker: str, completed: int) -> dict:
|
def update_portfolio_progress(self, task_id: str, *, ticker: str, completed: int) -> dict:
|
||||||
state = self.task_results[task_id]
|
state = self.task_results[task_id]
|
||||||
state["current_ticker"] = ticker
|
state["current_ticker"] = ticker
|
||||||
|
|
@ -92,3 +211,12 @@ class JobService:
|
||||||
state["error"] = error
|
state["error"] = error
|
||||||
self.persist_task(task_id, state)
|
self.persist_task(task_id, state)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_task_state(task_id: str, state: dict) -> dict:
|
||||||
|
normalized = dict(state)
|
||||||
|
normalized.setdefault("request_id", task_id)
|
||||||
|
normalized.setdefault("executor_type", DEFAULT_EXECUTOR_TYPE)
|
||||||
|
normalized.setdefault("contract_version", CONTRACT_VERSION)
|
||||||
|
normalized.setdefault("result_ref", None)
|
||||||
|
return normalized
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,17 @@ from uuid import uuid4
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
|
|
||||||
|
CONTRACT_VERSION = "v1alpha1"
|
||||||
|
DEFAULT_EXECUTOR_TYPE = "legacy_subprocess"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class RequestContext:
|
class RequestContext:
|
||||||
"""Minimal request-scoped metadata passed into application services."""
|
"""Minimal request-scoped metadata passed into application services."""
|
||||||
|
|
||||||
request_id: str
|
request_id: str
|
||||||
|
contract_version: str = CONTRACT_VERSION
|
||||||
|
executor_type: str = DEFAULT_EXECUTOR_TYPE
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
client_host: Optional[str] = None
|
client_host: Optional[str] = None
|
||||||
is_local: bool = False
|
is_local: bool = False
|
||||||
|
|
@ -23,6 +29,8 @@ def build_request_context(
|
||||||
*,
|
*,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
request_id: Optional[str] = None,
|
request_id: Optional[str] = None,
|
||||||
|
contract_version: str = CONTRACT_VERSION,
|
||||||
|
executor_type: str = DEFAULT_EXECUTOR_TYPE,
|
||||||
metadata: Optional[dict[str, str]] = None,
|
metadata: Optional[dict[str, str]] = None,
|
||||||
) -> RequestContext:
|
) -> RequestContext:
|
||||||
"""Create a stable request context without leaking FastAPI internals into services."""
|
"""Create a stable request context without leaking FastAPI internals into services."""
|
||||||
|
|
@ -30,6 +38,8 @@ def build_request_context(
|
||||||
is_local = client_host in {"127.0.0.1", "::1", "localhost", "testclient"}
|
is_local = client_host in {"127.0.0.1", "::1", "localhost", "testclient"}
|
||||||
return RequestContext(
|
return RequestContext(
|
||||||
request_id=request_id or uuid4().hex,
|
request_id=request_id or uuid4().hex,
|
||||||
|
contract_version=contract_version,
|
||||||
|
executor_type=executor_type,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
client_host=client_host,
|
client_host=client_host,
|
||||||
is_local=is_local,
|
is_local=is_local,
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,15 @@ from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
CONTRACT_VERSION = "v1alpha1"
|
||||||
|
|
||||||
|
|
||||||
class ResultStore:
|
class ResultStore:
|
||||||
"""Storage boundary for persisted task state and portfolio results."""
|
"""Storage boundary for persisted task state and portfolio results."""
|
||||||
|
|
||||||
def __init__(self, task_status_dir: Path, portfolio_gateway):
|
def __init__(self, task_status_dir: Path, portfolio_gateway):
|
||||||
self.task_status_dir = task_status_dir
|
self.task_status_dir = task_status_dir
|
||||||
|
self.result_contract_dir = self.task_status_dir / "result_contracts"
|
||||||
self.portfolio_gateway = portfolio_gateway
|
self.portfolio_gateway = portfolio_gateway
|
||||||
|
|
||||||
def restore_task_results(self) -> dict[str, dict]:
|
def restore_task_results(self) -> dict[str, dict]:
|
||||||
|
|
@ -29,6 +33,15 @@ class ResultStore:
|
||||||
self.task_status_dir.mkdir(parents=True, exist_ok=True)
|
self.task_status_dir.mkdir(parents=True, exist_ok=True)
|
||||||
(self.task_status_dir / f"{task_id}.json").write_text(json.dumps(data, ensure_ascii=False))
|
(self.task_status_dir / f"{task_id}.json").write_text(json.dumps(data, ensure_ascii=False))
|
||||||
|
|
||||||
|
def save_result_contract(self, task_id: str, contract: dict) -> str:
|
||||||
|
self.result_contract_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
payload = dict(contract)
|
||||||
|
payload.setdefault("task_id", task_id)
|
||||||
|
payload.setdefault("contract_version", CONTRACT_VERSION)
|
||||||
|
file_path = self.result_contract_dir / f"{task_id}.json"
|
||||||
|
file_path.write_text(json.dumps(payload, ensure_ascii=False))
|
||||||
|
return file_path.relative_to(self.task_status_dir).as_posix()
|
||||||
|
|
||||||
def delete_task_status(self, task_id: str) -> None:
|
def delete_task_status(self, task_id: str) -> None:
|
||||||
(self.task_status_dir / f"{task_id}.json").unlink(missing_ok=True)
|
(self.task_status_dir / f"{task_id}.json").unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,3 +53,73 @@ def test_analysis_task_routes_smoke(monkeypatch):
|
||||||
assert any(task["task_id"] == "task-smoke" for task in tasks_response.json()["tasks"])
|
assert any(task["task_id"] == "task-smoke" for task in tasks_response.json()["tasks"])
|
||||||
assert status_response.status_code == 200
|
assert status_response.status_code == 200
|
||||||
assert status_response.json()["task_id"] == "task-smoke"
|
assert status_response.json()["task_id"] == "task-smoke"
|
||||||
|
|
||||||
|
|
||||||
|
def test_analysis_start_route_uses_analysis_service(monkeypatch):
|
||||||
|
monkeypatch.delenv("DASHBOARD_API_KEY", raising=False)
|
||||||
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||||
|
|
||||||
|
main = _load_main_module(monkeypatch)
|
||||||
|
created: dict[str, object] = {}
|
||||||
|
|
||||||
|
class DummyTask:
|
||||||
|
def cancel(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def fake_create_task(coro):
|
||||||
|
created["scheduled_coro"] = coro.cr_code.co_name
|
||||||
|
coro.close()
|
||||||
|
task = DummyTask()
|
||||||
|
created["task"] = task
|
||||||
|
return task
|
||||||
|
|
||||||
|
monkeypatch.setattr(main.asyncio, "create_task", fake_create_task)
|
||||||
|
|
||||||
|
with TestClient(main.app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/api/analysis/start",
|
||||||
|
json={"ticker": "AAPL", "date": "2026-04-11"},
|
||||||
|
headers={"api-key": "test-key"},
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = response.json()
|
||||||
|
task_id = payload["task_id"]
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert payload["ticker"] == "AAPL"
|
||||||
|
assert payload["date"] == "2026-04-11"
|
||||||
|
assert payload["status"] == "running"
|
||||||
|
assert created["scheduled_coro"] == "_run_analysis"
|
||||||
|
assert main.app.state.analysis_tasks[task_id] is created["task"]
|
||||||
|
assert main.app.state.task_results[task_id]["current_stage"] == "analysts"
|
||||||
|
assert main.app.state.task_results[task_id]["status"] == "running"
|
||||||
|
assert main.app.state.task_results[task_id]["request_id"]
|
||||||
|
assert main.app.state.task_results[task_id]["executor_type"] == "legacy_subprocess"
|
||||||
|
assert main.app.state.task_results[task_id]["result_ref"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_portfolio_analyze_route_uses_analysis_service_smoke(monkeypatch):
|
||||||
|
monkeypatch.delenv("DASHBOARD_API_KEY", raising=False)
|
||||||
|
monkeypatch.setenv("TRADINGAGENTS_USE_APPLICATION_SERVICES", "1")
|
||||||
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "service-key")
|
||||||
|
|
||||||
|
main = _load_main_module(monkeypatch)
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
async def fake_start_portfolio_analysis(*, task_id, date, request_context, broadcast_progress):
|
||||||
|
captured["task_id"] = task_id
|
||||||
|
captured["date"] = date
|
||||||
|
captured["request_context"] = request_context
|
||||||
|
captured["broadcast_progress"] = broadcast_progress
|
||||||
|
return {"task_id": task_id, "status": "running", "total": 3}
|
||||||
|
|
||||||
|
with TestClient(main.app) as client:
|
||||||
|
monkeypatch.setattr(main.app.state.analysis_service, "start_portfolio_analysis", fake_start_portfolio_analysis)
|
||||||
|
response = client.post("/api/portfolio/analyze", headers={"api-key": "service-key"})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["status"] == "running"
|
||||||
|
assert str(captured["task_id"]).startswith("port_")
|
||||||
|
assert isinstance(captured["date"], str)
|
||||||
|
assert captured["request_context"].api_key == "service-key"
|
||||||
|
assert callable(captured["broadcast_progress"])
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,112 @@
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from services.executor import AnalysisExecutorError, LegacySubprocessAnalysisExecutor
|
||||||
|
from services.request_context import build_request_context
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeStdout:
|
||||||
|
def __init__(self, lines, *, stall: bool = False):
|
||||||
|
self._lines = list(lines)
|
||||||
|
self._stall = stall
|
||||||
|
|
||||||
|
async def readline(self):
|
||||||
|
if self._stall:
|
||||||
|
await asyncio.sleep(3600)
|
||||||
|
if self._lines:
|
||||||
|
return self._lines.pop(0)
|
||||||
|
return b""
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeStderr:
|
||||||
|
def __init__(self, payload: bytes = b""):
|
||||||
|
self._payload = payload
|
||||||
|
|
||||||
|
async def read(self):
|
||||||
|
return self._payload
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeProcess:
|
||||||
|
def __init__(self, stdout, *, stderr: bytes = b"", returncode=None):
|
||||||
|
self.stdout = stdout
|
||||||
|
self.stderr = _FakeStderr(stderr)
|
||||||
|
self.returncode = returncode
|
||||||
|
self.kill_called = False
|
||||||
|
self.wait_called = False
|
||||||
|
|
||||||
|
async def wait(self):
|
||||||
|
self.wait_called = True
|
||||||
|
if self.returncode is None:
|
||||||
|
self.returncode = -9 if self.kill_called else 0
|
||||||
|
return self.returncode
|
||||||
|
|
||||||
|
def kill(self):
|
||||||
|
self.kill_called = True
|
||||||
|
self.returncode = -9
|
||||||
|
|
||||||
|
|
||||||
|
def test_executor_raises_when_required_markers_missing(monkeypatch):
|
||||||
|
process = _FakeProcess(
|
||||||
|
_FakeStdout(
|
||||||
|
[
|
||||||
|
b"STAGE:analysts\n",
|
||||||
|
b"STAGE:portfolio\n",
|
||||||
|
b"SIGNAL_DETAIL:{\"quant_signal\":\"BUY\",\"llm_signal\":\"BUY\",\"confidence\":0.8}\n",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
returncode=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fake_create_subprocess_exec(*args, **kwargs):
|
||||||
|
return process
|
||||||
|
|
||||||
|
monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_create_subprocess_exec)
|
||||||
|
|
||||||
|
executor = LegacySubprocessAnalysisExecutor(
|
||||||
|
analysis_python=Path("/usr/bin/python3"),
|
||||||
|
repo_root=Path("."),
|
||||||
|
api_key_resolver=lambda: "env-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
with pytest.raises(AnalysisExecutorError, match="required markers: ANALYSIS_COMPLETE"):
|
||||||
|
await executor.execute(
|
||||||
|
task_id="task-1",
|
||||||
|
ticker="AAPL",
|
||||||
|
date="2026-04-13",
|
||||||
|
request_context=build_request_context(api_key="ctx-key"),
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
|
||||||
|
def test_executor_kills_subprocess_on_timeout(monkeypatch):
|
||||||
|
process = _FakeProcess(_FakeStdout([], stall=True))
|
||||||
|
|
||||||
|
async def fake_create_subprocess_exec(*args, **kwargs):
|
||||||
|
return process
|
||||||
|
|
||||||
|
monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_create_subprocess_exec)
|
||||||
|
|
||||||
|
executor = LegacySubprocessAnalysisExecutor(
|
||||||
|
analysis_python=Path("/usr/bin/python3"),
|
||||||
|
repo_root=Path("."),
|
||||||
|
api_key_resolver=lambda: "env-key",
|
||||||
|
stdout_timeout_secs=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
with pytest.raises(AnalysisExecutorError, match="timed out"):
|
||||||
|
await executor.execute(
|
||||||
|
task_id="task-2",
|
||||||
|
ticker="AAPL",
|
||||||
|
date="2026-04-13",
|
||||||
|
request_context=build_request_context(api_key="ctx-key"),
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
assert process.kill_called is True
|
||||||
|
assert process.wait_called is True
|
||||||
|
|
@ -2,6 +2,7 @@ import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from services.analysis_service import AnalysisService
|
from services.analysis_service import AnalysisService
|
||||||
|
from services.executor import AnalysisExecutionOutput
|
||||||
from services.job_service import JobService
|
from services.job_service import JobService
|
||||||
from services.migration_flags import load_migration_flags
|
from services.migration_flags import load_migration_flags
|
||||||
from services.request_context import build_request_context
|
from services.request_context import build_request_context
|
||||||
|
|
@ -48,6 +49,8 @@ def test_build_request_context_defaults():
|
||||||
|
|
||||||
assert context.api_key == "secret"
|
assert context.api_key == "secret"
|
||||||
assert context.request_id
|
assert context.request_id
|
||||||
|
assert context.contract_version == "v1alpha1"
|
||||||
|
assert context.executor_type == "legacy_subprocess"
|
||||||
assert context.metadata == {"source": "test"}
|
assert context.metadata == {"source": "test"}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -64,6 +67,23 @@ def test_result_store_round_trip(tmp_path):
|
||||||
assert positions == [{"ticker": "AAPL", "account": "模拟账户"}]
|
assert positions == [{"ticker": "AAPL", "account": "模拟账户"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_store_saves_result_contract(tmp_path):
|
||||||
|
gateway = DummyPortfolioGateway()
|
||||||
|
store = ResultStore(tmp_path / "task_status", gateway)
|
||||||
|
|
||||||
|
result_ref = store.save_result_contract(
|
||||||
|
"task-2",
|
||||||
|
{"status": "completed", "result": {"decision": "BUY"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
saved = json.loads((tmp_path / "task_status" / result_ref).read_text())
|
||||||
|
|
||||||
|
assert result_ref == "result_contracts/task-2.json"
|
||||||
|
assert saved["task_id"] == "task-2"
|
||||||
|
assert saved["contract_version"] == "v1alpha1"
|
||||||
|
assert saved["result"]["decision"] == "BUY"
|
||||||
|
|
||||||
|
|
||||||
def test_job_service_create_and_fail_job():
|
def test_job_service_create_and_fail_job():
|
||||||
task_results = {}
|
task_results = {}
|
||||||
analysis_tasks = {}
|
analysis_tasks = {}
|
||||||
|
|
@ -78,15 +98,48 @@ def test_job_service_create_and_fail_job():
|
||||||
delete_task=lambda task_id: persisted.pop(task_id, None),
|
delete_task=lambda task_id: persisted.pop(task_id, None),
|
||||||
)
|
)
|
||||||
|
|
||||||
state = service.create_portfolio_job(task_id="port_1", total=2)
|
state = service.create_portfolio_job(
|
||||||
|
task_id="port_1",
|
||||||
|
total=2,
|
||||||
|
request_id="req-1",
|
||||||
|
executor_type="analysis_executor",
|
||||||
|
)
|
||||||
assert state["total"] == 2
|
assert state["total"] == 2
|
||||||
assert processes["port_1"] is None
|
assert processes["port_1"] is None
|
||||||
|
assert state["request_id"] == "req-1"
|
||||||
|
assert state["executor_type"] == "analysis_executor"
|
||||||
|
assert state["contract_version"] == "v1alpha1"
|
||||||
|
assert state["result_ref"] is None
|
||||||
|
|
||||||
|
attached = service.attach_result_contract(
|
||||||
|
"port_1",
|
||||||
|
result_ref="result_contracts/port_1.json",
|
||||||
|
)
|
||||||
|
assert attached["result_ref"] == "result_contracts/port_1.json"
|
||||||
|
|
||||||
failed = service.fail_job("port_1", "boom")
|
failed = service.fail_job("port_1", "boom")
|
||||||
assert failed["status"] == "failed"
|
assert failed["status"] == "failed"
|
||||||
assert persisted["port_1"]["error"] == "boom"
|
assert persisted["port_1"]["error"] == "boom"
|
||||||
|
|
||||||
|
|
||||||
|
def test_job_service_restores_legacy_tasks_with_contract_metadata():
|
||||||
|
service = JobService(
|
||||||
|
task_results={},
|
||||||
|
analysis_tasks={},
|
||||||
|
processes={},
|
||||||
|
persist_task=lambda task_id, data: None,
|
||||||
|
delete_task=lambda task_id: None,
|
||||||
|
)
|
||||||
|
|
||||||
|
service.restore_task_results({"legacy-task": {"task_id": "legacy-task", "status": "running"}})
|
||||||
|
|
||||||
|
restored = service.task_results["legacy-task"]
|
||||||
|
assert restored["request_id"] == "legacy-task"
|
||||||
|
assert restored["executor_type"] == "legacy_subprocess"
|
||||||
|
assert restored["contract_version"] == "v1alpha1"
|
||||||
|
assert restored["result_ref"] is None
|
||||||
|
|
||||||
|
|
||||||
def test_analysis_service_build_recommendation_record():
|
def test_analysis_service_build_recommendation_record():
|
||||||
rec = AnalysisService._build_recommendation_record(
|
rec = AnalysisService._build_recommendation_record(
|
||||||
stdout='\n'.join([
|
stdout='\n'.join([
|
||||||
|
|
@ -103,3 +156,74 @@ def test_analysis_service_build_recommendation_record():
|
||||||
assert rec["quant_signal"] == "BUY"
|
assert rec["quant_signal"] == "BUY"
|
||||||
assert rec["llm_signal"] == "HOLD"
|
assert rec["llm_signal"] == "HOLD"
|
||||||
assert rec["confidence"] == 0.75
|
assert rec["confidence"] == 0.75
|
||||||
|
|
||||||
|
|
||||||
|
class FakeExecutor:
|
||||||
|
async def execute(self, *, task_id, ticker, date, request_context, on_stage=None):
|
||||||
|
if on_stage is not None:
|
||||||
|
await on_stage("analysts")
|
||||||
|
await on_stage("research")
|
||||||
|
await on_stage("trading")
|
||||||
|
await on_stage("risk")
|
||||||
|
await on_stage("portfolio")
|
||||||
|
return AnalysisExecutionOutput(
|
||||||
|
decision="BUY",
|
||||||
|
quant_signal="OVERWEIGHT",
|
||||||
|
llm_signal="BUY",
|
||||||
|
confidence=0.82,
|
||||||
|
report_path=f"results/{ticker}/{date}/complete_report.md",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_analysis_service_start_analysis_uses_executor(tmp_path):
|
||||||
|
gateway = DummyPortfolioGateway()
|
||||||
|
store = ResultStore(tmp_path / "task_status", gateway)
|
||||||
|
task_results = {}
|
||||||
|
analysis_tasks = {}
|
||||||
|
processes = {}
|
||||||
|
service = JobService(
|
||||||
|
task_results=task_results,
|
||||||
|
analysis_tasks=analysis_tasks,
|
||||||
|
processes=processes,
|
||||||
|
persist_task=store.save_task_status,
|
||||||
|
delete_task=store.delete_task_status,
|
||||||
|
)
|
||||||
|
analysis_service = AnalysisService(
|
||||||
|
executor=FakeExecutor(),
|
||||||
|
result_store=store,
|
||||||
|
job_service=service,
|
||||||
|
)
|
||||||
|
broadcasts = []
|
||||||
|
|
||||||
|
async def _broadcast(task_id, payload):
|
||||||
|
broadcasts.append((task_id, payload["status"], payload.get("current_stage")))
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
response = await analysis_service.start_analysis(
|
||||||
|
task_id="task-1",
|
||||||
|
ticker="AAPL",
|
||||||
|
date="2026-04-13",
|
||||||
|
request_context=build_request_context(api_key="secret"),
|
||||||
|
broadcast_progress=_broadcast,
|
||||||
|
)
|
||||||
|
await analysis_tasks["task-1"]
|
||||||
|
return response
|
||||||
|
|
||||||
|
response = asyncio.run(scenario())
|
||||||
|
|
||||||
|
assert response == {
|
||||||
|
"contract_version": "v1alpha1",
|
||||||
|
"task_id": "task-1",
|
||||||
|
"ticker": "AAPL",
|
||||||
|
"date": "2026-04-13",
|
||||||
|
"status": "running",
|
||||||
|
}
|
||||||
|
assert task_results["task-1"]["status"] == "completed"
|
||||||
|
assert task_results["task-1"]["decision"] == "BUY"
|
||||||
|
assert task_results["task-1"]["result_ref"] == "result_contracts/task-1.json"
|
||||||
|
assert task_results["task-1"]["result"]["signals"]["llm"]["rating"] == "BUY"
|
||||||
|
saved_contract = json.loads((tmp_path / "task_status" / "result_contracts" / "task-1.json").read_text())
|
||||||
|
assert saved_contract["status"] == "completed"
|
||||||
|
assert saved_contract["result"]["signals"]["merged"]["rating"] == "BUY"
|
||||||
|
assert broadcasts[0] == ("task-1", "running", "analysts")
|
||||||
|
assert broadcasts[-1][1] == "completed"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue