Harden security, fix memory leak, clean up deps
- Add API key auth (AGENTS_API_KEY env var) on /analyze endpoints - Add CORS_ORIGINS env var instead of hardcoded wildcard - Add memory cleanup (30min TTL) and concurrency semaphore (max 3) - Add 10-minute analysis timeout - Fix ticker validation (alphanumeric check) - Remove unused deps (redis, backtrader, parsel, rich, typer, questionary) - Fix pyproject.toml: replace chainlit with actual FastAPI deps - Add .dockerignore, add eval_results/ to .gitignore Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ba39a81e82
commit
3ac1c5ad3d
|
|
@ -0,0 +1,8 @@
|
||||||
|
.git
|
||||||
|
eval_results
|
||||||
|
*.txt
|
||||||
|
docs
|
||||||
|
uv.lock
|
||||||
|
__pycache__
|
||||||
|
.env
|
||||||
|
.env.example
|
||||||
|
|
@ -217,3 +217,4 @@ __marimo__/
|
||||||
|
|
||||||
# Cache
|
# Cache
|
||||||
**/data_cache/
|
**/data_cache/
|
||||||
|
eval_results/
|
||||||
|
|
|
||||||
91
app.py
91
app.py
|
|
@ -5,9 +5,10 @@ import time
|
||||||
import uuid
|
import uuid
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import traceback as _tb
|
||||||
from datetime import date
|
from datetime import date
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException, Request, Depends
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
@ -22,14 +23,34 @@ from cli.main import (
|
||||||
)
|
)
|
||||||
|
|
||||||
app = FastAPI(title="TradingAgents API")
|
app = FastAPI(title="TradingAgents API")
|
||||||
|
|
||||||
|
# --- CORS ---
|
||||||
|
_cors_env = os.getenv("CORS_ORIGINS", "")
|
||||||
|
_cors_origins = [o.strip() for o in _cors_env.split(",") if o.strip()] if _cors_env else ["*"]
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=_cors_origins,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Active analysis state: id -> {queue, events (replay buffer), done}
|
# --- Auth dependency ---
|
||||||
|
_API_KEY = os.getenv("AGENTS_API_KEY", "")
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_api_key(request: Request):
|
||||||
|
if not _API_KEY:
|
||||||
|
return # dev mode — no auth
|
||||||
|
auth = request.headers.get("Authorization", "")
|
||||||
|
if auth != f"Bearer {_API_KEY}":
|
||||||
|
raise HTTPException(401, "Invalid or missing API key")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Concurrency ---
|
||||||
|
MAX_CONCURRENT = int(os.getenv("MAX_CONCURRENT_ANALYSES", "3"))
|
||||||
|
_semaphore = asyncio.Semaphore(MAX_CONCURRENT)
|
||||||
|
|
||||||
|
# Active analysis state: id -> {queue, events (replay buffer), done, created_at}
|
||||||
analyses: dict[str, dict] = {}
|
analyses: dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -90,9 +111,8 @@ def _agent_stage(agent_name):
|
||||||
return "unknown"
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
async def _run_analysis_inner(analysis_id: str, ticker: str, trade_date: str):
|
||||||
"""Background task that runs the TradingAgents pipeline and pushes SSE events."""
|
"""Core analysis logic."""
|
||||||
import traceback as _tb
|
|
||||||
state = analyses[analysis_id]
|
state = analyses[analysis_id]
|
||||||
q = state["queue"]
|
q = state["queue"]
|
||||||
config = build_config()
|
config = build_config()
|
||||||
|
|
@ -205,8 +225,7 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||||
state["events"].append(evt)
|
state["events"].append(evt)
|
||||||
await q.put(evt)
|
await q.put(evt)
|
||||||
|
|
||||||
# Research debate (guard with research_emitted to avoid resetting
|
# Research debate
|
||||||
# statuses on subsequent chunks in stream_mode="values")
|
|
||||||
if chunk.get("investment_debate_state") and not research_emitted:
|
if chunk.get("investment_debate_state") and not research_emitted:
|
||||||
debate = chunk["investment_debate_state"]
|
debate = chunk["investment_debate_state"]
|
||||||
bull = debate.get("bull_history", "").strip()
|
bull = debate.get("bull_history", "").strip()
|
||||||
|
|
@ -251,8 +270,7 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||||
state["events"].append(evt)
|
state["events"].append(evt)
|
||||||
await q.put(evt)
|
await q.put(evt)
|
||||||
|
|
||||||
# Risk debate (guard with risk_emitted to avoid resetting
|
# Risk debate
|
||||||
# statuses on subsequent chunks in stream_mode="values")
|
|
||||||
if chunk.get("risk_debate_state") and not risk_emitted:
|
if chunk.get("risk_debate_state") and not risk_emitted:
|
||||||
risk = chunk["risk_debate_state"]
|
risk = chunk["risk_debate_state"]
|
||||||
agg = risk.get("aggressive_history", "").strip()
|
agg = risk.get("aggressive_history", "").strip()
|
||||||
|
|
@ -302,7 +320,6 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||||
for agent in buf.agent_status:
|
for agent in buf.agent_status:
|
||||||
buf.update_agent_status(agent, "completed")
|
buf.update_agent_status(agent, "completed")
|
||||||
st = get_stats_dict(stats_handler, buf, start_time)
|
st = get_stats_dict(stats_handler, buf, start_time)
|
||||||
# Emit agent_update for any agents not yet shown as completed on the client
|
|
||||||
for agent, status in buf.agent_status.items():
|
for agent, status in buf.agent_status.items():
|
||||||
if prev_statuses.get(agent) != "completed":
|
if prev_statuses.get(agent) != "completed":
|
||||||
prev_statuses[agent] = "completed"
|
prev_statuses[agent] = "completed"
|
||||||
|
|
@ -329,19 +346,63 @@ async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||||
await q.put(None) # sentinel — stream done
|
await q.put(None) # sentinel — stream done
|
||||||
|
|
||||||
|
|
||||||
@app.post("/analyze")
|
async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
|
||||||
|
"""Background task: acquires semaphore, runs analysis with timeout."""
|
||||||
|
state = analyses[analysis_id]
|
||||||
|
q = state["queue"]
|
||||||
|
async with _semaphore:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
_run_analysis_inner(analysis_id, ticker, trade_date),
|
||||||
|
timeout=600, # 10 minutes
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print(f"[ANALYSIS] Timeout for {analysis_id}", flush=True)
|
||||||
|
evt = {"type": "error", "message": "Analysis timed out after 10 minutes"}
|
||||||
|
state["events"].append(evt)
|
||||||
|
await q.put(evt)
|
||||||
|
state["done"] = True
|
||||||
|
await q.put(None)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Memory cleanup background task ---
|
||||||
|
async def _cleanup_loop():
|
||||||
|
"""Remove analyses older than 30 minutes every 5 minutes."""
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(300)
|
||||||
|
now = time.time()
|
||||||
|
expired = [aid for aid, s in analyses.items() if now - s["created_at"] > 1800]
|
||||||
|
for aid in expired:
|
||||||
|
analyses.pop(aid, None)
|
||||||
|
if expired:
|
||||||
|
print(f"[CLEANUP] Removed {len(expired)} expired analyses", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def _start_cleanup():
|
||||||
|
asyncio.create_task(_cleanup_loop())
|
||||||
|
|
||||||
|
|
||||||
|
# --- Routes ---
|
||||||
|
|
||||||
|
@app.post("/analyze", dependencies=[Depends(verify_api_key)])
|
||||||
async def start_analysis(req: AnalyzeRequest):
|
async def start_analysis(req: AnalyzeRequest):
|
||||||
ticker = req.ticker.upper().strip()
|
ticker = req.ticker.upper().strip()
|
||||||
if not ticker or len(ticker) > 5:
|
if not ticker or len(ticker) > 5 or not ticker.isalpha():
|
||||||
raise HTTPException(400, "Invalid ticker")
|
raise HTTPException(400, "Invalid ticker")
|
||||||
trade_date = req.date or str(date.today())
|
trade_date = req.date or str(date.today())
|
||||||
analysis_id = str(uuid.uuid4())
|
analysis_id = str(uuid.uuid4())
|
||||||
analyses[analysis_id] = {"queue": asyncio.Queue(), "events": [], "done": False}
|
analyses[analysis_id] = {
|
||||||
|
"queue": asyncio.Queue(),
|
||||||
|
"events": [],
|
||||||
|
"done": False,
|
||||||
|
"created_at": time.time(),
|
||||||
|
}
|
||||||
asyncio.create_task(run_analysis(analysis_id, ticker, trade_date))
|
asyncio.create_task(run_analysis(analysis_id, ticker, trade_date))
|
||||||
return {"id": analysis_id, "ticker": ticker, "date": trade_date}
|
return {"id": analysis_id, "ticker": ticker, "date": trade_date}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/analyze/{analysis_id}/stream")
|
@app.get("/analyze/{analysis_id}/stream", dependencies=[Depends(verify_api_key)])
|
||||||
async def stream_analysis(analysis_id: str, last_event: int = 0):
|
async def stream_analysis(analysis_id: str, last_event: int = 0):
|
||||||
"""Stream SSE events. Supports reconnection via ?last_event=N to replay missed events."""
|
"""Stream SSE events. Supports reconnection via ?last_event=N to replay missed events."""
|
||||||
if analysis_id not in analyses:
|
if analysis_id not in analyses:
|
||||||
|
|
|
||||||
|
|
@ -10,27 +10,23 @@ readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"langchain-core>=0.3.81",
|
"langchain-core>=0.3.81",
|
||||||
"backtrader>=1.9.78.123",
|
|
||||||
"chainlit>=2.5.5",
|
|
||||||
"langchain-anthropic>=0.3.15",
|
"langchain-anthropic>=0.3.15",
|
||||||
"langchain-experimental>=0.3.4",
|
"langchain-experimental>=0.3.4",
|
||||||
"langchain-google-genai>=2.1.5",
|
"langchain-google-genai>=2.1.5",
|
||||||
"langchain-openai>=0.3.23",
|
"langchain-openai>=0.3.23",
|
||||||
"langgraph>=0.4.8",
|
"langgraph>=0.4.8",
|
||||||
"pandas>=2.3.0",
|
"pandas>=2.3.0",
|
||||||
"parsel>=1.10.0",
|
|
||||||
"pytz>=2025.2",
|
"pytz>=2025.2",
|
||||||
"questionary>=2.1.0",
|
|
||||||
"rank-bm25>=0.2.2",
|
"rank-bm25>=0.2.2",
|
||||||
"redis>=6.2.0",
|
|
||||||
"requests>=2.32.4",
|
"requests>=2.32.4",
|
||||||
"rich>=14.0.0",
|
|
||||||
"typer>=0.21.0",
|
|
||||||
"setuptools>=80.9.0",
|
"setuptools>=80.9.0",
|
||||||
"stockstats>=0.6.5",
|
"stockstats>=0.6.5",
|
||||||
"tqdm>=4.67.1",
|
"tqdm>=4.67.1",
|
||||||
"typing-extensions>=4.14.0",
|
"typing-extensions>=4.14.0",
|
||||||
"yfinance>=0.2.63",
|
"yfinance>=0.2.63",
|
||||||
|
"fastapi>=0.115.0",
|
||||||
|
"uvicorn[standard]>=0.30.0",
|
||||||
|
"sse-starlette>=2.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
|
||||||
|
|
@ -8,17 +8,11 @@ stockstats
|
||||||
langgraph
|
langgraph
|
||||||
rank-bm25
|
rank-bm25
|
||||||
setuptools
|
setuptools
|
||||||
backtrader
|
|
||||||
parsel
|
|
||||||
requests
|
requests
|
||||||
tqdm
|
tqdm
|
||||||
pytz
|
pytz
|
||||||
redis
|
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
sse-starlette
|
sse-starlette
|
||||||
rich
|
|
||||||
typer
|
|
||||||
questionary
|
|
||||||
langchain_anthropic
|
langchain_anthropic
|
||||||
langchain-google-genai
|
langchain-google-genai
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue