TradingAgents/tests/api/test_middleware.py

567 lines
20 KiB
Python

"""
Test suite for FastAPI middleware and error handling.
This module tests Issue #48 middleware features:
1. Error handling middleware
2. HTTP exceptions (400, 401, 404, 422, 500)
3. Request logging middleware
4. CORS middleware (if implemented)
5. Request ID tracking
6. Error response format consistency
Tests follow TDD - written before implementation.
"""
import pytest
from typing import Dict, Any
pytestmark = pytest.mark.asyncio
# ============================================================================
# Integration Tests: Error Handling Middleware
# ============================================================================
class TestErrorHandlingMiddleware:
"""Test global error handling and exception formatting."""
async def test_404_not_found_format(self, client):
"""Test 404 error has consistent format."""
# Act
response = await client.get("/api/v1/nonexistent")
# Assert
assert response.status_code == 404
data = response.json()
assert "detail" in data
assert isinstance(data["detail"], str)
async def test_422_validation_error_format(self, client, auth_headers):
"""Test 422 validation error has detailed format."""
# Arrange: Send invalid data
invalid_data = {
"name": 123, # Should be string
}
# Act
response = await client.post(
"/api/v1/strategies",
json=invalid_data,
headers=auth_headers,
)
# Assert
assert response.status_code == 422
data = response.json()
assert "detail" in data
# FastAPI validation errors include location and message
if isinstance(data["detail"], list):
assert len(data["detail"]) > 0
error = data["detail"][0]
assert "loc" in error or "msg" in error
async def test_401_unauthorized_format(self, client):
"""Test 401 unauthorized error format."""
# Act: Access protected endpoint without token
response = await client.get("/api/v1/strategies")
# Assert
assert response.status_code == 401
data = response.json()
assert "detail" in data
async def test_500_internal_error_handling(self, client, test_user, auth_headers):
"""Test that 500 errors are caught and formatted consistently."""
# This test requires an endpoint that can trigger 500 error
# Will need to be implemented based on actual error scenarios
# For now, test that if 500 occurs, it has proper format
# (Implementation may need mock or special test endpoint)
pass
async def test_error_response_includes_timestamp(self, client):
"""Test that error responses may include timestamp."""
# Act
response = await client.get("/api/v1/nonexistent")
# Assert
assert response.status_code == 404
data = response.json()
# Timestamp may be included for debugging
# assert "timestamp" in data or "detail" in data
async def test_error_response_no_stack_trace(self, client):
"""Test that error responses don't leak stack traces in production."""
# Act
response = await client.get("/api/v1/nonexistent")
# Assert
assert response.status_code == 404
data = response.json()
response_text = str(data).lower()
# Should not contain stack trace keywords
assert "traceback" not in response_text
assert "line " not in response_text # "line 123" from stack traces
assert ".py" not in response_text # File paths
async def test_error_response_content_type(self, client):
"""Test that error responses have correct Content-Type."""
# Act
response = await client.get("/api/v1/nonexistent")
# Assert
assert response.status_code == 404
assert "application/json" in response.headers.get("content-type", "")
# ============================================================================
# Unit Tests: Exception Handlers
# ============================================================================
class TestExceptionHandlers:
"""Test custom exception handlers."""
async def test_http_exception_handler(self, client):
"""Test HTTPException is handled correctly."""
# This would test custom HTTPException handler if implemented
# Act: Trigger HTTPException
response = await client.get("/api/v1/strategies/invalid")
# Assert: Should be handled gracefully
assert response.status_code in [400, 404, 422]
async def test_validation_exception_handler(self, client, auth_headers):
"""Test RequestValidationError handler."""
# Arrange: Send malformed JSON
# Act
response = await client.post(
"/api/v1/strategies",
data="not valid json",
headers={**auth_headers, "Content-Type": "application/json"},
)
# Assert
assert response.status_code == 422
async def test_generic_exception_handler(self, client):
"""Test that unexpected exceptions are caught."""
# This requires an endpoint that can raise unexpected exception
# or mock implementation
pass
# ============================================================================
# Integration Tests: Request Logging
# ============================================================================
class TestRequestLogging:
"""Test request and response logging middleware."""
async def test_request_logging_on_success(self, client, test_user, auth_headers, caplog):
"""Test that successful requests are logged."""
# Act
response = await client.get("/api/v1/strategies", headers=auth_headers)
# Assert
assert response.status_code == 200
# Check logs for request info (if logging middleware implemented)
# log_messages = [record.message for record in caplog.records]
# assert any("GET" in msg and "/api/v1/strategies" in msg for msg in log_messages)
async def test_request_logging_on_error(self, client, caplog):
"""Test that failed requests are logged."""
# Act
response = await client.get("/api/v1/nonexistent")
# Assert
assert response.status_code == 404
# Errors should be logged
# log_messages = [record.message for record in caplog.records]
# assert any("404" in msg for msg in log_messages)
async def test_sensitive_data_not_logged(
self, client, test_user, test_user_data, caplog
):
"""Test that passwords/tokens are not logged."""
# Arrange
login_data = {
"username": test_user_data["username"],
"password": test_user_data["password"],
}
# Act
response = await client.post("/api/v1/auth/login", json=login_data)
# Assert: Password should not appear in logs
log_text = " ".join([record.message for record in caplog.records])
assert test_user_data["password"] not in log_text
if response.status_code == 200:
token = response.json().get("access_token", "")
# Token should not be fully logged (may log prefix)
if token:
assert token not in log_text
# ============================================================================
# Integration Tests: CORS Middleware
# ============================================================================
class TestCORSMiddleware:
"""Test CORS (Cross-Origin Resource Sharing) configuration."""
async def test_cors_preflight_request(self, client):
"""Test CORS preflight OPTIONS request."""
# Act
response = await client.options(
"/api/v1/strategies",
headers={
"Origin": "http://localhost:3000",
"Access-Control-Request-Method": "GET",
},
)
# Assert: May return 200 or 405 if CORS not configured
assert response.status_code in [200, 405]
# If CORS is configured, check headers
if response.status_code == 200:
assert "access-control-allow-origin" in [
h.lower() for h in response.headers.keys()
]
async def test_cors_headers_on_response(self, client, test_user, auth_headers):
"""Test that CORS headers are present on API responses."""
# Act
response = await client.get(
"/api/v1/strategies",
headers={**auth_headers, "Origin": "http://localhost:3000"},
)
# Assert
assert response.status_code == 200
# CORS headers may be present if configured
# assert "access-control-allow-origin" in [h.lower() for h in response.headers.keys()]
async def test_cors_credentials_allowed(self, client):
"""Test CORS credentials configuration."""
# This tests if cookies/credentials are allowed
# May not be applicable if using JWT bearer tokens only
pass
# ============================================================================
# Integration Tests: Request ID Tracking
# ============================================================================
class TestRequestIDTracking:
"""Test request ID generation and tracking."""
async def test_request_id_in_response_headers(self, client):
"""Test that responses include request ID header."""
# Act
response = await client.get("/api/v1/strategies")
# Assert: May include X-Request-ID header
# request_id = response.headers.get("X-Request-ID")
# if request_id:
# assert len(request_id) > 0
async def test_request_id_propagation(self, client):
"""Test that request ID from client is preserved."""
# Arrange
client_request_id = "client-req-123"
# Act
response = await client.get(
"/api/v1/strategies",
headers={"X-Request-ID": client_request_id},
)
# Assert: Server may preserve client's request ID
# response_request_id = response.headers.get("X-Request-ID")
# assert response_request_id == client_request_id
# ============================================================================
# Integration Tests: Rate Limiting
# ============================================================================
class TestRateLimiting:
"""Test rate limiting middleware (if implemented)."""
async def test_rate_limit_not_exceeded(self, client, test_user, auth_headers):
"""Test normal request rate is allowed."""
# Act: Make reasonable number of requests
for _ in range(5):
response = await client.get("/api/v1/strategies", headers=auth_headers)
assert response.status_code == 200
async def test_rate_limit_headers(self, client, test_user, auth_headers):
"""Test that rate limit headers are included."""
# Act
response = await client.get("/api/v1/strategies", headers=auth_headers)
# Assert: May include rate limit headers
# assert "X-RateLimit-Limit" in response.headers
# assert "X-RateLimit-Remaining" in response.headers
async def test_rate_limit_exceeded(self, client, test_user_data):
"""Test that excessive requests are rate limited."""
# Arrange: Login endpoint is good for rate limit testing
login_data = {
"username": test_user_data["username"],
"password": "wrong_password",
}
# Act: Make many rapid requests
responses = []
for _ in range(50):
response = await client.post("/api/v1/auth/login", json=login_data)
responses.append(response)
# Assert: Should eventually get rate limited (429)
status_codes = [r.status_code for r in responses]
# May include 429 Too Many Requests if rate limiting implemented
# assert 429 in status_codes or all(code == 401 for code in status_codes)
# ============================================================================
# Integration Tests: Content Negotiation
# ============================================================================
class TestContentNegotiation:
"""Test content type handling."""
async def test_json_content_type_accepted(self, client, test_user, auth_headers):
"""Test that application/json is accepted."""
# Arrange
strategy_data = {
"name": "Test Strategy",
"description": "Test",
}
# Act
response = await client.post(
"/api/v1/strategies",
json=strategy_data,
headers={**auth_headers, "Content-Type": "application/json"},
)
# Assert
assert response.status_code == 201
async def test_json_response_content_type(self, client, test_user, auth_headers):
"""Test that responses have JSON content type."""
# Act
response = await client.get("/api/v1/strategies", headers=auth_headers)
# Assert
assert response.status_code == 200
assert "application/json" in response.headers.get("content-type", "")
async def test_unsupported_content_type_rejected(self, client, auth_headers):
"""Test that unsupported content types are rejected."""
# Act: Send XML instead of JSON
response = await client.post(
"/api/v1/strategies",
data="<xml>data</xml>",
headers={**auth_headers, "Content-Type": "application/xml"},
)
# Assert: Should reject (415 Unsupported Media Type or 422)
assert response.status_code in [415, 422]
# ============================================================================
# Edge Cases: Middleware
# ============================================================================
class TestMiddlewareEdgeCases:
"""Test edge cases in middleware handling."""
async def test_very_large_request_body(self, client, test_user, auth_headers):
"""Test handling of very large request bodies."""
# Arrange: Create 1MB JSON
large_params = {"key": "x" * 1_000_000}
strategy_data = {
"name": "Large Body Test",
"description": "Testing large request",
"parameters": large_params,
}
# Act
response = await client.post(
"/api/v1/strategies",
json=strategy_data,
headers=auth_headers,
)
# Assert: Should either accept or reject gracefully
assert response.status_code in [201, 413, 422] # 413 = Payload Too Large
async def test_malformed_json_request(self, client, auth_headers):
"""Test handling of malformed JSON."""
# Act
response = await client.post(
"/api/v1/strategies",
data='{"name": "test", invalid json}',
headers={**auth_headers, "Content-Type": "application/json"},
)
# Assert
assert response.status_code == 422
async def test_empty_request_body(self, client, auth_headers):
"""Test handling of empty request body."""
# Act
response = await client.post(
"/api/v1/strategies",
data="",
headers={**auth_headers, "Content-Type": "application/json"},
)
# Assert
assert response.status_code == 422
async def test_null_request_body(self, client, auth_headers):
"""Test handling of null JSON body."""
# Act
response = await client.post(
"/api/v1/strategies",
json=None,
headers=auth_headers,
)
# Assert
assert response.status_code == 422
async def test_concurrent_requests_different_users(self, client, db_session):
"""Test middleware handles concurrent requests correctly."""
# Arrange
import asyncio
try:
from tradingagents.api.services.auth_service import create_access_token
user1_headers = {
"Authorization": f"Bearer {create_access_token({'sub': 'user1'})}"
}
user2_headers = {
"Authorization": f"Bearer {create_access_token({'sub': 'user2'})}"
}
# Act: Make concurrent requests for different users
tasks = [
client.get("/api/v1/strategies", headers=user1_headers),
client.get("/api/v1/strategies", headers=user2_headers),
client.get("/api/v1/strategies", headers=user1_headers),
client.get("/api/v1/strategies", headers=user2_headers),
]
responses = await asyncio.gather(*tasks, return_exceptions=True)
# Assert: All should complete without mixing user contexts
# (This tests request context isolation)
assert len(responses) == 4
except ImportError:
pytest.skip("Auth service not implemented yet")
async def test_special_characters_in_url(self, client, test_user, auth_headers):
"""Test URL encoding and special characters."""
# Act: Try various special characters in URL
special_chars = ["%20", "%2F", "..%2F", "%00"]
for char in special_chars:
response = await client.get(
f"/api/v1/strategies/{char}",
headers=auth_headers,
)
# Assert: Should handle gracefully (not crash)
assert response.status_code in [400, 404, 422]
async def test_very_long_url_path(self, client, test_user, auth_headers):
"""Test handling of very long URL paths."""
# Arrange
long_path = "a" * 10000
# Act
response = await client.get(
f"/api/v1/strategies/{long_path}",
headers=auth_headers,
)
# Assert: Should reject gracefully
assert response.status_code in [400, 404, 414, 422] # 414 = URI Too Long
async def test_header_injection_prevention(self, client):
"""Test that header injection is prevented."""
# Arrange: Try to inject headers via CRLF
malicious_header = "Bearer token\r\nX-Injected: malicious"
# Act
response = await client.get(
"/api/v1/strategies",
headers={"Authorization": malicious_header},
)
# Assert: Should reject or sanitize
assert response.status_code in [400, 401]
# ============================================================================
# Security Tests: Middleware
# ============================================================================
class TestMiddlewareSecurity:
"""Test security aspects of middleware."""
async def test_security_headers_present(self, client):
"""Test that security headers are set."""
# Act
response = await client.get("/api/v1/strategies")
# Assert: Check for common security headers
headers = {k.lower(): v for k, v in response.headers.items()}
# May include security headers like:
# - X-Content-Type-Options: nosniff
# - X-Frame-Options: DENY
# - X-XSS-Protection: 1; mode=block
# These are optional but recommended
async def test_no_server_version_leak(self, client):
"""Test that Server header doesn't leak version info."""
# Act
response = await client.get("/api/v1/strategies")
# Assert: Server header should be minimal
server_header = response.headers.get("Server", "")
# Should not contain version numbers or detailed info
assert "uvicorn" not in server_header.lower() or "/" not in server_header
async def test_error_messages_dont_leak_info(self, client):
"""Test that error messages don't leak sensitive information."""
# Act: Trigger various errors
response = await client.get("/api/v1/strategies/99999")
# Assert
assert response.status_code == 404
data = response.json()
error_text = str(data).lower()
# Should not leak database info
assert "sql" not in error_text
assert "database" not in error_text
assert "table" not in error_text
async def test_method_not_allowed_handling(self, client):
"""Test handling of unsupported HTTP methods."""
# Act: Try PATCH on endpoint that doesn't support it
response = await client.patch("/api/v1/strategies")
# Assert
assert response.status_code == 405 # Method Not Allowed
assert "Allow" in response.headers or "allow" in response.headers