fix(dashboard): secure API key handling and add stage progress streaming
- Pass ANTHROPIC_API_KEY via env dict instead of CLI args (P1 security fix) - Add monitor_subprocess() coroutine with fcntl non-blocking reads - Inject STAGE markers (analysts/research/trading/risk/portfolio) into script stdout - Update task stage state and broadcast WebSocket progress at each stage boundary - Add asyncio.Event for monitor cancellation on task completion/cancel Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ddf34222e3
commit
30db76b9b1
|
|
@ -3,6 +3,7 @@ TradingAgents Web Dashboard Backend
|
|||
FastAPI REST API + WebSocket for real-time analysis progress
|
||||
"""
|
||||
import asyncio
|
||||
import fcntl
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
|
|
@ -130,17 +131,15 @@ async def screen_stocks(mode: str = Query("china_strict"), refresh: bool = Query
|
|||
# ============== Analysis Execution ==============
|
||||
|
||||
# Script template for subprocess-based analysis
|
||||
# ticker and date are passed as command-line args to avoid injection
|
||||
# api_key is passed via environment variable (not CLI) for security
|
||||
ANALYSIS_SCRIPT_TEMPLATE = """
|
||||
import sys
|
||||
import os
|
||||
ticker = sys.argv[1]
|
||||
date = sys.argv[2]
|
||||
repo_root = sys.argv[3]
|
||||
api_key = sys.argv[4]
|
||||
|
||||
sys.path.insert(0, repo_root)
|
||||
import os
|
||||
os.environ["ANTHROPIC_API_KEY"] = api_key
|
||||
os.environ["ANTHROPIC_BASE_URL"] = "https://api.minimaxi.com/anthropic"
|
||||
import py_mini_racer
|
||||
sys.modules["mini_racer"] = py_mini_racer
|
||||
|
|
@ -148,6 +147,8 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from pathlib import Path
|
||||
|
||||
print("STAGE:analysts", flush=True)
|
||||
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["llm_provider"] = "anthropic"
|
||||
config["deep_think_llm"] = "MiniMax-M2.7-highspeed"
|
||||
|
|
@ -156,9 +157,15 @@ config["backend_url"] = "https://api.minimaxi.com/anthropic"
|
|||
config["max_debate_rounds"] = 1
|
||||
config["max_risk_discuss_rounds"] = 1
|
||||
|
||||
print("STAGE:research", flush=True)
|
||||
|
||||
ta = TradingAgentsGraph(debug=False, config=config)
|
||||
print("STAGE:trading", flush=True)
|
||||
|
||||
final_state, decision = ta.propagate(ticker, date)
|
||||
|
||||
print("STAGE:risk", flush=True)
|
||||
|
||||
results_dir = Path(repo_root) / "results" / ticker / date
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -178,7 +185,8 @@ report_content = (
|
|||
report_path = results_dir / "complete_report.md"
|
||||
report_path.write_text(report_content)
|
||||
|
||||
print("ANALYSIS_COMPLETE:" + signal)
|
||||
print("STAGE:portfolio", flush=True)
|
||||
print("ANALYSIS_COMPLETE:" + signal, flush=True)
|
||||
"""
|
||||
|
||||
|
||||
|
|
@ -225,6 +233,60 @@ async def start_analysis(request: AnalysisRequest):
|
|||
app.state.processes = getattr(app.state, 'processes', {})
|
||||
app.state.processes[task_id] = None
|
||||
|
||||
# Cancellation event for the monitor coroutine
|
||||
cancel_event = asyncio.Event()
|
||||
|
||||
# Stage name to index mapping
|
||||
STAGE_NAMES = ["analysts", "research", "trading", "risk", "portfolio"]
|
||||
|
||||
def _update_task_stage(stage_name: str):
|
||||
"""Update task state for a completed stage and mark next as running."""
|
||||
try:
|
||||
idx = STAGE_NAMES.index(stage_name)
|
||||
except ValueError:
|
||||
return
|
||||
# Mark all previous stages as completed
|
||||
for i in range(idx):
|
||||
if app.state.task_results[task_id]["stages"][i]["status"] != "completed":
|
||||
app.state.task_results[task_id]["stages"][i]["status"] = "completed"
|
||||
app.state.task_results[task_id]["stages"][i]["completed_at"] = datetime.now().strftime("%H:%M:%S")
|
||||
# Mark current as completed
|
||||
if app.state.task_results[task_id]["stages"][idx]["status"] != "completed":
|
||||
app.state.task_results[task_id]["stages"][idx]["status"] = "completed"
|
||||
app.state.task_results[task_id]["stages"][idx]["completed_at"] = datetime.now().strftime("%H:%M:%S")
|
||||
# Mark next as running
|
||||
if idx + 1 < 5:
|
||||
if app.state.task_results[task_id]["stages"][idx + 1]["status"] == "pending":
|
||||
app.state.task_results[task_id]["stages"][idx + 1]["status"] = "running"
|
||||
# Update progress
|
||||
app.state.task_results[task_id]["progress"] = int((idx + 1) / 5 * 100)
|
||||
app.state.task_results[task_id]["current_stage"] = stage_name
|
||||
|
||||
async def monitor_subprocess(task_id: str, proc: asyncio.subprocess.Process, cancel_evt: asyncio.Event):
|
||||
"""Monitor subprocess stdout for stage markers and broadcast progress."""
|
||||
# Set stdout to non-blocking
|
||||
fd = proc.stdout.fileno()
|
||||
fl = fcntl.fcntl(fd, fcntl.GETFL)
|
||||
fcntl.fcntl(fd, fcntl.SETFL, fl | os.O_NONBLOCK)
|
||||
|
||||
while not cancel_evt.is_set():
|
||||
if proc.returncode is not None:
|
||||
break
|
||||
await asyncio.sleep(5)
|
||||
if cancel_evt.is_set():
|
||||
break
|
||||
try:
|
||||
chunk = os.read(fd, 32768)
|
||||
if chunk:
|
||||
for line in chunk.decode().splitlines():
|
||||
if line.startswith("STAGE:"):
|
||||
stage = line.split(":", 1)[1].strip()
|
||||
_update_task_stage(stage)
|
||||
await broadcast_progress(task_id, app.state.task_results[task_id])
|
||||
except (BlockingIOError, OSError):
|
||||
# No data available yet
|
||||
pass
|
||||
|
||||
async def run_analysis():
|
||||
"""Run analysis subprocess and broadcast progress"""
|
||||
try:
|
||||
|
|
@ -240,15 +302,24 @@ async def start_analysis(request: AnalysisRequest):
|
|||
request.ticker,
|
||||
date,
|
||||
str(REPO_ROOT),
|
||||
api_key,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=clean_env,
|
||||
)
|
||||
app.state.processes[task_id] = proc
|
||||
|
||||
# Start monitor coroutine alongside subprocess
|
||||
monitor_task = asyncio.create_task(monitor_subprocess(task_id, proc, cancel_event))
|
||||
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
# Signal monitor to stop and wait for it
|
||||
cancel_event.set()
|
||||
try:
|
||||
await asyncio.wait_for(monitor_task, timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
monitor_task.cancel()
|
||||
|
||||
# Clean up script file
|
||||
try:
|
||||
script_path.unlink()
|
||||
|
|
@ -258,7 +329,7 @@ async def start_analysis(request: AnalysisRequest):
|
|||
if proc.returncode == 0:
|
||||
output = stdout.decode()
|
||||
decision = "HOLD"
|
||||
for line in output.split("\n"):
|
||||
for line in output.splitlines():
|
||||
if line.startswith("ANALYSIS_COMPLETE:"):
|
||||
decision = line.split(":", 1)[1].strip()
|
||||
|
||||
|
|
@ -268,13 +339,15 @@ async def start_analysis(request: AnalysisRequest):
|
|||
app.state.task_results[task_id]["current_stage"] = "portfolio"
|
||||
for i in range(5):
|
||||
app.state.task_results[task_id]["stages"][i]["status"] = "completed"
|
||||
app.state.task_results[task_id]["stages"][i]["completed_at"] = datetime.now().strftime("%H:%M:%S")
|
||||
if not app.state.task_results[task_id]["stages"][i].get("completed_at"):
|
||||
app.state.task_results[task_id]["stages"][i]["completed_at"] = datetime.now().strftime("%H:%M:%S")
|
||||
else:
|
||||
error_msg = stderr.decode()[-1000:] if stderr else "Unknown error"
|
||||
app.state.task_results[task_id]["status"] = "failed"
|
||||
app.state.task_results[task_id]["error"] = error_msg
|
||||
|
||||
except Exception as e:
|
||||
cancel_event.set()
|
||||
app.state.task_results[task_id]["status"] = "failed"
|
||||
app.state.task_results[task_id]["error"] = str(e)
|
||||
try:
|
||||
|
|
|
|||
Loading…
Reference in New Issue