670 lines
18 KiB
Python
670 lines
18 KiB
Python
"""
|
|
Shared pytest fixtures for API tests.
|
|
|
|
This module provides fixtures for testing the FastAPI backend:
|
|
- Test database with SQLAlchemy async engine
|
|
- Test FastAPI client with httpx.AsyncClient
|
|
- Test users and JWT tokens
|
|
- Mock authentication dependencies
|
|
- Database session fixtures
|
|
|
|
All fixtures follow TDD principles - they define the expected API
|
|
before implementation exists.
|
|
"""
|
|
|
|
import os
|
|
import pytest
|
|
import asyncio
|
|
from typing import AsyncGenerator, Generator, Dict, Any
|
|
from unittest.mock import Mock, patch, AsyncMock
|
|
from datetime import datetime, timedelta
|
|
|
|
|
|
# ============================================================================
|
|
# Pytest Configuration
|
|
# ============================================================================
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop_policy():
|
|
"""Set event loop policy for async tests."""
|
|
return asyncio.DefaultEventLoopPolicy()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop(event_loop_policy):
|
|
"""Create event loop for session scope."""
|
|
loop = event_loop_policy.new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
|
|
# ============================================================================
|
|
# Database Fixtures
|
|
# ============================================================================
|
|
|
|
@pytest.fixture
|
|
async def db_engine():
|
|
"""
|
|
Create async SQLAlchemy engine for testing.
|
|
|
|
Uses SQLite in-memory database for fast, isolated tests.
|
|
Creates all tables before test, drops after test.
|
|
|
|
Yields:
|
|
AsyncEngine: SQLAlchemy async engine
|
|
|
|
Example:
|
|
async def test_database(db_engine):
|
|
async with db_engine.begin() as conn:
|
|
result = await conn.execute(text("SELECT 1"))
|
|
assert result.scalar() == 1
|
|
"""
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
|
|
|
# Create in-memory SQLite database
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
echo=False,
|
|
future=True,
|
|
)
|
|
|
|
# Import models to ensure they're registered
|
|
try:
|
|
from tradingagents.api.models import Base
|
|
|
|
# Create all tables
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
except ImportError:
|
|
# Models don't exist yet (TDD - tests written first)
|
|
pass
|
|
|
|
yield engine
|
|
|
|
# Cleanup
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_session(db_engine):
|
|
"""
|
|
Create async database session for testing.
|
|
|
|
Provides a database session that rolls back after each test
|
|
to ensure test isolation.
|
|
|
|
Args:
|
|
db_engine: Test database engine fixture
|
|
|
|
Yields:
|
|
AsyncSession: SQLAlchemy async session
|
|
|
|
Example:
|
|
async def test_create_user(db_session):
|
|
user = User(username="test", email="test@example.com")
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
assert user.id is not None
|
|
"""
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
|
|
# Create session factory
|
|
async_session = async_sessionmaker(
|
|
db_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
# Create session
|
|
async with async_session() as session:
|
|
yield session
|
|
# Rollback any uncommitted changes
|
|
await session.rollback()
|
|
|
|
|
|
@pytest.fixture
|
|
async def clean_db(db_session):
|
|
"""
|
|
Ensure database is clean before test.
|
|
|
|
Deletes all data from all tables to ensure test isolation.
|
|
|
|
Args:
|
|
db_session: Database session fixture
|
|
|
|
Example:
|
|
async def test_with_clean_db(clean_db, db_session):
|
|
# Database is guaranteed to be empty
|
|
result = await db_session.execute(select(User))
|
|
assert len(result.scalars().all()) == 0
|
|
"""
|
|
try:
|
|
from tradingagents.api.models import User, Strategy
|
|
from sqlalchemy import delete
|
|
|
|
# Delete all strategies first (foreign key constraint)
|
|
await db_session.execute(delete(Strategy))
|
|
await db_session.execute(delete(User))
|
|
await db_session.commit()
|
|
except ImportError:
|
|
# Models don't exist yet
|
|
pass
|
|
|
|
yield
|
|
|
|
|
|
# ============================================================================
|
|
# FastAPI Client Fixtures
|
|
# ============================================================================
|
|
|
|
@pytest.fixture
|
|
async def test_app():
|
|
"""
|
|
Create FastAPI test application.
|
|
|
|
Returns the FastAPI app instance configured for testing.
|
|
Database dependency is overridden to use test database.
|
|
|
|
Yields:
|
|
FastAPI: Test application instance
|
|
|
|
Example:
|
|
async def test_root_endpoint(test_app):
|
|
assert test_app is not None
|
|
assert hasattr(test_app, "routes")
|
|
"""
|
|
try:
|
|
from tradingagents.api.main import app
|
|
yield app
|
|
except ImportError:
|
|
# App doesn't exist yet (TDD)
|
|
from fastapi import FastAPI
|
|
|
|
# Create minimal app for testing
|
|
app = FastAPI(title="TradingAgents API (Test)", version="0.1.0")
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "TradingAgents API"}
|
|
|
|
yield app
|
|
|
|
|
|
@pytest.fixture
|
|
async def client(test_app, db_session):
|
|
"""
|
|
Create async HTTP client for API testing.
|
|
|
|
Uses httpx.AsyncClient to test FastAPI endpoints.
|
|
Overrides database dependency to use test database.
|
|
|
|
Args:
|
|
test_app: FastAPI test application
|
|
db_session: Test database session
|
|
|
|
Yields:
|
|
AsyncClient: HTTP client for making requests
|
|
|
|
Example:
|
|
async def test_api_endpoint(client):
|
|
response = await client.get("/api/v1/strategies")
|
|
assert response.status_code == 200
|
|
"""
|
|
import httpx
|
|
from httpx import AsyncClient
|
|
|
|
# Override database dependency
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
try:
|
|
from tradingagents.api.dependencies import get_db
|
|
test_app.dependency_overrides[get_db] = override_get_db
|
|
except ImportError:
|
|
# Dependency doesn't exist yet
|
|
pass
|
|
|
|
async with AsyncClient(transport=httpx.ASGITransport(app=test_app), base_url="http://test") as ac:
|
|
yield ac
|
|
|
|
# Clear overrides
|
|
test_app.dependency_overrides.clear()
|
|
|
|
|
|
# ============================================================================
|
|
# Authentication Fixtures
|
|
# ============================================================================
|
|
|
|
@pytest.fixture
|
|
def test_user_data() -> Dict[str, Any]:
|
|
"""
|
|
Test user data for registration/login.
|
|
|
|
Returns:
|
|
dict: User data with username, email, password
|
|
|
|
Example:
|
|
def test_user_creation(test_user_data):
|
|
assert test_user_data["username"] == "testuser"
|
|
assert "password" in test_user_data
|
|
"""
|
|
return {
|
|
"username": "testuser",
|
|
"email": "test@example.com",
|
|
"password": "SecurePassword123!",
|
|
"full_name": "Test User",
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def second_user_data() -> Dict[str, Any]:
|
|
"""
|
|
Second test user for testing user isolation.
|
|
|
|
Returns:
|
|
dict: Second user's data
|
|
"""
|
|
return {
|
|
"username": "otheruser",
|
|
"email": "other@example.com",
|
|
"password": "AnotherPassword456!",
|
|
"full_name": "Other User",
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
async def test_user(db_session, test_user_data):
|
|
"""
|
|
Create test user in database.
|
|
|
|
Creates a user with hashed password for authentication testing.
|
|
|
|
Args:
|
|
db_session: Database session
|
|
test_user_data: Test user data
|
|
|
|
Yields:
|
|
User: Created user model instance
|
|
|
|
Example:
|
|
async def test_with_user(test_user):
|
|
assert test_user.username == "testuser"
|
|
assert test_user.id is not None
|
|
"""
|
|
try:
|
|
from tradingagents.api.models import User
|
|
from tradingagents.api.services.auth_service import hash_password
|
|
|
|
user = User(
|
|
username=test_user_data["username"],
|
|
email=test_user_data["email"],
|
|
hashed_password=hash_password(test_user_data["password"]),
|
|
full_name=test_user_data.get("full_name"),
|
|
)
|
|
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
await db_session.refresh(user)
|
|
|
|
yield user
|
|
except ImportError:
|
|
# Models/services don't exist yet
|
|
yield None
|
|
|
|
|
|
@pytest.fixture
|
|
async def second_user(db_session, second_user_data):
|
|
"""
|
|
Create second test user in database.
|
|
|
|
Used for testing user isolation and authorization.
|
|
|
|
Args:
|
|
db_session: Database session
|
|
second_user_data: Second user data
|
|
|
|
Yields:
|
|
User: Created user model instance
|
|
"""
|
|
try:
|
|
from tradingagents.api.models import User
|
|
from tradingagents.api.services.auth_service import hash_password
|
|
|
|
user = User(
|
|
username=second_user_data["username"],
|
|
email=second_user_data["email"],
|
|
hashed_password=hash_password(second_user_data["password"]),
|
|
full_name=second_user_data.get("full_name"),
|
|
)
|
|
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
await db_session.refresh(user)
|
|
|
|
yield user
|
|
except ImportError:
|
|
yield None
|
|
|
|
|
|
@pytest.fixture
|
|
def jwt_token(test_user_data) -> str:
|
|
"""
|
|
Generate valid JWT token for testing.
|
|
|
|
Creates a JWT token for authenticated requests.
|
|
|
|
Args:
|
|
test_user_data: Test user data
|
|
|
|
Returns:
|
|
str: JWT access token
|
|
|
|
Example:
|
|
async def test_protected_endpoint(client, jwt_token):
|
|
response = await client.get(
|
|
"/api/v1/strategies",
|
|
headers={"Authorization": f"Bearer {jwt_token}"}
|
|
)
|
|
assert response.status_code == 200
|
|
"""
|
|
try:
|
|
from tradingagents.api.services.auth_service import create_access_token
|
|
|
|
token_data = {"sub": test_user_data["username"]}
|
|
token = create_access_token(token_data)
|
|
return token
|
|
except ImportError:
|
|
# Auth service doesn't exist yet
|
|
return "test-jwt-token-placeholder"
|
|
|
|
|
|
@pytest.fixture
|
|
def expired_jwt_token(test_user_data) -> str:
|
|
"""
|
|
Generate expired JWT token for testing.
|
|
|
|
Creates an expired JWT token to test token expiration handling.
|
|
|
|
Returns:
|
|
str: Expired JWT access token
|
|
|
|
Example:
|
|
async def test_expired_token(client, expired_jwt_token):
|
|
response = await client.get(
|
|
"/api/v1/strategies",
|
|
headers={"Authorization": f"Bearer {expired_jwt_token}"}
|
|
)
|
|
assert response.status_code == 401
|
|
"""
|
|
try:
|
|
from tradingagents.api.services.auth_service import create_access_token
|
|
|
|
token_data = {"sub": test_user_data["username"]}
|
|
# Create token that expired 1 hour ago
|
|
token = create_access_token(
|
|
token_data,
|
|
expires_delta=timedelta(hours=-1)
|
|
)
|
|
return token
|
|
except ImportError:
|
|
return "expired-jwt-token-placeholder"
|
|
|
|
|
|
@pytest.fixture
|
|
def invalid_jwt_token() -> str:
|
|
"""
|
|
Generate invalid JWT token for testing.
|
|
|
|
Returns:
|
|
str: Invalid/malformed JWT token
|
|
|
|
Example:
|
|
async def test_invalid_token(client, invalid_jwt_token):
|
|
response = await client.get(
|
|
"/api/v1/strategies",
|
|
headers={"Authorization": f"Bearer {invalid_jwt_token}"}
|
|
)
|
|
assert response.status_code == 401
|
|
"""
|
|
return "invalid.jwt.token"
|
|
|
|
|
|
@pytest.fixture
|
|
def auth_headers(jwt_token) -> Dict[str, str]:
|
|
"""
|
|
Create authorization headers with JWT token.
|
|
|
|
Args:
|
|
jwt_token: Valid JWT token
|
|
|
|
Returns:
|
|
dict: Headers with Authorization bearer token
|
|
|
|
Example:
|
|
async def test_authenticated_request(client, auth_headers):
|
|
response = await client.get("/api/v1/strategies", headers=auth_headers)
|
|
assert response.status_code == 200
|
|
"""
|
|
return {"Authorization": f"Bearer {jwt_token}"}
|
|
|
|
|
|
# ============================================================================
|
|
# Strategy Fixtures
|
|
# ============================================================================
|
|
|
|
@pytest.fixture
|
|
def strategy_data() -> Dict[str, Any]:
|
|
"""
|
|
Test strategy data for creation.
|
|
|
|
Returns:
|
|
dict: Strategy data with required fields
|
|
|
|
Example:
|
|
async def test_create_strategy(client, auth_headers, strategy_data):
|
|
response = await client.post(
|
|
"/api/v1/strategies",
|
|
json=strategy_data,
|
|
headers=auth_headers
|
|
)
|
|
assert response.status_code == 201
|
|
"""
|
|
return {
|
|
"name": "Moving Average Crossover",
|
|
"description": "Simple moving average crossover strategy",
|
|
"parameters": {
|
|
"fast_period": 10,
|
|
"slow_period": 20,
|
|
"symbol": "AAPL",
|
|
},
|
|
"is_active": True,
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def strategy_data_minimal() -> Dict[str, Any]:
|
|
"""
|
|
Minimal strategy data (only required fields).
|
|
|
|
Returns:
|
|
dict: Minimal strategy data
|
|
"""
|
|
return {
|
|
"name": "Minimal Strategy",
|
|
"description": "A minimal test strategy",
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
async def test_strategy(db_session, test_user, strategy_data):
|
|
"""
|
|
Create test strategy in database.
|
|
|
|
Creates a strategy owned by test_user.
|
|
|
|
Args:
|
|
db_session: Database session
|
|
test_user: Owner user
|
|
strategy_data: Strategy data
|
|
|
|
Yields:
|
|
Strategy: Created strategy model instance
|
|
|
|
Example:
|
|
async def test_with_strategy(test_strategy):
|
|
assert test_strategy.name == "Moving Average Crossover"
|
|
assert test_strategy.user_id is not None
|
|
"""
|
|
if test_user is None:
|
|
yield None
|
|
return
|
|
|
|
try:
|
|
from tradingagents.api.models import Strategy
|
|
|
|
strategy = Strategy(
|
|
name=strategy_data["name"],
|
|
description=strategy_data["description"],
|
|
parameters=strategy_data.get("parameters", {}),
|
|
is_active=strategy_data.get("is_active", True),
|
|
user_id=test_user.id,
|
|
)
|
|
|
|
db_session.add(strategy)
|
|
await db_session.commit()
|
|
await db_session.refresh(strategy)
|
|
|
|
yield strategy
|
|
except ImportError:
|
|
yield None
|
|
|
|
|
|
@pytest.fixture
|
|
async def multiple_strategies(db_session, test_user):
|
|
"""
|
|
Create multiple test strategies for list/pagination testing.
|
|
|
|
Creates 5 strategies with different names and parameters.
|
|
|
|
Args:
|
|
db_session: Database session
|
|
test_user: Owner user
|
|
|
|
Yields:
|
|
list[Strategy]: List of created strategies
|
|
"""
|
|
if test_user is None:
|
|
yield []
|
|
return
|
|
|
|
try:
|
|
from tradingagents.api.models import Strategy
|
|
|
|
strategies = []
|
|
for i in range(5):
|
|
strategy = Strategy(
|
|
name=f"Strategy {i+1}",
|
|
description=f"Test strategy number {i+1}",
|
|
parameters={"index": i},
|
|
is_active=i % 2 == 0, # Alternate active/inactive
|
|
user_id=test_user.id,
|
|
)
|
|
db_session.add(strategy)
|
|
strategies.append(strategy)
|
|
|
|
await db_session.commit()
|
|
|
|
# Refresh all strategies
|
|
for strategy in strategies:
|
|
await db_session.refresh(strategy)
|
|
|
|
yield strategies
|
|
except ImportError:
|
|
yield []
|
|
|
|
|
|
# ============================================================================
|
|
# Mock Environment Fixtures
|
|
# ============================================================================
|
|
|
|
@pytest.fixture
|
|
def mock_env_jwt_secret():
|
|
"""
|
|
Mock environment with JWT secret key.
|
|
|
|
Sets required environment variables for JWT authentication.
|
|
|
|
Yields:
|
|
None
|
|
|
|
Example:
|
|
def test_jwt_config(mock_env_jwt_secret):
|
|
assert os.getenv("JWT_SECRET_KEY") is not None
|
|
"""
|
|
with patch.dict(os.environ, {
|
|
"JWT_SECRET_KEY": "test-secret-key-for-jwt-signing-very-secure-123",
|
|
"JWT_ALGORITHM": "HS256",
|
|
"JWT_EXPIRATION_MINUTES": "30",
|
|
}, clear=False):
|
|
yield
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_env_database():
|
|
"""
|
|
Mock environment with database URL.
|
|
|
|
Sets database connection string for testing.
|
|
|
|
Yields:
|
|
None
|
|
"""
|
|
with patch.dict(os.environ, {
|
|
"DATABASE_URL": "sqlite+aiosqlite:///:memory:",
|
|
}, clear=False):
|
|
yield
|
|
|
|
|
|
# ============================================================================
|
|
# Utility Fixtures
|
|
# ============================================================================
|
|
|
|
@pytest.fixture
|
|
def sample_sql_injection_payloads() -> list[str]:
|
|
"""
|
|
Sample SQL injection attack payloads for security testing.
|
|
|
|
Returns:
|
|
list[str]: Common SQL injection patterns
|
|
|
|
Example:
|
|
async def test_sql_injection_prevention(client, sample_sql_injection_payloads):
|
|
for payload in sample_sql_injection_payloads:
|
|
response = await client.get(f"/api/v1/strategies/{payload}")
|
|
assert response.status_code in [400, 404] # Not 500
|
|
"""
|
|
return [
|
|
"1' OR '1'='1",
|
|
"1; DROP TABLE users--",
|
|
"' OR 1=1--",
|
|
"admin'--",
|
|
"' UNION SELECT * FROM users--",
|
|
"1' AND '1'='1",
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_xss_payloads() -> list[str]:
|
|
"""
|
|
Sample XSS attack payloads for security testing.
|
|
|
|
Returns:
|
|
list[str]: Common XSS patterns
|
|
"""
|
|
return [
|
|
"<script>alert('XSS')</script>",
|
|
"javascript:alert('XSS')",
|
|
"<img src=x onerror=alert('XSS')>",
|
|
"<svg onload=alert('XSS')>",
|
|
]
|