TradingAgents/app.py

493 lines
18 KiB
Python

"""FastAPI SSE backend for the structured equity ranking engine."""
from pathlib import Path
from dotenv import load_dotenv
load_dotenv(Path(__file__).parent / ".env")
import logging
import os
import re
import time
import uuid
import asyncio
import json
import traceback as _tb
from datetime import date
from fastapi import FastAPI, HTTPException, Request, Depends
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
logger = logging.getLogger(__name__)
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
# If using Groq (or other OpenAI-compatible), set OPENAI_API_KEY for langchain
if not os.environ.get("OPENAI_API_KEY"):
groq_key = os.environ.get("GROQ_API_KEY", "")
if groq_key:
os.environ["OPENAI_API_KEY"] = groq_key
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
app = FastAPI(title="TradingAgents Structured Pipeline")
# --- CORS ---
_cors_env = os.getenv("CORS_ORIGINS", "")
_cors_origins = [o.strip() for o in _cors_env.split(",") if o.strip()] if _cors_env else ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Auth ---
_API_KEY = os.getenv("AGENTS_API_KEY", "")
async def verify_api_key(request: Request):
if not _API_KEY:
return
auth = request.headers.get("Authorization", "")
if auth != f"Bearer {_API_KEY}":
raise HTTPException(401, "Invalid or missing API key")
# --- Concurrency ---
MAX_CONCURRENT = int(os.getenv("MAX_CONCURRENT_ANALYSES", "3"))
_semaphore = asyncio.Semaphore(MAX_CONCURRENT)
# --- Event buffer cap ---
MAX_EVENTS_PER_ANALYSIS = 5000
analyses: dict[str, dict] = {}
def _append_event(state: dict, evt: dict):
"""Append an event to the analysis state, enforcing the buffer cap."""
events = state["events"]
events.append(evt)
if len(events) > MAX_EVENTS_PER_ANALYSIS:
# Drop oldest events, keep the last MAX_EVENTS_PER_ANALYSIS
state["events"] = events[-MAX_EVENTS_PER_ANALYSIS:]
class AnalyzeRequest(BaseModel):
ticker: str
date: str | None = None
def build_config():
"""Build TradingAgents config from env vars."""
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = os.getenv("LLM_PROVIDER", "openai")
config["deep_think_llm"] = os.getenv("DEEP_THINK_MODEL", "deepseek-v3.1:671b-cloud")
config["quick_think_llm"] = os.getenv("QUICK_THINK_MODEL", "deepseek-v3.1:671b-cloud")
config["backend_url"] = os.getenv("LLM_BASE_URL", "https://ollama.com/v1")
config["max_debate_rounds"] = 1
config["max_risk_discuss_rounds"] = 1
config["data_vendors"] = {
"core_stock_apis": "yfinance",
"technical_indicators": "yfinance",
"fundamental_data": "yfinance",
"news_data": "yfinance",
}
logger.info(
"config_built provider=%s deep=%s quick=%s url=%s",
config['llm_provider'], config['deep_think_llm'],
config['quick_think_llm'], config['backend_url'],
)
return config
# ---------------------------------------------------------------------------
# Stage/agent mapping for SSE events
# ---------------------------------------------------------------------------
# Maps state field → (agent display name, pipeline stage)
FIELD_AGENT_MAP = {
"validation": ("Validation", "validation"),
"company_card": ("Company Card", "validation"),
"macro": ("Macro Regime", "tier1"),
"liquidity": ("Liquidity", "tier1"),
"business_quality": ("Business Quality", "tier2"),
"institutional_flow": ("Institutional Flow", "tier2"),
"valuation": ("Valuation", "tier2"),
"entry_timing": ("Entry Timing", "tier2"),
"earnings_revisions": ("Earnings Revisions", "tier2"),
"sector_rotation": ("Sector Rotation", "tier2"),
"backlog": ("Backlog / Order Momentum", "tier2"),
"crowding": ("Narrative Crowding", "tier2"),
"archetype": ("Archetype", "scoring"),
"master_score": ("Master Score", "scoring"),
"theme_substitution": ("Theme Substitution", "portfolio"),
"position_replacement": ("Position Replacement", "portfolio"),
"bull_case": ("Bull Researcher", "debate"),
"bear_case": ("Bear Researcher", "debate"),
"debate": ("Debate Referee", "debate"),
"risk": ("Risk / Invalidation", "decision"),
"final_decision": ("Final Decision", "decision"),
}
ALL_AGENTS = [name for name, _ in FIELD_AGENT_MAP.values()]
ALL_STAGES = ["validation", "tier1", "tier2", "scoring", "portfolio", "debate", "decision"]
# ---------------------------------------------------------------------------
# Analysis runner
# ---------------------------------------------------------------------------
async def _run_analysis_inner(analysis_id: str, ticker: str, trade_date: str):
"""Core analysis logic — streams structured pipeline state changes as SSE."""
state = analyses[analysis_id]
q = state["queue"]
config = build_config()
try:
graph = TradingAgentsGraph(debug=False, config=config)
logger.info(
"analysis_init_ok deep_llm=%s quick_llm=%s analysis_id=%s",
type(graph.deep_thinking_llm).__name__,
type(graph.quick_thinking_llm).__name__,
analysis_id,
)
except Exception as e:
logger.error("analysis_init_failed analysis_id=%s error=%s\n%s", analysis_id, e, _tb.format_exc())
await q.put({"type": "error", "message": f"Init failed: {e}"})
await q.put(None)
return
init_state = graph._create_initial_state(ticker, trade_date)
start_time = time.time()
emitted_fields = set()
prev_agent_statuses = {}
final_state = None
# Emit initial status: all agents pending
for field, (agent_name, stage) in FIELD_AGENT_MAP.items():
prev_agent_statuses[field] = "pending"
evt = {
"type": "agent_update",
"agent": agent_name,
"stage": stage,
"status": "pending",
"stats": _stats(start_time, emitted_fields),
}
_append_event(state, evt)
await q.put(evt)
try:
async for chunk in graph.graph.astream(
init_state,
stream_mode="values",
config={"recursion_limit": 25},
):
final_state = chunk
# Detect newly populated fields
for field, (agent_name, stage) in FIELD_AGENT_MAP.items():
if field in emitted_fields:
continue
value = chunk.get(field)
if value is None:
continue
emitted_fields.add(field)
st = _stats(start_time, emitted_fields)
# Mark this agent completed
prev_agent_statuses[field] = "completed"
evt = {
"type": "agent_update",
"agent": agent_name,
"stage": stage,
"status": "completed",
"stats": st,
}
_append_event(state, evt)
await q.put(evt)
# Emit report data for key fields
if field in ("validation", "company_card"):
evt = {
"type": "report",
"agent": agent_name,
"stage": stage,
"field": field,
"report": _format_report(field, value),
"stats": st,
}
_append_event(state, evt)
await q.put(evt)
elif field == "debate":
bull = chunk.get("bull_case") or {}
bear = chunk.get("bear_case") or {}
evt = {
"type": "debate",
"stage": "debate",
"bull": bull.get("thesis", ""),
"bear": bear.get("thesis", ""),
"judge": (value or {}).get("reasoning", ""),
"winner": (value or {}).get("winner", ""),
"stats": st,
}
_append_event(state, evt)
await q.put(evt)
elif field == "master_score":
evt = {
"type": "score",
"stage": "scoring",
"master_score": value,
"adjusted_score": chunk.get("adjusted_score"),
"position_role": chunk.get("position_role"),
"stats": st,
}
_append_event(state, evt)
await q.put(evt)
# Mark in-progress agents for upcoming stages
await _update_in_progress(chunk, emitted_fields, prev_agent_statuses, state, q, start_time)
except Exception as e:
logger.error("analysis_stream_error analysis_id=%s error=%s\n%s", analysis_id, e, _tb.format_exc())
evt = {"type": "error", "message": str(e)}
_append_event(state, evt)
await q.put(evt)
state["done"] = True
await q.put(None)
return
# Final decision event
if final_state:
decision = final_state.get("final_decision") or {}
st = _stats(start_time, emitted_fields)
# Mark all remaining as completed
for field in FIELD_AGENT_MAP:
if prev_agent_statuses.get(field) != "completed":
agent_name, stage = FIELD_AGENT_MAP[field]
prev_agent_statuses[field] = "completed"
evt = {
"type": "agent_update",
"agent": agent_name,
"stage": stage,
"status": "completed",
"stats": st,
}
_append_event(state, evt)
await q.put(evt)
evt = {
"type": "decision",
"stage": "decision",
"signal": decision.get("action", "AVOID"),
"decision_text": decision.get("narrative", ""),
"master_score": final_state.get("master_score"),
"adjusted_score": final_state.get("adjusted_score"),
"position_role": final_state.get("position_role"),
"final_decision": decision,
"stats": st,
}
_append_event(state, evt)
await q.put(evt)
state["done"] = True
await q.put(None)
async def _update_in_progress(chunk, emitted, statuses, state, q, start_time):
"""Heuristic: mark agents as in_progress based on stage progression."""
# If validation is done, mark tier 1 as in_progress
if "validation" in emitted:
for field in ("macro", "liquidity"):
if field not in emitted and statuses.get(field) == "pending":
statuses[field] = "in_progress"
agent_name, stage = FIELD_AGENT_MAP[field]
evt = {
"type": "agent_update",
"agent": agent_name,
"stage": stage,
"status": "in_progress",
"stats": _stats(start_time, emitted),
}
_append_event(state, evt)
await q.put(evt)
# If tier 1 done, mark tier 2 in_progress
if "macro" in emitted and "liquidity" in emitted:
tier2_fields = [
"business_quality", "institutional_flow", "valuation",
"entry_timing", "earnings_revisions", "sector_rotation",
"backlog", "crowding",
]
for field in tier2_fields:
if field not in emitted and statuses.get(field) == "pending":
statuses[field] = "in_progress"
agent_name, stage = FIELD_AGENT_MAP[field]
evt = {
"type": "agent_update",
"agent": agent_name,
"stage": stage,
"status": "in_progress",
"stats": _stats(start_time, emitted),
}
_append_event(state, evt)
await q.put(evt)
# If scoring done, mark portfolio analysis in_progress
if "master_score" in emitted:
for field in ("theme_substitution", "position_replacement"):
if field not in emitted and statuses.get(field) == "pending":
statuses[field] = "in_progress"
agent_name, stage = FIELD_AGENT_MAP[field]
evt = {
"type": "agent_update",
"agent": agent_name,
"stage": stage,
"status": "in_progress",
"stats": _stats(start_time, emitted),
}
_append_event(state, evt)
await q.put(evt)
def _stats(start_time: float, emitted_fields: set) -> dict:
return {
"agents_done": len(emitted_fields),
"agents_total": len(FIELD_AGENT_MAP),
"elapsed": round(time.time() - start_time, 1),
}
def _format_report(field: str, value) -> str:
"""Format a state field value as a readable report string."""
if isinstance(value, dict):
if "summary_1_sentence" in value:
return value["summary_1_sentence"]
if "company_name" in value:
return f"{value.get('company_name', '')} ({value.get('ticker', '')}) — {value.get('sector', '')} / {value.get('industry', '')}"
return json.dumps(value, indent=2, default=str)[:500]
return str(value)[:500]
async def run_analysis(analysis_id: str, ticker: str, trade_date: str):
"""Background task with semaphore and timeout."""
state = analyses[analysis_id]
q = state["queue"]
async with _semaphore:
try:
await asyncio.wait_for(
_run_analysis_inner(analysis_id, ticker, trade_date),
timeout=3600,
)
except asyncio.TimeoutError:
logger.warning("analysis_timeout analysis_id=%s", analysis_id)
evt = {"type": "error", "message": "Analysis timed out after 60 minutes"}
_append_event(state, evt)
await q.put(evt)
state["done"] = True
await q.put(None)
# --- Cleanup ---
async def _cleanup_loop():
while True:
await asyncio.sleep(300)
now = time.time()
expired = [aid for aid, s in analyses.items() if now - s["created_at"] > 1800]
for aid in expired:
analyses.pop(aid, None)
if expired:
logger.info("cleanup_expired count=%d", len(expired))
@app.on_event("startup")
async def _start_cleanup():
asyncio.create_task(_cleanup_loop())
# --- Routes ---
@app.post("/analyze", dependencies=[Depends(verify_api_key)])
async def start_analysis(req: AnalyzeRequest):
ticker = req.ticker.upper().strip()
if not ticker:
raise HTTPException(400, "Ticker must not be empty")
if len(ticker) > 10:
raise HTTPException(400, f"Ticker too long ({len(ticker)} chars, max 10)")
if not re.match(r'^[A-Z0-9.\-]{1,10}$', ticker):
raise HTTPException(400, "Invalid ticker — only letters, digits, dots, and hyphens allowed")
trade_date = req.date or str(date.today())
analysis_id = str(uuid.uuid4())
analyses[analysis_id] = {
"queue": asyncio.Queue(),
"events": [],
"done": False,
"created_at": time.time(),
}
asyncio.create_task(run_analysis(analysis_id, ticker, trade_date))
return {"id": analysis_id, "ticker": ticker, "date": trade_date}
@app.get("/analyze/{analysis_id}/stream", dependencies=[Depends(verify_api_key)])
async def stream_analysis(analysis_id: str, last_event: int = 0):
"""Stream SSE events. Supports reconnection via ?last_event=N."""
if analysis_id not in analyses:
raise HTTPException(404, "Analysis not found")
state = analyses[analysis_id]
async def event_generator():
idx = last_event
while idx < len(state["events"]):
evt = state["events"][idx]
idx += 1
yield {"id": str(idx), "data": json.dumps(evt)}
if state["done"]:
return
q = state["queue"]
while True:
try:
event = await asyncio.wait_for(q.get(), timeout=15)
except asyncio.TimeoutError:
yield {"event": "heartbeat", "data": json.dumps({"type": "heartbeat"})}
continue
if event is None:
break
idx += 1
yield {"id": str(idx), "data": json.dumps(event)}
return EventSourceResponse(event_generator())
@app.get("/health")
async def health():
return {"status": "ok", "engine": "structured_pipeline"}
@app.get("/api/status")
async def get_status():
"""Structured pipeline status — no auth required."""
from datetime import datetime
active_count = len(analyses)
return {
"service": "structured-pipeline",
"engine": "TradingAgents",
"active_analyses": active_count,
"analyses": {k: {"created": v["created"], "done": v["done"]} for k, v in analyses.items()},
"pid": __import__("os").getpid(),
"uptime": time.time() - __import__("os").getpid(),
}
@app.get("/api/health")
async def api_health():
return {"status": "ok", "service": "structured-pipeline"}