"""FastAPI SSE backend for the structured equity ranking engine.""" from pathlib import Path from dotenv import load_dotenv load_dotenv(Path(__file__).parent / ".env") import logging import os import re import time import uuid import asyncio import json import traceback as _tb from datetime import date from fastapi import FastAPI, HTTPException, Request, Depends logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s", ) logger = logging.getLogger(__name__) from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from sse_starlette.sse import EventSourceResponse # If using Groq (or other OpenAI-compatible), set OPENAI_API_KEY for langchain if not os.environ.get("OPENAI_API_KEY"): groq_key = os.environ.get("GROQ_API_KEY", "") if groq_key: os.environ["OPENAI_API_KEY"] = groq_key from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG app = FastAPI(title="TradingAgents Structured Pipeline") # --- CORS --- _cors_env = os.getenv("CORS_ORIGINS", "") _cors_origins = [o.strip() for o in _cors_env.split(",") if o.strip()] if _cors_env else ["*"] app.add_middleware( CORSMiddleware, allow_origins=_cors_origins, allow_methods=["*"], allow_headers=["*"], ) # --- Auth --- _API_KEY = os.getenv("AGENTS_API_KEY", "") async def verify_api_key(request: Request): if not _API_KEY: return auth = request.headers.get("Authorization", "") if auth != f"Bearer {_API_KEY}": raise HTTPException(401, "Invalid or missing API key") # --- Concurrency --- MAX_CONCURRENT = int(os.getenv("MAX_CONCURRENT_ANALYSES", "3")) _semaphore = asyncio.Semaphore(MAX_CONCURRENT) # --- Event buffer cap --- MAX_EVENTS_PER_ANALYSIS = 5000 analyses: dict[str, dict] = {} def _append_event(state: dict, evt: dict): """Append an event to the analysis state, enforcing the buffer cap.""" events = state["events"] events.append(evt) if len(events) > MAX_EVENTS_PER_ANALYSIS: # Drop oldest events, keep the last MAX_EVENTS_PER_ANALYSIS state["events"] = events[-MAX_EVENTS_PER_ANALYSIS:] class AnalyzeRequest(BaseModel): ticker: str date: str | None = None def build_config(): """Build TradingAgents config from env vars.""" config = DEFAULT_CONFIG.copy() config["llm_provider"] = os.getenv("LLM_PROVIDER", "openai") config["deep_think_llm"] = os.getenv("DEEP_THINK_MODEL", "deepseek-v3.1:671b-cloud") config["quick_think_llm"] = os.getenv("QUICK_THINK_MODEL", "deepseek-v3.1:671b-cloud") config["backend_url"] = os.getenv("LLM_BASE_URL", "https://ollama.com/v1") config["max_debate_rounds"] = 1 config["max_risk_discuss_rounds"] = 1 config["data_vendors"] = { "core_stock_apis": "yfinance", "technical_indicators": "yfinance", "fundamental_data": "yfinance", "news_data": "yfinance", } logger.info( "config_built provider=%s deep=%s quick=%s url=%s", config['llm_provider'], config['deep_think_llm'], config['quick_think_llm'], config['backend_url'], ) return config # --------------------------------------------------------------------------- # Stage/agent mapping for SSE events # --------------------------------------------------------------------------- # Maps state field → (agent display name, pipeline stage) FIELD_AGENT_MAP = { "validation": ("Validation", "validation"), "company_card": ("Company Card", "validation"), "macro": ("Macro Regime", "tier1"), "liquidity": ("Liquidity", "tier1"), "business_quality": ("Business Quality", "tier2"), "institutional_flow": ("Institutional Flow", "tier2"), "valuation": ("Valuation", "tier2"), "entry_timing": ("Entry Timing", "tier2"), "earnings_revisions": ("Earnings Revisions", "tier2"), "sector_rotation": ("Sector Rotation", "tier2"), "backlog": ("Backlog / Order Momentum", "tier2"), "crowding": ("Narrative Crowding", "tier2"), "archetype": ("Archetype", "scoring"), "master_score": ("Master Score", "scoring"), "theme_substitution": ("Theme Substitution", "portfolio"), "position_replacement": ("Position Replacement", "portfolio"), "bull_case": ("Bull Researcher", "debate"), "bear_case": ("Bear Researcher", "debate"), "debate": ("Debate Referee", "debate"), "risk": ("Risk / Invalidation", "decision"), "final_decision": ("Final Decision", "decision"), } ALL_AGENTS = [name for name, _ in FIELD_AGENT_MAP.values()] ALL_STAGES = ["validation", "tier1", "tier2", "scoring", "portfolio", "debate", "decision"] # --------------------------------------------------------------------------- # Analysis runner # --------------------------------------------------------------------------- async def _run_analysis_inner(analysis_id: str, ticker: str, trade_date: str): """Core analysis logic — streams structured pipeline state changes as SSE.""" state = analyses[analysis_id] q = state["queue"] config = build_config() try: graph = TradingAgentsGraph(debug=False, config=config) logger.info( "analysis_init_ok deep_llm=%s quick_llm=%s analysis_id=%s", type(graph.deep_thinking_llm).__name__, type(graph.quick_thinking_llm).__name__, analysis_id, ) except Exception as e: logger.error("analysis_init_failed analysis_id=%s error=%s\n%s", analysis_id, e, _tb.format_exc()) await q.put({"type": "error", "message": f"Init failed: {e}"}) await q.put(None) return init_state = graph._create_initial_state(ticker, trade_date) start_time = time.time() emitted_fields = set() prev_agent_statuses = {} final_state = None # Emit initial status: all agents pending for field, (agent_name, stage) in FIELD_AGENT_MAP.items(): prev_agent_statuses[field] = "pending" evt = { "type": "agent_update", "agent": agent_name, "stage": stage, "status": "pending", "stats": _stats(start_time, emitted_fields), } _append_event(state, evt) await q.put(evt) try: async for chunk in graph.graph.astream( init_state, stream_mode="values", config={"recursion_limit": 25}, ): final_state = chunk # Detect newly populated fields for field, (agent_name, stage) in FIELD_AGENT_MAP.items(): if field in emitted_fields: continue value = chunk.get(field) if value is None: continue emitted_fields.add(field) st = _stats(start_time, emitted_fields) # Mark this agent completed prev_agent_statuses[field] = "completed" evt = { "type": "agent_update", "agent": agent_name, "stage": stage, "status": "completed", "stats": st, } _append_event(state, evt) await q.put(evt) # Emit report data for key fields if field in ("validation", "company_card"): evt = { "type": "report", "agent": agent_name, "stage": stage, "field": field, "report": _format_report(field, value), "stats": st, } _append_event(state, evt) await q.put(evt) elif field == "debate": bull = chunk.get("bull_case") or {} bear = chunk.get("bear_case") or {} evt = { "type": "debate", "stage": "debate", "bull": bull.get("thesis", ""), "bear": bear.get("thesis", ""), "judge": (value or {}).get("reasoning", ""), "winner": (value or {}).get("winner", ""), "stats": st, } _append_event(state, evt) await q.put(evt) elif field == "master_score": evt = { "type": "score", "stage": "scoring", "master_score": value, "adjusted_score": chunk.get("adjusted_score"), "position_role": chunk.get("position_role"), "stats": st, } _append_event(state, evt) await q.put(evt) # Mark in-progress agents for upcoming stages await _update_in_progress(chunk, emitted_fields, prev_agent_statuses, state, q, start_time) except Exception as e: logger.error("analysis_stream_error analysis_id=%s error=%s\n%s", analysis_id, e, _tb.format_exc()) evt = {"type": "error", "message": str(e)} _append_event(state, evt) await q.put(evt) state["done"] = True await q.put(None) return # Final decision event if final_state: decision = final_state.get("final_decision") or {} st = _stats(start_time, emitted_fields) # Mark all remaining as completed for field in FIELD_AGENT_MAP: if prev_agent_statuses.get(field) != "completed": agent_name, stage = FIELD_AGENT_MAP[field] prev_agent_statuses[field] = "completed" evt = { "type": "agent_update", "agent": agent_name, "stage": stage, "status": "completed", "stats": st, } _append_event(state, evt) await q.put(evt) evt = { "type": "decision", "stage": "decision", "signal": decision.get("action", "AVOID"), "decision_text": decision.get("narrative", ""), "master_score": final_state.get("master_score"), "adjusted_score": final_state.get("adjusted_score"), "position_role": final_state.get("position_role"), "final_decision": decision, "stats": st, } _append_event(state, evt) await q.put(evt) state["done"] = True await q.put(None) async def _update_in_progress(chunk, emitted, statuses, state, q, start_time): """Heuristic: mark agents as in_progress based on stage progression.""" # If validation is done, mark tier 1 as in_progress if "validation" in emitted: for field in ("macro", "liquidity"): if field not in emitted and statuses.get(field) == "pending": statuses[field] = "in_progress" agent_name, stage = FIELD_AGENT_MAP[field] evt = { "type": "agent_update", "agent": agent_name, "stage": stage, "status": "in_progress", "stats": _stats(start_time, emitted), } _append_event(state, evt) await q.put(evt) # If tier 1 done, mark tier 2 in_progress if "macro" in emitted and "liquidity" in emitted: tier2_fields = [ "business_quality", "institutional_flow", "valuation", "entry_timing", "earnings_revisions", "sector_rotation", "backlog", "crowding", ] for field in tier2_fields: if field not in emitted and statuses.get(field) == "pending": statuses[field] = "in_progress" agent_name, stage = FIELD_AGENT_MAP[field] evt = { "type": "agent_update", "agent": agent_name, "stage": stage, "status": "in_progress", "stats": _stats(start_time, emitted), } _append_event(state, evt) await q.put(evt) # If scoring done, mark portfolio analysis in_progress if "master_score" in emitted: for field in ("theme_substitution", "position_replacement"): if field not in emitted and statuses.get(field) == "pending": statuses[field] = "in_progress" agent_name, stage = FIELD_AGENT_MAP[field] evt = { "type": "agent_update", "agent": agent_name, "stage": stage, "status": "in_progress", "stats": _stats(start_time, emitted), } _append_event(state, evt) await q.put(evt) def _stats(start_time: float, emitted_fields: set) -> dict: return { "agents_done": len(emitted_fields), "agents_total": len(FIELD_AGENT_MAP), "elapsed": round(time.time() - start_time, 1), } def _format_report(field: str, value) -> str: """Format a state field value as a readable report string.""" if isinstance(value, dict): if "summary_1_sentence" in value: return value["summary_1_sentence"] if "company_name" in value: return f"{value.get('company_name', '')} ({value.get('ticker', '')}) — {value.get('sector', '')} / {value.get('industry', '')}" return json.dumps(value, indent=2, default=str)[:500] return str(value)[:500] async def run_analysis(analysis_id: str, ticker: str, trade_date: str): """Background task with semaphore and timeout.""" state = analyses[analysis_id] q = state["queue"] async with _semaphore: try: await asyncio.wait_for( _run_analysis_inner(analysis_id, ticker, trade_date), timeout=3600, ) except asyncio.TimeoutError: logger.warning("analysis_timeout analysis_id=%s", analysis_id) evt = {"type": "error", "message": "Analysis timed out after 60 minutes"} _append_event(state, evt) await q.put(evt) state["done"] = True await q.put(None) # --- Cleanup --- async def _cleanup_loop(): while True: await asyncio.sleep(300) now = time.time() expired = [aid for aid, s in analyses.items() if now - s["created_at"] > 1800] for aid in expired: analyses.pop(aid, None) if expired: logger.info("cleanup_expired count=%d", len(expired)) @app.on_event("startup") async def _start_cleanup(): asyncio.create_task(_cleanup_loop()) # --- Routes --- @app.post("/analyze", dependencies=[Depends(verify_api_key)]) async def start_analysis(req: AnalyzeRequest): ticker = req.ticker.upper().strip() if not ticker: raise HTTPException(400, "Ticker must not be empty") if len(ticker) > 10: raise HTTPException(400, f"Ticker too long ({len(ticker)} chars, max 10)") if not re.match(r'^[A-Z0-9.\-]{1,10}$', ticker): raise HTTPException(400, "Invalid ticker — only letters, digits, dots, and hyphens allowed") trade_date = req.date or str(date.today()) analysis_id = str(uuid.uuid4()) analyses[analysis_id] = { "queue": asyncio.Queue(), "events": [], "done": False, "created_at": time.time(), } asyncio.create_task(run_analysis(analysis_id, ticker, trade_date)) return {"id": analysis_id, "ticker": ticker, "date": trade_date} @app.get("/analyze/{analysis_id}/stream", dependencies=[Depends(verify_api_key)]) async def stream_analysis(analysis_id: str, last_event: int = 0): """Stream SSE events. Supports reconnection via ?last_event=N.""" if analysis_id not in analyses: raise HTTPException(404, "Analysis not found") state = analyses[analysis_id] async def event_generator(): idx = last_event while idx < len(state["events"]): evt = state["events"][idx] idx += 1 yield {"id": str(idx), "data": json.dumps(evt)} if state["done"]: return q = state["queue"] while True: try: event = await asyncio.wait_for(q.get(), timeout=15) except asyncio.TimeoutError: yield {"event": "heartbeat", "data": json.dumps({"type": "heartbeat"})} continue if event is None: break idx += 1 yield {"id": str(idx), "data": json.dumps(event)} return EventSourceResponse(event_generator()) @app.get("/health") async def health(): return {"status": "ok", "engine": "structured_pipeline"} @app.get("/api/status") async def get_status(): """Structured pipeline status — no auth required.""" from datetime import datetime active_count = len(analyses) return { "service": "structured-pipeline", "engine": "TradingAgents", "active_analyses": active_count, "analyses": {k: {"created": v["created"], "done": v["done"]} for k, v in analyses.items()}, "pid": __import__("os").getpid(), "uptime": time.time() - __import__("os").getpid(), } @app.get("/api/health") async def api_health(): return {"status": "ok", "service": "structured-pipeline"}