191 lines
6.0 KiB
Python
191 lines
6.0 KiB
Python
"""
|
|
FastAPI application entry point for TradingAgentsX Backend
|
|
"""
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
import logging
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from collections import defaultdict
|
|
from datetime import datetime, timedelta
|
|
|
|
from backend.app.core.config import settings
|
|
from backend.app.core.cors import setup_cors
|
|
from backend.app.api.routes import router
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO if settings.debug else logging.WARNING,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|
"""Add security headers to all responses"""
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
response = await call_next(request)
|
|
|
|
# Security headers
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
response.headers["X-Frame-Options"] = "DENY"
|
|
response.headers["X-XSS-Protection"] = "1; mode=block"
|
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
response.headers["Permissions-Policy"] = "camera=(), microphone=(), geolocation=()"
|
|
|
|
return response
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""Simple in-memory rate limiting middleware"""
|
|
|
|
def __init__(self, app, max_requests: int = 30, window_seconds: int = 60):
|
|
super().__init__(app)
|
|
self.max_requests = max_requests
|
|
self.window_seconds = window_seconds
|
|
self.requests: dict[str, list[float]] = defaultdict(list)
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
# Skip rate limiting for health checks
|
|
if request.url.path == "/api/health":
|
|
return await call_next(request)
|
|
|
|
# Get client IP
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
forwarded_for = request.headers.get("X-Forwarded-For")
|
|
if forwarded_for:
|
|
client_ip = forwarded_for.split(",")[0].strip()
|
|
|
|
# Clean old requests
|
|
now = time.time()
|
|
cutoff = now - self.window_seconds
|
|
self.requests[client_ip] = [
|
|
t for t in self.requests[client_ip] if t > cutoff
|
|
]
|
|
|
|
# Check rate limit
|
|
if len(self.requests[client_ip]) >= self.max_requests:
|
|
retry_after = int(self.window_seconds - (now - self.requests[client_ip][0]))
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "Too many requests",
|
|
"message": f"Rate limit exceeded. Please wait {retry_after} seconds.",
|
|
"retry_after": retry_after,
|
|
},
|
|
headers={"Retry-After": str(retry_after)},
|
|
)
|
|
|
|
# Record this request
|
|
self.requests[client_ip].append(now)
|
|
|
|
# Process request
|
|
response = await call_next(request)
|
|
|
|
# Add rate limit headers
|
|
remaining = self.max_requests - len(self.requests[client_ip])
|
|
response.headers["X-RateLimit-Limit"] = str(self.max_requests)
|
|
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
|
|
response.headers["X-RateLimit-Reset"] = str(int(now + self.window_seconds))
|
|
|
|
return response
|
|
|
|
|
|
class SensitiveDataFilter(logging.Filter):
|
|
"""Filter to mask API keys in log messages"""
|
|
|
|
SENSITIVE_PATTERNS = ["api_key", "apikey", "api-key", "token", "secret", "password"]
|
|
|
|
def filter(self, record):
|
|
if hasattr(record, 'msg') and isinstance(record.msg, str):
|
|
msg = record.msg
|
|
for pattern in self.SENSITIVE_PATTERNS:
|
|
if pattern.lower() in msg.lower():
|
|
# Mask the value after the pattern
|
|
import re
|
|
msg = re.sub(
|
|
rf'({pattern}["\']?\s*[=:]\s*["\']?)([^"\'\s,}}]+)',
|
|
r'\1**********',
|
|
msg,
|
|
flags=re.IGNORECASE
|
|
)
|
|
record.msg = msg
|
|
return True
|
|
|
|
|
|
# Add sensitive data filter to all loggers
|
|
for handler in logging.root.handlers:
|
|
handler.addFilter(SensitiveDataFilter())
|
|
|
|
|
|
# Create FastAPI application
|
|
app = FastAPI(
|
|
title=settings.app_name,
|
|
version=settings.app_version,
|
|
description="Multi-Agent LLM Financial Trading Framework - REST API",
|
|
docs_url="/docs",
|
|
redoc_url="/redoc",
|
|
)
|
|
|
|
# Add security middleware (order matters - added first, executed last)
|
|
app.add_middleware(SecurityHeadersMiddleware)
|
|
app.add_middleware(RateLimitMiddleware, max_requests=30, window_seconds=60)
|
|
|
|
# Setup CORS
|
|
setup_cors(app)
|
|
|
|
# Include API routes
|
|
app.include_router(router)
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint"""
|
|
return {
|
|
"message": "Welcome to TradingAgentsX API",
|
|
"version": settings.app_version,
|
|
"docs": "/docs",
|
|
"health": "/api/health",
|
|
}
|
|
|
|
|
|
@app.exception_handler(Exception)
|
|
async def global_exception_handler(request, exc):
|
|
"""Global exception handler - masks sensitive data in errors"""
|
|
error_msg = str(exc)
|
|
|
|
# Mask any API keys that might be in the error message
|
|
import re
|
|
patterns = ["sk-", "sk-ant-", "xai-", "AIza"]
|
|
for pattern in patterns:
|
|
error_msg = re.sub(
|
|
rf'{pattern}[a-zA-Z0-9_-]+',
|
|
f'{pattern}**********',
|
|
error_msg
|
|
)
|
|
|
|
logger.error(f"Unhandled exception: {error_msg}", exc_info=True)
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={
|
|
"error": "Internal server error",
|
|
"detail": error_msg,
|
|
"type": type(exc).__name__,
|
|
},
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"app.main:app",
|
|
host="0.0.0.0",
|
|
port=8000,
|
|
reload=settings.debug,
|
|
log_level="info" if settings.debug else "warning",
|
|
)
|