""" FastAPI application entry point for TradingAgentsX Backend """ from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware import logging import sys import time from pathlib import Path from collections import defaultdict from datetime import datetime, timedelta from backend.app.core.config import settings from backend.app.core.cors import setup_cors from backend.app.api.routes import router from backend.app.api.auth import router as auth_router from backend.app.api.user import router as user_router # Configure logging logging.basicConfig( level=logging.INFO if settings.debug else logging.WARNING, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Add security headers to all responses""" async def dispatch(self, request: Request, call_next): response = await call_next(request) # Security headers response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" response.headers["Permissions-Policy"] = "camera=(), microphone=(), geolocation=()" return response class RateLimitMiddleware(BaseHTTPMiddleware): """Simple in-memory rate limiting middleware""" def __init__(self, app, max_requests: int = 30, window_seconds: int = 60): super().__init__(app) self.max_requests = max_requests self.window_seconds = window_seconds self.requests: dict[str, list[float]] = defaultdict(list) async def dispatch(self, request: Request, call_next): # Skip rate limiting for health checks if request.url.path == "/api/health": return await call_next(request) # Get client IP client_ip = request.client.host if request.client else "unknown" forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: client_ip = forwarded_for.split(",")[0].strip() # Clean old requests now = time.time() cutoff = now - self.window_seconds self.requests[client_ip] = [ t for t in self.requests[client_ip] if t > cutoff ] # Check rate limit if len(self.requests[client_ip]) >= self.max_requests: retry_after = int(self.window_seconds - (now - self.requests[client_ip][0])) return JSONResponse( status_code=429, content={ "error": "Too many requests", "message": f"Rate limit exceeded. Please wait {retry_after} seconds.", "retry_after": retry_after, }, headers={"Retry-After": str(retry_after)}, ) # Record this request self.requests[client_ip].append(now) # Process request response = await call_next(request) # Add rate limit headers remaining = self.max_requests - len(self.requests[client_ip]) response.headers["X-RateLimit-Limit"] = str(self.max_requests) response.headers["X-RateLimit-Remaining"] = str(max(0, remaining)) response.headers["X-RateLimit-Reset"] = str(int(now + self.window_seconds)) return response class SensitiveDataFilter(logging.Filter): """Filter to mask API keys in log messages""" SENSITIVE_PATTERNS = ["api_key", "apikey", "api-key", "token", "secret", "password"] def filter(self, record): if hasattr(record, 'msg') and isinstance(record.msg, str): msg = record.msg for pattern in self.SENSITIVE_PATTERNS: if pattern.lower() in msg.lower(): # Mask the value after the pattern import re msg = re.sub( rf'({pattern}["\']?\s*[=:]\s*["\']?)([^"\'\s,}}]+)', r'\1**********', msg, flags=re.IGNORECASE ) record.msg = msg return True # Add sensitive data filter to all loggers for handler in logging.root.handlers: handler.addFilter(SensitiveDataFilter()) # Create FastAPI application app = FastAPI( title=settings.app_name, version=settings.app_version, description="Multi-Agent LLM Financial Trading Framework - REST API", docs_url="/docs", redoc_url="/redoc", ) # Add security middleware (order matters - added first, executed last) app.add_middleware(SecurityHeadersMiddleware) app.add_middleware(RateLimitMiddleware, max_requests=30, window_seconds=60) # Setup CORS setup_cors(app) # Include API routes app.include_router(router) app.include_router(auth_router) app.include_router(user_router) # Database initialization on startup @app.on_event("startup") async def startup_event(): """Initialize database on startup""" try: from backend.app.db import init_db, check_db_connection # Check if database is configured if await check_db_connection(): logger.info("Database connection successful") await init_db() logger.info("Database tables initialized") else: logger.warning("Database not configured or connection failed - running without database") except Exception as e: logger.warning(f"Database initialization failed: {e} - running without database") @app.get("/") async def root(): """Root endpoint - redirect to API documentation""" from fastapi.responses import RedirectResponse return RedirectResponse(url="/docs") @app.exception_handler(Exception) async def global_exception_handler(request: Request, exc: Exception): """Global exception handler - returns appropriate status codes and masks sensitive data""" error_msg = str(exc) # Mask any API keys that might be in the error message import re patterns = ["sk-", "sk-ant-", "xai-", "AIza"] for pattern in patterns: error_msg = re.sub( rf'{pattern}[a-zA-Z0-9_-]+', f'{pattern}**********', error_msg ) # Determine appropriate status code based on exception type error_type = type(exc).__name__ status_code = 500 if isinstance(exc, ValueError): status_code = 400 elif isinstance(exc, (ConnectionError, ConnectionRefusedError, TimeoutError)): status_code = 503 elif isinstance(exc, PermissionError): status_code = 403 elif isinstance(exc, FileNotFoundError): status_code = 404 logger.error( f"Unhandled exception on {request.method} {request.url.path}: " f"[{error_type}] {error_msg}", exc_info=True ) return JSONResponse( status_code=status_code, content={ "error": "Internal server error" if status_code == 500 else error_type, "detail": error_msg, "type": error_type, "path": str(request.url.path), }, ) if __name__ == "__main__": import uvicorn uvicorn.run( "app.main:app", host="0.0.0.0", port=8000, reload=settings.debug, log_level="info" if settings.debug else "warning", )