feat(db): add User model fields for tax, timezone, API key - Fixes #3
This commit is contained in:
parent
9933a929df
commit
d3892b0da9
17
CHANGELOG.md
17
CHANGELOG.md
|
|
@ -36,6 +36,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- New dependencies in pyproject.toml: fastapi, uvicorn, sqlalchemy, alembic, pydantic-settings, passlib, argon2-cffi, python-multipart, python-jose, cryptography
|
||||
- API documentation generated from FastAPI OpenAPI schema (available at /docs and /redoc)
|
||||
|
||||
- User model enhancement with profile and API key management (Issue #3)
|
||||
- Extended User model with tax_jurisdiction and timezone fields [file:tradingagents/api/models/user.py:47-54](tradingagents/api/models/user.py)
|
||||
- Tax jurisdiction field supporting country (e.g., "US", "AU") and state/province level codes (e.g., "US-CA", "AU-NSW")
|
||||
- IANA timezone identifier field (e.g., "America/New_York", "Australia/Sydney") with automatic validation
|
||||
- Email verification status tracking via is_verified boolean field [file:tradingagents/api/models/user.py:60-64](tradingagents/api/models/user.py)
|
||||
- Secure API key management with bcrypt hashing and unique constraints [file:tradingagents/api/models/user.py:55-59](tradingagents/api/models/user.py)
|
||||
- API key service module with generate_api_key(), hash_api_key(), and verify_api_key() functions [file:tradingagents/api/services/api_key_service.py](tradingagents/api/services/api_key_service.py)
|
||||
- API key generation using secrets.token_urlsafe() with 256-bit entropy and 'ta_' prefix
|
||||
- Bcrypt-based API key hashing using pwdlib.PasswordHash for secure storage
|
||||
- Constant-time verification to prevent timing attacks on API keys
|
||||
- Timezone validator using IANA zoneinfo database [file:tradingagents/api/services/validators.py:134-191](tradingagents/api/services/validators.py)
|
||||
- Tax jurisdiction validator supporting 50+ country codes and state/province subdivisions [file:tradingagents/api/services/validators.py:193-283](tradingagents/api/services/validators.py)
|
||||
- Utility functions get_available_timezones() and get_available_tax_jurisdictions() for UI dropdowns [file:tradingagents/api/services/validators.py:285-333](tradingagents/api/services/validators.py)
|
||||
- Database migration 002_add_user_profile_fields.py with proper defaults and constraints [file:migrations/versions/002_add_user_profile_fields.py](migrations/versions/002_add_user_profile_fields.py)
|
||||
- Migration rollback support for reversible schema changes
|
||||
- Comprehensive docstrings and security considerations for all functions
|
||||
|
||||
- Test fixtures directory with centralized mock data (Issue #51)
|
||||
- FixtureLoader class for loading JSON fixtures with automatic datetime parsing [file:tests/fixtures/__init__.py](tests/fixtures/__init__.py)
|
||||
- Stock data fixtures: US market OHLCV, Chinese market OHLCV, standardized data [file:tests/fixtures/stock_data/](tests/fixtures/stock_data/)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,115 @@
|
|||
"""Add user profile fields - tax_jurisdiction, timezone, api_key_hash, is_verified
|
||||
|
||||
Revision ID: 002
|
||||
Revises: 001
|
||||
Create Date: 2025-12-26 13:00:00.000000
|
||||
|
||||
This migration adds four new fields to the users table:
|
||||
- tax_jurisdiction: Tax jurisdiction code (default: AU)
|
||||
- timezone: IANA timezone identifier (default: Australia/Sydney)
|
||||
- api_key_hash: Bcrypt hash of API key for programmatic access (nullable)
|
||||
- is_verified: Email verification status (default: False)
|
||||
|
||||
All existing users will get default values for the new required fields.
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '002'
|
||||
down_revision: Union[str, None] = '001'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add tax_jurisdiction, timezone, api_key_hash, and is_verified columns to users table.
|
||||
|
||||
For existing rows:
|
||||
- tax_jurisdiction defaults to "AU"
|
||||
- timezone defaults to "Australia/Sydney"
|
||||
- api_key_hash is NULL
|
||||
- is_verified is False
|
||||
"""
|
||||
# Add tax_jurisdiction column
|
||||
op.add_column(
|
||||
'users',
|
||||
sa.Column(
|
||||
'tax_jurisdiction',
|
||||
sa.String(length=10),
|
||||
nullable=False,
|
||||
server_default='AU',
|
||||
comment='Tax jurisdiction code (e.g., US, US-CA, AU-NSW)'
|
||||
)
|
||||
)
|
||||
|
||||
# Add timezone column
|
||||
op.add_column(
|
||||
'users',
|
||||
sa.Column(
|
||||
'timezone',
|
||||
sa.String(length=50),
|
||||
nullable=False,
|
||||
server_default='Australia/Sydney',
|
||||
comment='IANA timezone identifier (e.g., America/New_York, UTC)'
|
||||
)
|
||||
)
|
||||
|
||||
# Add api_key_hash column with unique constraint and index
|
||||
op.add_column(
|
||||
'users',
|
||||
sa.Column(
|
||||
'api_key_hash',
|
||||
sa.String(length=255),
|
||||
nullable=True,
|
||||
comment='Bcrypt hash of API key for programmatic access'
|
||||
)
|
||||
)
|
||||
# Create unique constraint for api_key_hash
|
||||
op.create_unique_constraint(
|
||||
'uq_users_api_key_hash',
|
||||
'users',
|
||||
['api_key_hash']
|
||||
)
|
||||
# Create index for fast lookups
|
||||
op.create_index(
|
||||
'ix_users_api_key_hash',
|
||||
'users',
|
||||
['api_key_hash'],
|
||||
unique=False
|
||||
)
|
||||
|
||||
# Add is_verified column
|
||||
op.add_column(
|
||||
'users',
|
||||
sa.Column(
|
||||
'is_verified',
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default='0',
|
||||
comment='Whether user email has been verified'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove tax_jurisdiction, timezone, api_key_hash, and is_verified columns from users table.
|
||||
|
||||
WARNING: This will permanently delete data in these columns!
|
||||
"""
|
||||
# Remove is_verified column
|
||||
op.drop_column('users', 'is_verified')
|
||||
|
||||
# Remove api_key_hash column (drop index and constraint first)
|
||||
op.drop_index('ix_users_api_key_hash', 'users')
|
||||
op.drop_constraint('uq_users_api_key_hash', 'users', type_='unique')
|
||||
op.drop_column('users', 'api_key_hash')
|
||||
|
||||
# Remove timezone column
|
||||
op.drop_column('users', 'timezone')
|
||||
|
||||
# Remove tax_jurisdiction column
|
||||
op.drop_column('users', 'tax_jurisdiction')
|
||||
|
|
@ -253,6 +253,8 @@ def test_user_data() -> Dict[str, Any]:
|
|||
"email": "test@example.com",
|
||||
"password": "SecurePassword123!",
|
||||
"full_name": "Test User",
|
||||
"timezone": "America/New_York", # Issue #3
|
||||
"tax_jurisdiction": "US-NY", # Issue #3
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -269,6 +271,8 @@ def second_user_data() -> Dict[str, Any]:
|
|||
"email": "other@example.com",
|
||||
"password": "AnotherPassword456!",
|
||||
"full_name": "Other User",
|
||||
"timezone": "America/Los_Angeles", # Issue #3
|
||||
"tax_jurisdiction": "US-CA", # Issue #3
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -667,3 +671,228 @@ def sample_xss_payloads() -> list[str]:
|
|||
"<img src=x onerror=alert('XSS')>",
|
||||
"<svg onload=alert('XSS')>",
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Issue #3 Fixtures: API Keys, Timezones, Tax Jurisdictions
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def verified_user_data() -> Dict[str, Any]:
|
||||
"""
|
||||
Test user data for verified user (Issue #3).
|
||||
|
||||
Returns:
|
||||
dict: Verified user data with all Issue #3 fields
|
||||
|
||||
Example:
|
||||
async def test_verified_user(verified_user_data):
|
||||
assert verified_user_data["is_verified"] is True
|
||||
"""
|
||||
return {
|
||||
"username": "verifieduser",
|
||||
"email": "verified@example.com",
|
||||
"password": "VerifiedPassword123!",
|
||||
"full_name": "Verified User",
|
||||
"timezone": "UTC",
|
||||
"tax_jurisdiction": "US",
|
||||
"is_verified": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def verified_user(db_session, verified_user_data):
|
||||
"""
|
||||
Create verified test user in database (Issue #3).
|
||||
|
||||
Creates a verified user with timezone and tax jurisdiction.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
verified_user_data: Verified user data
|
||||
|
||||
Yields:
|
||||
User: Created verified user model instance
|
||||
"""
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
from tradingagents.api.services.auth_service import hash_password
|
||||
|
||||
user = User(
|
||||
username=verified_user_data["username"],
|
||||
email=verified_user_data["email"],
|
||||
hashed_password=hash_password(verified_user_data["password"]),
|
||||
full_name=verified_user_data.get("full_name"),
|
||||
timezone=verified_user_data.get("timezone"),
|
||||
tax_jurisdiction=verified_user_data.get("tax_jurisdiction"),
|
||||
is_verified=verified_user_data.get("is_verified", True),
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
yield user
|
||||
except ImportError:
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def user_with_api_key(db_session, test_user_data):
|
||||
"""
|
||||
Create test user with API key in database (Issue #3).
|
||||
|
||||
Creates a user with a hashed API key.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
test_user_data: Test user data
|
||||
|
||||
Yields:
|
||||
tuple[User, str]: (Created user, plain API key)
|
||||
"""
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
from tradingagents.api.services.auth_service import hash_password
|
||||
from tradingagents.api.services.api_key_service import (
|
||||
generate_api_key,
|
||||
hash_api_key,
|
||||
)
|
||||
|
||||
# Generate API key
|
||||
plain_api_key = generate_api_key()
|
||||
hashed_api_key = hash_api_key(plain_api_key)
|
||||
|
||||
user = User(
|
||||
username=test_user_data["username"],
|
||||
email=test_user_data["email"],
|
||||
hashed_password=hash_password(test_user_data["password"]),
|
||||
full_name=test_user_data.get("full_name"),
|
||||
api_key_hash=hashed_api_key,
|
||||
timezone=test_user_data.get("timezone"),
|
||||
tax_jurisdiction=test_user_data.get("tax_jurisdiction"),
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
yield (user, plain_api_key)
|
||||
except ImportError:
|
||||
yield (None, None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_timezones() -> list[str]:
|
||||
"""
|
||||
List of valid IANA timezones for testing (Issue #3).
|
||||
|
||||
Returns:
|
||||
list[str]: Valid timezone identifiers
|
||||
|
||||
Example:
|
||||
def test_timezones(valid_timezones):
|
||||
for tz in valid_timezones:
|
||||
assert validate_timezone(tz) is True
|
||||
"""
|
||||
return [
|
||||
"UTC",
|
||||
"GMT",
|
||||
"America/New_York",
|
||||
"America/Los_Angeles",
|
||||
"America/Chicago",
|
||||
"America/Denver",
|
||||
"Europe/London",
|
||||
"Europe/Paris",
|
||||
"Europe/Berlin",
|
||||
"Asia/Tokyo",
|
||||
"Asia/Shanghai",
|
||||
"Asia/Hong_Kong",
|
||||
"Australia/Sydney",
|
||||
"Australia/Melbourne",
|
||||
"Pacific/Auckland",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_timezones() -> list[str]:
|
||||
"""
|
||||
List of invalid timezones for testing (Issue #3).
|
||||
|
||||
Returns:
|
||||
list[str]: Invalid timezone identifiers
|
||||
|
||||
Example:
|
||||
def test_invalid_timezones(invalid_timezones):
|
||||
for tz in invalid_timezones:
|
||||
assert validate_timezone(tz) is False
|
||||
"""
|
||||
return [
|
||||
"PST",
|
||||
"EST",
|
||||
"CST",
|
||||
"MST",
|
||||
"America/InvalidCity",
|
||||
"Europe/FakePlace",
|
||||
"Random/Stuff",
|
||||
"america/new_york", # Wrong case
|
||||
"123456",
|
||||
"!@#$%",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_tax_jurisdictions() -> list[str]:
|
||||
"""
|
||||
List of valid tax jurisdictions for testing (Issue #3).
|
||||
|
||||
Returns:
|
||||
list[str]: Valid tax jurisdiction codes
|
||||
|
||||
Example:
|
||||
def test_jurisdictions(valid_tax_jurisdictions):
|
||||
for jurisdiction in valid_tax_jurisdictions:
|
||||
assert validate_tax_jurisdiction(jurisdiction) is True
|
||||
"""
|
||||
return [
|
||||
"US",
|
||||
"CA",
|
||||
"GB",
|
||||
"DE",
|
||||
"FR",
|
||||
"JP",
|
||||
"AU",
|
||||
"US-CA",
|
||||
"US-NY",
|
||||
"US-TX",
|
||||
"US-FL",
|
||||
"CA-ON",
|
||||
"CA-QC",
|
||||
"CA-BC",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_tax_jurisdictions() -> list[str]:
|
||||
"""
|
||||
List of invalid tax jurisdictions for testing (Issue #3).
|
||||
|
||||
Returns:
|
||||
list[str]: Invalid tax jurisdiction codes
|
||||
|
||||
Example:
|
||||
def test_invalid_jurisdictions(invalid_tax_jurisdictions):
|
||||
for jurisdiction in invalid_tax_jurisdictions:
|
||||
assert validate_tax_jurisdiction(jurisdiction) is False
|
||||
"""
|
||||
return [
|
||||
"InvalidFormat",
|
||||
"US_CA", # Wrong separator
|
||||
"US/CA", # Wrong separator
|
||||
"USCA", # No separator
|
||||
"us-ca", # Lowercase
|
||||
"XX-YY", # Invalid country code
|
||||
"123",
|
||||
"!@#",
|
||||
"",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,587 @@
|
|||
"""
|
||||
Test suite for API Key Service (Issue #3).
|
||||
|
||||
This module tests the API key generation, hashing, and verification service:
|
||||
1. Generate secure random API keys
|
||||
2. Hash API keys using bcrypt
|
||||
3. Verify API keys against hashes
|
||||
4. Key format validation
|
||||
5. Security best practices
|
||||
|
||||
Tests follow TDD - written before implementation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Unit Tests: API Key Generation
|
||||
# ============================================================================
|
||||
|
||||
class TestApiKeyGeneration:
|
||||
"""Test API key generation functionality."""
|
||||
|
||||
def test_generate_api_key_returns_string(self):
|
||||
"""Test that generate_api_key returns a string."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import generate_api_key
|
||||
|
||||
api_key = generate_api_key()
|
||||
|
||||
# Assert
|
||||
assert isinstance(api_key, str)
|
||||
assert len(api_key) > 0
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_generate_api_key_format(self):
|
||||
"""Test that generated API key has correct format."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import generate_api_key
|
||||
|
||||
api_key = generate_api_key()
|
||||
|
||||
# Assert: Should be prefixed with "ta_" (TradingAgents)
|
||||
assert api_key.startswith("ta_")
|
||||
|
||||
# Should contain only alphanumeric characters after prefix
|
||||
key_part = api_key[3:] # Remove "ta_" prefix
|
||||
assert re.match(r'^[A-Za-z0-9]+$', key_part)
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_generate_api_key_length(self):
|
||||
"""Test that generated API key has sufficient length for security."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import generate_api_key
|
||||
|
||||
api_key = generate_api_key()
|
||||
|
||||
# Assert: Should be at least 32 characters (including prefix)
|
||||
assert len(api_key) >= 32
|
||||
|
||||
# Key part (without prefix) should be at least 29 chars
|
||||
key_part = api_key[3:]
|
||||
assert len(key_part) >= 29
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_generate_api_key_uniqueness(self):
|
||||
"""Test that each generated API key is unique."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import generate_api_key
|
||||
|
||||
keys = [generate_api_key() for _ in range(100)]
|
||||
|
||||
# Assert: All keys should be unique
|
||||
assert len(keys) == len(set(keys))
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_generate_api_key_randomness(self):
|
||||
"""Test that API keys have sufficient randomness."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import generate_api_key
|
||||
|
||||
keys = [generate_api_key() for _ in range(10)]
|
||||
|
||||
# Assert: Keys should not have common patterns
|
||||
# Check that no two keys share same first 10 chars after prefix
|
||||
prefixes = [key[3:13] for key in keys]
|
||||
assert len(prefixes) == len(set(prefixes))
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_generate_api_key_custom_length(self):
|
||||
"""Test generating API key with custom length."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import generate_api_key
|
||||
|
||||
# Try generating key with custom length (if supported)
|
||||
api_key = generate_api_key(length=40)
|
||||
|
||||
# Assert: Should respect custom length
|
||||
assert len(api_key) >= 40 or len(api_key) >= 32 # May have min length
|
||||
except (ImportError, TypeError):
|
||||
# TypeError is ok if length parameter not implemented
|
||||
pytest.skip("Custom length not implemented yet")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Unit Tests: API Key Hashing
|
||||
# ============================================================================
|
||||
|
||||
class TestApiKeyHashing:
|
||||
"""Test API key hashing functionality."""
|
||||
|
||||
def test_hash_api_key_returns_string(self):
|
||||
"""Test that hash_api_key returns a string."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import hash_api_key
|
||||
|
||||
api_key = "ta_test1234567890abcdefghijklmnop"
|
||||
|
||||
# Act
|
||||
hashed = hash_api_key(api_key)
|
||||
|
||||
# Assert
|
||||
assert isinstance(hashed, str)
|
||||
assert len(hashed) > 0
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_hash_api_key_bcrypt_format(self):
|
||||
"""Test that hashed API key uses bcrypt format."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import hash_api_key
|
||||
|
||||
api_key = "ta_test1234567890abcdefghijklmnop"
|
||||
|
||||
# Act
|
||||
hashed = hash_api_key(api_key)
|
||||
|
||||
# Assert: Should be bcrypt format ($2b$rounds$salt+hash)
|
||||
assert hashed.startswith("$2b$") or hashed.startswith("$2a$")
|
||||
assert len(hashed) == 60 # Standard bcrypt hash length
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_hash_api_key_different_for_same_input(self):
|
||||
"""Test that hashing same key twice produces different hashes (salt)."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import hash_api_key
|
||||
|
||||
api_key = "ta_test1234567890abcdefghijklmnop"
|
||||
|
||||
# Act
|
||||
hash1 = hash_api_key(api_key)
|
||||
hash2 = hash_api_key(api_key)
|
||||
|
||||
# Assert: Should be different due to different salts
|
||||
assert hash1 != hash2
|
||||
assert hash1.startswith("$2b$")
|
||||
assert hash2.startswith("$2b$")
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_hash_api_key_different_keys_different_hashes(self):
|
||||
"""Test that different keys produce different hashes."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import hash_api_key
|
||||
|
||||
key1 = "ta_key1234567890abcdefghijklmnop"
|
||||
key2 = "ta_key9876543210zyxwvutsrqponmlk"
|
||||
|
||||
# Act
|
||||
hash1 = hash_api_key(key1)
|
||||
hash2 = hash_api_key(key2)
|
||||
|
||||
# Assert
|
||||
assert hash1 != hash2
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_hash_api_key_empty_string(self):
|
||||
"""Test hashing empty string."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import hash_api_key
|
||||
|
||||
# Act & Assert: Should handle gracefully or raise ValueError
|
||||
try:
|
||||
hashed = hash_api_key("")
|
||||
# If it doesn't raise, should still return valid hash
|
||||
assert isinstance(hashed, str)
|
||||
except ValueError:
|
||||
# Acceptable to raise ValueError for empty key
|
||||
pass
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_hash_api_key_special_characters(self):
|
||||
"""Test hashing API key with special characters."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import hash_api_key
|
||||
|
||||
api_key = "ta_key!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
|
||||
# Act
|
||||
hashed = hash_api_key(api_key)
|
||||
|
||||
# Assert: Should handle special characters
|
||||
assert isinstance(hashed, str)
|
||||
assert hashed.startswith("$2b$") or hashed.startswith("$2a$")
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Unit Tests: API Key Verification
|
||||
# ============================================================================
|
||||
|
||||
class TestApiKeyVerification:
|
||||
"""Test API key verification functionality."""
|
||||
|
||||
def test_verify_api_key_correct_key(self):
|
||||
"""Test verifying API key with correct hash."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import (
|
||||
hash_api_key,
|
||||
verify_api_key,
|
||||
)
|
||||
|
||||
api_key = "ta_correct1234567890abcdefghijk"
|
||||
hashed = hash_api_key(api_key)
|
||||
|
||||
# Act
|
||||
is_valid = verify_api_key(api_key, hashed)
|
||||
|
||||
# Assert
|
||||
assert is_valid is True
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_verify_api_key_incorrect_key(self):
|
||||
"""Test verifying API key with wrong hash."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import (
|
||||
hash_api_key,
|
||||
verify_api_key,
|
||||
)
|
||||
|
||||
correct_key = "ta_correct1234567890abcdefghijk"
|
||||
wrong_key = "ta_wrongkey1234567890abcdefghij"
|
||||
hashed = hash_api_key(correct_key)
|
||||
|
||||
# Act
|
||||
is_valid = verify_api_key(wrong_key, hashed)
|
||||
|
||||
# Assert
|
||||
assert is_valid is False
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_verify_api_key_empty_key(self):
|
||||
"""Test verifying empty API key."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import (
|
||||
hash_api_key,
|
||||
verify_api_key,
|
||||
)
|
||||
|
||||
api_key = "ta_test1234567890abcdefghijklmn"
|
||||
hashed = hash_api_key(api_key)
|
||||
|
||||
# Act
|
||||
is_valid = verify_api_key("", hashed)
|
||||
|
||||
# Assert
|
||||
assert is_valid is False
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_verify_api_key_none_key(self):
|
||||
"""Test verifying None API key."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import (
|
||||
hash_api_key,
|
||||
verify_api_key,
|
||||
)
|
||||
|
||||
api_key = "ta_test1234567890abcdefghijklmn"
|
||||
hashed = hash_api_key(api_key)
|
||||
|
||||
# Act & Assert: Should return False or raise TypeError
|
||||
try:
|
||||
is_valid = verify_api_key(None, hashed)
|
||||
assert is_valid is False
|
||||
except TypeError:
|
||||
# Acceptable to raise TypeError for None
|
||||
pass
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_verify_api_key_invalid_hash(self):
|
||||
"""Test verifying API key against invalid hash."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import verify_api_key
|
||||
|
||||
api_key = "ta_test1234567890abcdefghijklmn"
|
||||
invalid_hash = "not-a-valid-bcrypt-hash"
|
||||
|
||||
# Act & Assert: Should return False or raise ValueError
|
||||
try:
|
||||
is_valid = verify_api_key(api_key, invalid_hash)
|
||||
assert is_valid is False
|
||||
except ValueError:
|
||||
# Acceptable to raise ValueError for invalid hash
|
||||
pass
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_verify_api_key_case_sensitive(self):
|
||||
"""Test that API key verification is case-sensitive."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import (
|
||||
hash_api_key,
|
||||
verify_api_key,
|
||||
)
|
||||
|
||||
api_key = "ta_TestKey1234567890ABCDEFGHIJK"
|
||||
hashed = hash_api_key(api_key)
|
||||
|
||||
# Act
|
||||
is_valid_correct = verify_api_key(api_key, hashed)
|
||||
is_valid_wrong_case = verify_api_key(api_key.lower(), hashed)
|
||||
|
||||
# Assert
|
||||
assert is_valid_correct is True
|
||||
assert is_valid_wrong_case is False
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_verify_api_key_similar_keys(self):
|
||||
"""Test that similar keys don't validate."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import (
|
||||
hash_api_key,
|
||||
verify_api_key,
|
||||
)
|
||||
|
||||
api_key = "ta_key1234567890abcdefghijklmnop"
|
||||
similar_key = "ta_key1234567890abcdefghijklmnox" # Last char different
|
||||
hashed = hash_api_key(api_key)
|
||||
|
||||
# Act
|
||||
is_valid = verify_api_key(similar_key, hashed)
|
||||
|
||||
# Assert
|
||||
assert is_valid is False
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Tests: Full API Key Lifecycle
|
||||
# ============================================================================
|
||||
|
||||
class TestApiKeyLifecycle:
|
||||
"""Test complete API key generation, hashing, and verification workflow."""
|
||||
|
||||
def test_full_api_key_lifecycle(self):
|
||||
"""Test complete lifecycle: generate -> hash -> verify."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import (
|
||||
generate_api_key,
|
||||
hash_api_key,
|
||||
verify_api_key,
|
||||
)
|
||||
|
||||
# Generate new API key
|
||||
api_key = generate_api_key()
|
||||
|
||||
# Hash the API key
|
||||
hashed = hash_api_key(api_key)
|
||||
|
||||
# Verify with correct key
|
||||
is_valid = verify_api_key(api_key, hashed)
|
||||
|
||||
# Assert
|
||||
assert isinstance(api_key, str)
|
||||
assert api_key.startswith("ta_")
|
||||
assert isinstance(hashed, str)
|
||||
assert hashed.startswith("$2b$") or hashed.startswith("$2a$")
|
||||
assert is_valid is True
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_multiple_keys_independent(self):
|
||||
"""Test that multiple API keys can coexist independently."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import (
|
||||
generate_api_key,
|
||||
hash_api_key,
|
||||
verify_api_key,
|
||||
)
|
||||
|
||||
# Generate multiple keys
|
||||
key1 = generate_api_key()
|
||||
key2 = generate_api_key()
|
||||
key3 = generate_api_key()
|
||||
|
||||
# Hash each key
|
||||
hash1 = hash_api_key(key1)
|
||||
hash2 = hash_api_key(key2)
|
||||
hash3 = hash_api_key(key3)
|
||||
|
||||
# Assert: Each key only verifies against its own hash
|
||||
assert verify_api_key(key1, hash1) is True
|
||||
assert verify_api_key(key1, hash2) is False
|
||||
assert verify_api_key(key1, hash3) is False
|
||||
|
||||
assert verify_api_key(key2, hash1) is False
|
||||
assert verify_api_key(key2, hash2) is True
|
||||
assert verify_api_key(key2, hash3) is False
|
||||
|
||||
assert verify_api_key(key3, hash1) is False
|
||||
assert verify_api_key(key3, hash2) is False
|
||||
assert verify_api_key(key3, hash3) is True
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_regenerate_key_invalidates_old_hash(self):
|
||||
"""Test that regenerating a key invalidates the old hash."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import (
|
||||
generate_api_key,
|
||||
hash_api_key,
|
||||
verify_api_key,
|
||||
)
|
||||
|
||||
# Generate and hash first key
|
||||
old_key = generate_api_key()
|
||||
old_hash = hash_api_key(old_key)
|
||||
|
||||
# Generate new key (simulate regeneration)
|
||||
new_key = generate_api_key()
|
||||
|
||||
# Assert: Old hash should not work with new key
|
||||
assert verify_api_key(old_key, old_hash) is True
|
||||
assert verify_api_key(new_key, old_hash) is False
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Edge Cases: API Key Service
|
||||
# ============================================================================
|
||||
|
||||
class TestApiKeyEdgeCases:
|
||||
"""Test edge cases in API key service."""
|
||||
|
||||
def test_hash_very_long_key(self):
|
||||
"""Test hashing very long API key."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import (
|
||||
hash_api_key,
|
||||
verify_api_key,
|
||||
)
|
||||
|
||||
# Create very long key (bcrypt has 72 byte limit)
|
||||
long_key = "ta_" + "a" * 200
|
||||
|
||||
# Act
|
||||
hashed = hash_api_key(long_key)
|
||||
is_valid = verify_api_key(long_key, hashed)
|
||||
|
||||
# Assert: Should handle gracefully (may truncate to 72 bytes)
|
||||
assert isinstance(hashed, str)
|
||||
# Bcrypt will only use first ~72 bytes, so verification should work
|
||||
assert is_valid is True
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_hash_unicode_characters(self):
|
||||
"""Test hashing API key with Unicode characters."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import (
|
||||
hash_api_key,
|
||||
verify_api_key,
|
||||
)
|
||||
|
||||
# Key with Unicode characters
|
||||
unicode_key = "ta_测试key_🔑_αβγ"
|
||||
|
||||
# Act
|
||||
hashed = hash_api_key(unicode_key)
|
||||
is_valid = verify_api_key(unicode_key, hashed)
|
||||
|
||||
# Assert
|
||||
assert isinstance(hashed, str)
|
||||
assert is_valid is True
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_timing_attack_resistance(self):
|
||||
"""Test that verification takes similar time for valid/invalid keys."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import (
|
||||
generate_api_key,
|
||||
hash_api_key,
|
||||
verify_api_key,
|
||||
)
|
||||
import time
|
||||
|
||||
api_key = generate_api_key()
|
||||
hashed = hash_api_key(api_key)
|
||||
wrong_key = generate_api_key()
|
||||
|
||||
# Act: Measure time for correct and incorrect verification
|
||||
times_correct = []
|
||||
times_incorrect = []
|
||||
|
||||
for _ in range(10):
|
||||
start = time.perf_counter()
|
||||
verify_api_key(api_key, hashed)
|
||||
times_correct.append(time.perf_counter() - start)
|
||||
|
||||
start = time.perf_counter()
|
||||
verify_api_key(wrong_key, hashed)
|
||||
times_incorrect.append(time.perf_counter() - start)
|
||||
|
||||
# Assert: Times should be similar (within same order of magnitude)
|
||||
# This is a basic check - bcrypt is inherently resistant to timing attacks
|
||||
avg_correct = sum(times_correct) / len(times_correct)
|
||||
avg_incorrect = sum(times_incorrect) / len(times_incorrect)
|
||||
|
||||
# Both should take similar time (bcrypt always does full comparison)
|
||||
assert avg_correct > 0
|
||||
assert avg_incorrect > 0
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
||||
def test_concurrent_key_generation(self):
|
||||
"""Test generating API keys concurrently."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.api_key_service import generate_api_key
|
||||
import concurrent.futures
|
||||
|
||||
# Generate keys concurrently
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(generate_api_key) for _ in range(100)]
|
||||
keys = [f.result() for f in futures]
|
||||
|
||||
# Assert: All keys should be unique
|
||||
assert len(keys) == len(set(keys))
|
||||
assert all(k.startswith("ta_") for k in keys)
|
||||
except ImportError:
|
||||
pytest.skip("API key service not implemented yet")
|
||||
|
|
@ -1,13 +1,14 @@
|
|||
"""
|
||||
Test suite for SQLAlchemy database models.
|
||||
|
||||
This module tests Issue #48 database models:
|
||||
This module tests Issue #48 and Issue #3 database models:
|
||||
1. User model with hashed passwords
|
||||
2. Strategy model with JSON parameters
|
||||
3. Relationships (User -> Strategies)
|
||||
4. Model validation and constraints
|
||||
5. Timestamps (created_at, updated_at)
|
||||
6. Cascade delete behavior
|
||||
2. User model with tax_jurisdiction, timezone, api_key_hash, is_verified (Issue #3)
|
||||
3. Strategy model with JSON parameters
|
||||
4. Relationships (User -> Strategies)
|
||||
5. Model validation and constraints
|
||||
6. Timestamps (created_at, updated_at)
|
||||
7. Cascade delete behavior
|
||||
|
||||
Tests follow TDD - written before implementation.
|
||||
"""
|
||||
|
|
@ -804,3 +805,451 @@ class TestModelEdgeCases:
|
|||
assert strategy.parameters["l1"]["l2"]["l3"]["l4"]["l5"]["value"] == "deep"
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Unit Tests: Issue #3 - User Model Enhancements
|
||||
# ============================================================================
|
||||
|
||||
class TestUserModelIssue3:
|
||||
"""Test User model enhancements from Issue #3.
|
||||
|
||||
New fields tested:
|
||||
- tax_jurisdiction: Nullable string for user's tax jurisdiction
|
||||
- timezone: Nullable string for user's timezone (must be valid IANA timezone)
|
||||
- api_key_hash: Nullable string for hashed API key
|
||||
- is_verified: Boolean flag for email verification (defaults to False)
|
||||
"""
|
||||
|
||||
async def test_user_tax_jurisdiction_default_none(self, db_session):
|
||||
"""Test that tax_jurisdiction defaults to None."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
|
||||
user = User(
|
||||
username="taxuser",
|
||||
email="tax@example.com",
|
||||
hashed_password="hash",
|
||||
)
|
||||
|
||||
# Act
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert hasattr(user, "tax_jurisdiction")
|
||||
assert user.tax_jurisdiction is None
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_tax_jurisdiction_custom_value(self, db_session):
|
||||
"""Test setting custom tax_jurisdiction value."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
|
||||
user = User(
|
||||
username="taxuser2",
|
||||
email="tax2@example.com",
|
||||
hashed_password="hash",
|
||||
tax_jurisdiction="US-CA",
|
||||
)
|
||||
|
||||
# Act
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert user.tax_jurisdiction == "US-CA"
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_timezone_default_none(self, db_session):
|
||||
"""Test that timezone defaults to None."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
|
||||
user = User(
|
||||
username="tzuser",
|
||||
email="tz@example.com",
|
||||
hashed_password="hash",
|
||||
)
|
||||
|
||||
# Act
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert hasattr(user, "timezone")
|
||||
assert user.timezone is None
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_timezone_valid_value(self, db_session):
|
||||
"""Test setting valid timezone value."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
|
||||
user = User(
|
||||
username="tzuser2",
|
||||
email="tz2@example.com",
|
||||
hashed_password="hash",
|
||||
timezone="America/New_York",
|
||||
)
|
||||
|
||||
# Act
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert user.timezone == "America/New_York"
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_timezone_various_timezones(self, db_session):
|
||||
"""Test various valid IANA timezone values."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
|
||||
timezones = [
|
||||
"UTC",
|
||||
"America/Los_Angeles",
|
||||
"Europe/London",
|
||||
"Asia/Tokyo",
|
||||
"Australia/Sydney",
|
||||
]
|
||||
|
||||
for i, tz in enumerate(timezones):
|
||||
user = User(
|
||||
username=f"tzuser_{i}",
|
||||
email=f"tz{i}@example.com",
|
||||
hashed_password="hash",
|
||||
timezone=tz,
|
||||
)
|
||||
|
||||
# Act
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert user.timezone == tz
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_api_key_hash_default_none(self, db_session):
|
||||
"""Test that api_key_hash defaults to None."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
|
||||
user = User(
|
||||
username="apiuser",
|
||||
email="api@example.com",
|
||||
hashed_password="hash",
|
||||
)
|
||||
|
||||
# Act
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert hasattr(user, "api_key_hash")
|
||||
assert user.api_key_hash is None
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_api_key_hash_custom_value(self, db_session):
|
||||
"""Test setting api_key_hash value."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
|
||||
# Simulate hashed API key (bcrypt hash format)
|
||||
api_key_hash = "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/LewY5ygOy3f3K0E6O"
|
||||
|
||||
user = User(
|
||||
username="apiuser2",
|
||||
email="api2@example.com",
|
||||
hashed_password="hash",
|
||||
api_key_hash=api_key_hash,
|
||||
)
|
||||
|
||||
# Act
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert user.api_key_hash == api_key_hash
|
||||
assert user.api_key_hash.startswith("$2b$")
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_is_verified_default_false(self, db_session):
|
||||
"""Test that is_verified defaults to False."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
|
||||
user = User(
|
||||
username="verifyuser",
|
||||
email="verify@example.com",
|
||||
hashed_password="hash",
|
||||
)
|
||||
|
||||
# Act
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert hasattr(user, "is_verified")
|
||||
assert user.is_verified is False
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_is_verified_can_be_true(self, db_session):
|
||||
"""Test setting is_verified to True."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
|
||||
user = User(
|
||||
username="verifyuser2",
|
||||
email="verify2@example.com",
|
||||
hashed_password="hash",
|
||||
is_verified=True,
|
||||
)
|
||||
|
||||
# Act
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert user.is_verified is True
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_all_new_fields_together(self, db_session):
|
||||
"""Test creating user with all Issue #3 fields."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
|
||||
user = User(
|
||||
username="fulluser",
|
||||
email="full@example.com",
|
||||
hashed_password="hash",
|
||||
tax_jurisdiction="US-NY",
|
||||
timezone="America/New_York",
|
||||
api_key_hash="$2b$12$hashedapikey123",
|
||||
is_verified=True,
|
||||
)
|
||||
|
||||
# Act
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert user.tax_jurisdiction == "US-NY"
|
||||
assert user.timezone == "America/New_York"
|
||||
assert user.api_key_hash == "$2b$12$hashedapikey123"
|
||||
assert user.is_verified is True
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_update_timezone(self, db_session):
|
||||
"""Test updating user's timezone."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
|
||||
user = User(
|
||||
username="updatetz",
|
||||
email="updatetz@example.com",
|
||||
hashed_password="hash",
|
||||
timezone="UTC",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Act: Update timezone
|
||||
user.timezone = "America/Los_Angeles"
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert user.timezone == "America/Los_Angeles"
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_update_tax_jurisdiction(self, db_session):
|
||||
"""Test updating user's tax_jurisdiction."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
|
||||
user = User(
|
||||
username="updatetax",
|
||||
email="updatetax@example.com",
|
||||
hashed_password="hash",
|
||||
tax_jurisdiction="US-CA",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Act: Update tax jurisdiction
|
||||
user.tax_jurisdiction = "US-TX"
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert user.tax_jurisdiction == "US-TX"
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_verify_email(self, db_session):
|
||||
"""Test verifying user's email."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
|
||||
user = User(
|
||||
username="toverify",
|
||||
email="toverify@example.com",
|
||||
hashed_password="hash",
|
||||
is_verified=False,
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.is_verified is False
|
||||
|
||||
# Act: Verify email
|
||||
user.is_verified = True
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert user.is_verified is True
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_query_by_timezone(self, db_session):
|
||||
"""Test querying users by timezone."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
from sqlalchemy import select
|
||||
|
||||
# Create users with different timezones
|
||||
user1 = User(
|
||||
username="user_utc",
|
||||
email="utc@example.com",
|
||||
hashed_password="hash",
|
||||
timezone="UTC",
|
||||
)
|
||||
user2 = User(
|
||||
username="user_ny",
|
||||
email="ny@example.com",
|
||||
hashed_password="hash",
|
||||
timezone="America/New_York",
|
||||
)
|
||||
user3 = User(
|
||||
username="user_utc2",
|
||||
email="utc2@example.com",
|
||||
hashed_password="hash",
|
||||
timezone="UTC",
|
||||
)
|
||||
|
||||
db_session.add_all([user1, user2, user3])
|
||||
await db_session.commit()
|
||||
|
||||
# Act: Query users in UTC timezone
|
||||
result = await db_session.execute(
|
||||
select(User).where(User.timezone == "UTC")
|
||||
)
|
||||
utc_users = result.scalars().all()
|
||||
|
||||
# Assert
|
||||
assert len(utc_users) == 2
|
||||
assert all(u.timezone == "UTC" for u in utc_users)
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_query_verified_users(self, db_session):
|
||||
"""Test querying only verified users."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
from sqlalchemy import select
|
||||
|
||||
# Create verified and unverified users
|
||||
verified_user = User(
|
||||
username="verified",
|
||||
email="verified@example.com",
|
||||
hashed_password="hash",
|
||||
is_verified=True,
|
||||
)
|
||||
unverified_user = User(
|
||||
username="unverified",
|
||||
email="unverified@example.com",
|
||||
hashed_password="hash",
|
||||
is_verified=False,
|
||||
)
|
||||
|
||||
db_session.add_all([verified_user, unverified_user])
|
||||
await db_session.commit()
|
||||
|
||||
# Act: Query verified users
|
||||
result = await db_session.execute(
|
||||
select(User).where(User.is_verified == True)
|
||||
)
|
||||
verified_users = result.scalars().all()
|
||||
|
||||
# Assert
|
||||
assert len(verified_users) >= 1
|
||||
assert all(u.is_verified for u in verified_users)
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
||||
async def test_user_api_key_hash_nullable(self, db_session):
|
||||
"""Test that api_key_hash can be set to None."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
|
||||
user = User(
|
||||
username="apinull",
|
||||
email="apinull@example.com",
|
||||
hashed_password="hash",
|
||||
api_key_hash="$2b$12$somehash",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Act: Remove API key
|
||||
user.api_key_hash = None
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert user.api_key_hash is None
|
||||
except ImportError:
|
||||
pytest.skip("Models not implemented yet")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,604 @@
|
|||
"""Unit tests for User model with Issue #3 enhancements.
|
||||
|
||||
Tests for User model fields including:
|
||||
- tax_jurisdiction
|
||||
- timezone
|
||||
- api_key_hash
|
||||
- is_verified
|
||||
|
||||
Follows TDD principles with comprehensive coverage.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from tradingagents.api.models.user import User
|
||||
from tradingagents.api.services.auth_service import hash_password
|
||||
from tradingagents.api.services.api_key_service import generate_api_key, hash_api_key
|
||||
|
||||
|
||||
class TestUserModelBasicFields:
|
||||
"""Tests for basic User model fields."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_with_required_fields(self, db_session):
|
||||
"""Should create user with only required fields."""
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
hashed_password=hash_password("SecurePassword123!"),
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.id is not None
|
||||
assert user.username == "testuser"
|
||||
assert user.email == "test@example.com"
|
||||
assert user.hashed_password is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_defaults(self, db_session):
|
||||
"""Should apply default values to optional fields."""
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
hashed_password=hash_password("SecurePassword123!"),
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Check defaults
|
||||
assert user.is_active is True
|
||||
assert user.is_superuser is False
|
||||
assert user.tax_jurisdiction == "AU"
|
||||
assert user.timezone == "Australia/Sydney"
|
||||
assert user.is_verified is False
|
||||
assert user.api_key_hash is None
|
||||
assert user.full_name is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_username_unique_constraint(self, db_session):
|
||||
"""Should enforce unique username constraint."""
|
||||
user1 = User(
|
||||
username="testuser",
|
||||
email="test1@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
)
|
||||
db_session.add(user1)
|
||||
await db_session.commit()
|
||||
|
||||
# Try to create user with same username
|
||||
user2 = User(
|
||||
username="testuser",
|
||||
email="test2@example.com",
|
||||
hashed_password=hash_password("Password456!"),
|
||||
)
|
||||
db_session.add(user2)
|
||||
|
||||
with pytest.raises(Exception): # IntegrityError
|
||||
await db_session.commit()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_email_unique_constraint(self, db_session):
|
||||
"""Should enforce unique email constraint."""
|
||||
user1 = User(
|
||||
username="user1",
|
||||
email="test@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
)
|
||||
db_session.add(user1)
|
||||
await db_session.commit()
|
||||
|
||||
# Try to create user with same email
|
||||
user2 = User(
|
||||
username="user2",
|
||||
email="test@example.com",
|
||||
hashed_password=hash_password("Password456!"),
|
||||
)
|
||||
db_session.add(user2)
|
||||
|
||||
with pytest.raises(Exception): # IntegrityError
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
class TestUserModelTaxJurisdiction:
|
||||
"""Tests for tax_jurisdiction field (Issue #3)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_us_jurisdiction(self, db_session):
|
||||
"""Should set US tax jurisdiction."""
|
||||
user = User(
|
||||
username="ususer",
|
||||
email="us@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
tax_jurisdiction="US",
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.tax_jurisdiction == "US"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_us_state_jurisdiction(self, db_session):
|
||||
"""Should set US state-level tax jurisdiction."""
|
||||
user = User(
|
||||
username="nyuser",
|
||||
email="ny@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
tax_jurisdiction="US-NY",
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.tax_jurisdiction == "US-NY"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_canadian_province_jurisdiction(self, db_session):
|
||||
"""Should set Canadian province-level tax jurisdiction."""
|
||||
user = User(
|
||||
username="causer",
|
||||
email="ca@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
tax_jurisdiction="CA-ON",
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.tax_jurisdiction == "CA-ON"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_australian_state_jurisdiction(self, db_session):
|
||||
"""Should set Australian state-level tax jurisdiction."""
|
||||
user = User(
|
||||
username="auuser",
|
||||
email="au@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
tax_jurisdiction="AU-NSW",
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.tax_jurisdiction == "AU-NSW"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tax_jurisdiction_default(self, db_session):
|
||||
"""Should default to AU if not specified."""
|
||||
user = User(
|
||||
username="defaultuser",
|
||||
email="default@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.tax_jurisdiction == "AU"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tax_jurisdiction_max_length(self, db_session):
|
||||
"""Tax jurisdiction should not exceed 10 characters."""
|
||||
# This should work (10 chars max)
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
tax_jurisdiction="AU-NSW", # 6 chars
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.tax_jurisdiction == "AU-NSW"
|
||||
|
||||
|
||||
class TestUserModelTimezone:
|
||||
"""Tests for timezone field (Issue #3)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_us_timezone(self, db_session):
|
||||
"""Should set US timezone."""
|
||||
user = User(
|
||||
username="nyuser",
|
||||
email="ny@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
timezone="America/New_York",
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.timezone == "America/New_York"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_utc_timezone(self, db_session):
|
||||
"""Should set UTC timezone."""
|
||||
user = User(
|
||||
username="utcuser",
|
||||
email="utc@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
timezone="UTC",
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.timezone == "UTC"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_european_timezone(self, db_session):
|
||||
"""Should set European timezone."""
|
||||
user = User(
|
||||
username="londonuser",
|
||||
email="london@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
timezone="Europe/London",
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.timezone == "Europe/London"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_asian_timezone(self, db_session):
|
||||
"""Should set Asian timezone."""
|
||||
user = User(
|
||||
username="tokyouser",
|
||||
email="tokyo@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
timezone="Asia/Tokyo",
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.timezone == "Asia/Tokyo"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timezone_default(self, db_session):
|
||||
"""Should default to Australia/Sydney if not specified."""
|
||||
user = User(
|
||||
username="defaultuser",
|
||||
email="default@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.timezone == "Australia/Sydney"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timezone_max_length(self, db_session):
|
||||
"""Timezone should not exceed 50 characters."""
|
||||
# Longest IANA timezone is ~40 characters
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
timezone="America/Argentina/ComodRivadavia", # 35 chars
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.timezone == "America/Argentina/ComodRivadavia"
|
||||
|
||||
|
||||
class TestUserModelApiKey:
|
||||
"""Tests for api_key_hash field (Issue #3)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_without_api_key(self, db_session):
|
||||
"""User without API key should have None api_key_hash."""
|
||||
user = User(
|
||||
username="nokey",
|
||||
email="nokey@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.api_key_hash is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_with_api_key(self, db_session):
|
||||
"""User with API key should store hashed key."""
|
||||
api_key = generate_api_key()
|
||||
hashed = hash_api_key(api_key)
|
||||
|
||||
user = User(
|
||||
username="withkey",
|
||||
email="withkey@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
api_key_hash=hashed,
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.api_key_hash is not None
|
||||
assert user.api_key_hash == hashed
|
||||
assert user.api_key_hash != api_key # Should be hash, not plain key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_hash_unique_constraint(self, db_session):
|
||||
"""Should enforce unique api_key_hash constraint."""
|
||||
api_key = generate_api_key()
|
||||
hashed = hash_api_key(api_key)
|
||||
|
||||
user1 = User(
|
||||
username="user1",
|
||||
email="user1@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
api_key_hash=hashed,
|
||||
)
|
||||
db_session.add(user1)
|
||||
await db_session.commit()
|
||||
|
||||
# Try to create user with same api_key_hash
|
||||
user2 = User(
|
||||
username="user2",
|
||||
email="user2@example.com",
|
||||
hashed_password=hash_password("Password456!"),
|
||||
api_key_hash=hashed,
|
||||
)
|
||||
db_session.add(user2)
|
||||
|
||||
with pytest.raises(Exception): # IntegrityError
|
||||
await db_session.commit()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_hash_indexed(self, db_session):
|
||||
"""api_key_hash should be indexed for fast lookups."""
|
||||
# Create users with API keys
|
||||
for i in range(10):
|
||||
api_key = generate_api_key()
|
||||
hashed = hash_api_key(api_key)
|
||||
|
||||
user = User(
|
||||
username=f"user{i}",
|
||||
email=f"user{i}@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
api_key_hash=hashed,
|
||||
)
|
||||
db_session.add(user)
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
# Lookup by api_key_hash should work
|
||||
api_key = generate_api_key()
|
||||
hashed = hash_api_key(api_key)
|
||||
|
||||
user = User(
|
||||
username="lookup",
|
||||
email="lookup@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
api_key_hash=hashed,
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
# Query by api_key_hash
|
||||
result = await db_session.execute(
|
||||
select(User).where(User.api_key_hash == hashed)
|
||||
)
|
||||
found_user = result.scalar_one_or_none()
|
||||
|
||||
assert found_user is not None
|
||||
assert found_user.username == "lookup"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_api_key(self, db_session):
|
||||
"""User should be able to regenerate their API key."""
|
||||
# Create user with API key
|
||||
old_api_key = generate_api_key()
|
||||
old_hash = hash_api_key(old_api_key)
|
||||
|
||||
user = User(
|
||||
username="regenerate",
|
||||
email="regenerate@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
api_key_hash=old_hash,
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Regenerate API key
|
||||
new_api_key = generate_api_key()
|
||||
new_hash = hash_api_key(new_api_key)
|
||||
|
||||
user.api_key_hash = new_hash
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.api_key_hash == new_hash
|
||||
assert user.api_key_hash != old_hash
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_api_key(self, db_session):
|
||||
"""User should be able to revoke their API key."""
|
||||
# Create user with API key
|
||||
api_key = generate_api_key()
|
||||
hashed = hash_api_key(api_key)
|
||||
|
||||
user = User(
|
||||
username="revoke",
|
||||
email="revoke@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
api_key_hash=hashed,
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Revoke API key (set to None)
|
||||
user.api_key_hash = None
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.api_key_hash is None
|
||||
|
||||
|
||||
class TestUserModelIsVerified:
|
||||
"""Tests for is_verified field (Issue #3)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_unverified_by_default(self, db_session):
|
||||
"""New users should be unverified by default."""
|
||||
user = User(
|
||||
username="unverified",
|
||||
email="unverified@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.is_verified is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_verified_user(self, db_session):
|
||||
"""Should be able to create verified user."""
|
||||
user = User(
|
||||
username="verified",
|
||||
email="verified@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
is_verified=True,
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.is_verified is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_user(self, db_session):
|
||||
"""Should be able to verify user after creation."""
|
||||
# Create unverified user
|
||||
user = User(
|
||||
username="toverify",
|
||||
email="toverify@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
is_verified=False,
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.is_verified is False
|
||||
|
||||
# Verify user
|
||||
user.is_verified = True
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.is_verified is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_verified_users(self, db_session):
|
||||
"""Should be able to query only verified users."""
|
||||
# Create mix of verified and unverified users
|
||||
for i in range(5):
|
||||
user = User(
|
||||
username=f"user{i}",
|
||||
email=f"user{i}@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
is_verified=(i % 2 == 0), # Alternate verified/unverified
|
||||
)
|
||||
db_session.add(user)
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
# Query only verified users
|
||||
result = await db_session.execute(
|
||||
select(User).where(User.is_verified == True)
|
||||
)
|
||||
verified_users = result.scalars().all()
|
||||
|
||||
assert len(verified_users) == 3 # users 0, 2, 4
|
||||
for user in verified_users:
|
||||
assert user.is_verified is True
|
||||
|
||||
|
||||
class TestUserModelComplete:
|
||||
"""Tests for complete user creation with all Issue #3 fields."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_complete_user(self, db_session):
|
||||
"""Should create user with all fields including Issue #3 additions."""
|
||||
api_key = generate_api_key()
|
||||
hashed_api_key = hash_api_key(api_key)
|
||||
|
||||
user = User(
|
||||
username="complete",
|
||||
email="complete@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
full_name="Complete User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
tax_jurisdiction="US-NY",
|
||||
timezone="America/New_York",
|
||||
api_key_hash=hashed_api_key,
|
||||
is_verified=True,
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
assert user.id is not None
|
||||
assert user.username == "complete"
|
||||
assert user.email == "complete@example.com"
|
||||
assert user.full_name == "Complete User"
|
||||
assert user.is_active is True
|
||||
assert user.is_superuser is False
|
||||
assert user.tax_jurisdiction == "US-NY"
|
||||
assert user.timezone == "America/New_York"
|
||||
assert user.api_key_hash == hashed_api_key
|
||||
assert user.is_verified is True
|
||||
assert user.created_at is not None
|
||||
assert user.updated_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_repr(self, db_session):
|
||||
"""Should have meaningful string representation."""
|
||||
user = User(
|
||||
username="reprtest",
|
||||
email="repr@example.com",
|
||||
hashed_password=hash_password("Password123!"),
|
||||
)
|
||||
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
repr_str = repr(user)
|
||||
assert "reprtest" in repr_str
|
||||
assert "repr@example.com" in repr_str
|
||||
assert str(user.id) in repr_str
|
||||
|
|
@ -0,0 +1,678 @@
|
|||
"""
|
||||
Test suite for validators (Issue #3).
|
||||
|
||||
This module tests timezone and tax jurisdiction validation:
|
||||
1. Timezone validation (IANA timezone database)
|
||||
2. Tax jurisdiction validation (format and valid codes)
|
||||
3. Edge cases and error handling
|
||||
4. Integration with User model
|
||||
|
||||
Tests follow TDD - written before implementation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from typing import Optional
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Unit Tests: Timezone Validation
|
||||
# ============================================================================
|
||||
|
||||
class TestTimezoneValidation:
|
||||
"""Test timezone validation functionality."""
|
||||
|
||||
def test_validate_timezone_valid_utc(self):
|
||||
"""Test validating UTC timezone."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
result = validate_timezone("UTC")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_timezone_valid_america_new_york(self):
|
||||
"""Test validating America/New_York timezone."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
result = validate_timezone("America/New_York")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_timezone_valid_europe_london(self):
|
||||
"""Test validating Europe/London timezone."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
result = validate_timezone("Europe/London")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_timezone_valid_asia_tokyo(self):
|
||||
"""Test validating Asia/Tokyo timezone."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
result = validate_timezone("Asia/Tokyo")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_timezone_valid_australia_sydney(self):
|
||||
"""Test validating Australia/Sydney timezone."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
result = validate_timezone("Australia/Sydney")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_timezone_invalid_fake_timezone(self):
|
||||
"""Test rejecting invalid timezone."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
result = validate_timezone("Invalid/Timezone")
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_timezone_invalid_empty_string(self):
|
||||
"""Test rejecting empty string."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
result = validate_timezone("")
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_timezone_none(self):
|
||||
"""Test handling None value."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
# None should be accepted (nullable field)
|
||||
result = validate_timezone(None)
|
||||
|
||||
# Assert: None is valid (nullable)
|
||||
assert result is True
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_timezone_case_sensitive(self):
|
||||
"""Test that timezone validation is case-sensitive."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
# Correct case
|
||||
result_correct = validate_timezone("America/New_York")
|
||||
|
||||
# Wrong case
|
||||
result_wrong = validate_timezone("america/new_york")
|
||||
|
||||
# Assert
|
||||
assert result_correct is True
|
||||
assert result_wrong is False
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_timezone_various_valid_timezones(self):
|
||||
"""Test validating various valid IANA timezones."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
valid_timezones = [
|
||||
"UTC",
|
||||
"GMT",
|
||||
"America/New_York",
|
||||
"America/Los_Angeles",
|
||||
"America/Chicago",
|
||||
"America/Denver",
|
||||
"Europe/London",
|
||||
"Europe/Paris",
|
||||
"Europe/Berlin",
|
||||
"Asia/Tokyo",
|
||||
"Asia/Shanghai",
|
||||
"Asia/Hong_Kong",
|
||||
"Australia/Sydney",
|
||||
"Australia/Melbourne",
|
||||
"Pacific/Auckland",
|
||||
]
|
||||
|
||||
# Act & Assert
|
||||
for tz in valid_timezones:
|
||||
result = validate_timezone(tz)
|
||||
assert result is True, f"Timezone {tz} should be valid"
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_timezone_various_invalid_timezones(self):
|
||||
"""Test rejecting various invalid timezones."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
invalid_timezones = [
|
||||
"PST", # Abbreviations not valid
|
||||
"EST",
|
||||
"CST",
|
||||
"MST",
|
||||
"America/InvalidCity",
|
||||
"Europe/FakePlace",
|
||||
"Random/Stuff",
|
||||
"123456",
|
||||
"!@#$%",
|
||||
"america/new_york", # Wrong case
|
||||
]
|
||||
|
||||
# Act & Assert
|
||||
for tz in invalid_timezones:
|
||||
result = validate_timezone(tz)
|
||||
assert result is False, f"Timezone {tz} should be invalid"
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_timezone_with_underscores(self):
|
||||
"""Test timezones with underscores in city names."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
# Valid timezones with underscores
|
||||
result1 = validate_timezone("America/New_York")
|
||||
result2 = validate_timezone("America/Los_Angeles")
|
||||
result3 = validate_timezone("America/Port-au-Prince")
|
||||
|
||||
# Assert
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
assert result3 is True
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Unit Tests: Tax Jurisdiction Validation
|
||||
# ============================================================================
|
||||
|
||||
class TestTaxJurisdictionValidation:
|
||||
"""Test tax jurisdiction validation functionality."""
|
||||
|
||||
def test_validate_tax_jurisdiction_valid_us_state(self):
|
||||
"""Test validating US state tax jurisdiction."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
result = validate_tax_jurisdiction("US-CA")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_tax_jurisdiction_valid_us_states(self):
|
||||
"""Test validating various US state tax jurisdictions."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
valid_jurisdictions = [
|
||||
"US-CA", # California
|
||||
"US-NY", # New York
|
||||
"US-TX", # Texas
|
||||
"US-FL", # Florida
|
||||
"US-IL", # Illinois
|
||||
"US-PA", # Pennsylvania
|
||||
"US-OH", # Ohio
|
||||
"US-WA", # Washington
|
||||
]
|
||||
|
||||
# Act & Assert
|
||||
for jurisdiction in valid_jurisdictions:
|
||||
result = validate_tax_jurisdiction(jurisdiction)
|
||||
assert result is True, f"Jurisdiction {jurisdiction} should be valid"
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_tax_jurisdiction_valid_countries(self):
|
||||
"""Test validating country-level tax jurisdictions."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
valid_jurisdictions = [
|
||||
"US", # United States
|
||||
"CA", # Canada
|
||||
"GB", # United Kingdom
|
||||
"DE", # Germany
|
||||
"FR", # France
|
||||
"JP", # Japan
|
||||
"AU", # Australia
|
||||
"NZ", # New Zealand
|
||||
]
|
||||
|
||||
# Act & Assert
|
||||
for jurisdiction in valid_jurisdictions:
|
||||
result = validate_tax_jurisdiction(jurisdiction)
|
||||
assert result is True, f"Jurisdiction {jurisdiction} should be valid"
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_tax_jurisdiction_valid_canadian_provinces(self):
|
||||
"""Test validating Canadian province tax jurisdictions."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
valid_jurisdictions = [
|
||||
"CA-ON", # Ontario
|
||||
"CA-QC", # Quebec
|
||||
"CA-BC", # British Columbia
|
||||
"CA-AB", # Alberta
|
||||
]
|
||||
|
||||
# Act & Assert
|
||||
for jurisdiction in valid_jurisdictions:
|
||||
result = validate_tax_jurisdiction(jurisdiction)
|
||||
assert result is True, f"Jurisdiction {jurisdiction} should be valid"
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_tax_jurisdiction_invalid_format(self):
|
||||
"""Test rejecting invalid format."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
invalid_jurisdictions = [
|
||||
"InvalidFormat",
|
||||
"US_CA", # Wrong separator
|
||||
"US/CA", # Wrong separator
|
||||
"USCA", # No separator
|
||||
"us-ca", # Lowercase
|
||||
"123",
|
||||
"!@#",
|
||||
]
|
||||
|
||||
# Act & Assert
|
||||
for jurisdiction in invalid_jurisdictions:
|
||||
result = validate_tax_jurisdiction(jurisdiction)
|
||||
assert result is False, f"Jurisdiction {jurisdiction} should be invalid"
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_tax_jurisdiction_invalid_country_code(self):
|
||||
"""Test rejecting invalid country codes."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
result = validate_tax_jurisdiction("XX-YY")
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_tax_jurisdiction_none(self):
|
||||
"""Test handling None value."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
# None should be accepted (nullable field)
|
||||
result = validate_tax_jurisdiction(None)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_tax_jurisdiction_empty_string(self):
|
||||
"""Test rejecting empty string."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
result = validate_tax_jurisdiction("")
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_tax_jurisdiction_case_sensitive(self):
|
||||
"""Test that tax jurisdiction validation is case-sensitive."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
# Correct case (uppercase)
|
||||
result_correct = validate_tax_jurisdiction("US-CA")
|
||||
|
||||
# Wrong case (lowercase)
|
||||
result_wrong = validate_tax_jurisdiction("us-ca")
|
||||
|
||||
# Assert
|
||||
assert result_correct is True
|
||||
assert result_wrong is False
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_validate_tax_jurisdiction_max_length(self):
|
||||
"""Test rejecting very long jurisdiction strings."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
# Very long string
|
||||
result = validate_tax_jurisdiction("US-" + "A" * 100)
|
||||
|
||||
# Assert: Should reject overly long jurisdictions
|
||||
assert result is False
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Tests: Validators with User Model
|
||||
# ============================================================================
|
||||
|
||||
class TestValidatorsIntegration:
|
||||
"""Test validators integrated with User model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_with_valid_timezone(self, db_session):
|
||||
"""Test creating user with validated timezone."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
timezone = "America/New_York"
|
||||
assert validate_timezone(timezone) is True
|
||||
|
||||
user = User(
|
||||
username="tzvaliduser",
|
||||
email="tzvalid@example.com",
|
||||
hashed_password="hash",
|
||||
timezone=timezone,
|
||||
)
|
||||
|
||||
# Act
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert user.timezone == timezone
|
||||
except ImportError:
|
||||
pytest.skip("Models or validators not implemented yet")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_with_invalid_timezone_at_api_level(self, db_session):
|
||||
"""Test that invalid timezone should be caught at API level, not DB."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
invalid_timezone = "Invalid/Timezone"
|
||||
assert validate_timezone(invalid_timezone) is False
|
||||
|
||||
# Note: Database will accept it, validation happens at API layer
|
||||
user = User(
|
||||
username="tzinvalid",
|
||||
email="tzinvalid@example.com",
|
||||
hashed_password="hash",
|
||||
timezone=invalid_timezone,
|
||||
)
|
||||
|
||||
# Act: DB should accept it (validation is at API level)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert: DB accepts it, but validator rejects it
|
||||
assert user.timezone == invalid_timezone
|
||||
assert validate_timezone(user.timezone) is False
|
||||
except ImportError:
|
||||
pytest.skip("Models or validators not implemented yet")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_with_valid_tax_jurisdiction(self, db_session):
|
||||
"""Test creating user with validated tax jurisdiction."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
jurisdiction = "US-CA"
|
||||
assert validate_tax_jurisdiction(jurisdiction) is True
|
||||
|
||||
user = User(
|
||||
username="taxvaliduser",
|
||||
email="taxvalid@example.com",
|
||||
hashed_password="hash",
|
||||
tax_jurisdiction=jurisdiction,
|
||||
)
|
||||
|
||||
# Act
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert user.tax_jurisdiction == jurisdiction
|
||||
except ImportError:
|
||||
pytest.skip("Models or validators not implemented yet")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_with_both_validators(self, db_session):
|
||||
"""Test creating user with both validated fields."""
|
||||
# Arrange
|
||||
try:
|
||||
from tradingagents.api.models import User
|
||||
from tradingagents.api.services.validators import (
|
||||
validate_timezone,
|
||||
validate_tax_jurisdiction,
|
||||
)
|
||||
|
||||
timezone = "America/Los_Angeles"
|
||||
jurisdiction = "US-CA"
|
||||
|
||||
assert validate_timezone(timezone) is True
|
||||
assert validate_tax_jurisdiction(jurisdiction) is True
|
||||
|
||||
user = User(
|
||||
username="bothvalid",
|
||||
email="bothvalid@example.com",
|
||||
hashed_password="hash",
|
||||
timezone=timezone,
|
||||
tax_jurisdiction=jurisdiction,
|
||||
)
|
||||
|
||||
# Act
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Assert
|
||||
assert user.timezone == timezone
|
||||
assert user.tax_jurisdiction == jurisdiction
|
||||
except ImportError:
|
||||
pytest.skip("Models or validators not implemented yet")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Edge Cases: Validators
|
||||
# ============================================================================
|
||||
|
||||
class TestValidatorEdgeCases:
|
||||
"""Test edge cases in validators."""
|
||||
|
||||
def test_timezone_with_special_characters(self):
|
||||
"""Test timezone with special characters."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
# Test various special characters
|
||||
result1 = validate_timezone("America/Port-au-Prince") # Hyphen
|
||||
result2 = validate_timezone("America/Indiana/Indianapolis") # Multiple slashes
|
||||
|
||||
# Assert
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_timezone_whitespace_handling(self):
|
||||
"""Test timezone validation with whitespace."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
# Timezones with leading/trailing whitespace
|
||||
result1 = validate_timezone(" America/New_York ")
|
||||
result2 = validate_timezone("America/New_York ")
|
||||
result3 = validate_timezone(" America/New_York")
|
||||
|
||||
# Assert: Should reject (strict validation)
|
||||
assert result1 is False
|
||||
assert result2 is False
|
||||
assert result3 is False
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_tax_jurisdiction_whitespace_handling(self):
|
||||
"""Test tax jurisdiction validation with whitespace."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
# Jurisdictions with leading/trailing whitespace
|
||||
result1 = validate_tax_jurisdiction(" US-CA ")
|
||||
result2 = validate_tax_jurisdiction("US-CA ")
|
||||
result3 = validate_tax_jurisdiction(" US-CA")
|
||||
|
||||
# Assert: Should reject (strict validation)
|
||||
assert result1 is False
|
||||
assert result2 is False
|
||||
assert result3 is False
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_timezone_numeric_string(self):
|
||||
"""Test timezone validation with numeric strings."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
result = validate_timezone("12345")
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_tax_jurisdiction_only_country(self):
|
||||
"""Test tax jurisdiction with only country code."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
# Two-letter country codes should be valid
|
||||
result_us = validate_tax_jurisdiction("US")
|
||||
result_ca = validate_tax_jurisdiction("CA")
|
||||
result_gb = validate_tax_jurisdiction("GB")
|
||||
|
||||
# Assert
|
||||
assert result_us is True
|
||||
assert result_ca is True
|
||||
assert result_gb is True
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_tax_jurisdiction_single_letter(self):
|
||||
"""Test tax jurisdiction with single letter."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
result = validate_tax_jurisdiction("A")
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_timezone_sql_injection_attempt(self):
|
||||
"""Test timezone validation against SQL injection."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_timezone
|
||||
|
||||
malicious_inputs = [
|
||||
"'; DROP TABLE users; --",
|
||||
"1' OR '1'='1",
|
||||
"America/New_York'; DELETE FROM users; --",
|
||||
]
|
||||
|
||||
# Act & Assert
|
||||
for malicious in malicious_inputs:
|
||||
result = validate_timezone(malicious)
|
||||
assert result is False, f"Should reject malicious input: {malicious}"
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
||||
def test_tax_jurisdiction_sql_injection_attempt(self):
|
||||
"""Test tax jurisdiction validation against SQL injection."""
|
||||
# Arrange & Act
|
||||
try:
|
||||
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||
|
||||
malicious_inputs = [
|
||||
"'; DROP TABLE users; --",
|
||||
"1' OR '1'='1",
|
||||
"US-CA'; DELETE FROM users; --",
|
||||
]
|
||||
|
||||
# Act & Assert
|
||||
for malicious in malicious_inputs:
|
||||
result = validate_tax_jurisdiction(malicious)
|
||||
assert result is False, f"Should reject malicious input: {malicious}"
|
||||
except ImportError:
|
||||
pytest.skip("Validators not implemented yet")
|
||||
|
|
@ -8,18 +8,63 @@ from tradingagents.api.models.base import Base, TimestampMixin
|
|||
|
||||
|
||||
class User(Base, TimestampMixin):
|
||||
"""User model for authentication and authorization."""
|
||||
"""User model for authentication and authorization.
|
||||
|
||||
Attributes:
|
||||
id: Primary key
|
||||
username: Unique username for authentication
|
||||
email: Unique email address
|
||||
hashed_password: Bcrypt hashed password
|
||||
full_name: Optional full name
|
||||
is_active: Whether user account is active
|
||||
is_superuser: Whether user has admin privileges
|
||||
tax_jurisdiction: Tax jurisdiction code (e.g., "US", "US-CA", "AU")
|
||||
timezone: IANA timezone identifier (e.g., "America/New_York", "UTC")
|
||||
api_key_hash: Bcrypt hash of API key (if user has API key)
|
||||
is_verified: Whether user email is verified
|
||||
strategies: Related Strategy objects owned by this user
|
||||
"""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
# Primary identification
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
username: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
full_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
|
||||
# User status and permissions
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Issue #3: Profile fields
|
||||
tax_jurisdiction: Mapped[str] = mapped_column(
|
||||
String(10),
|
||||
default="AU",
|
||||
nullable=False,
|
||||
comment="Tax jurisdiction code (e.g., US, US-CA, AU-NSW)"
|
||||
)
|
||||
timezone: Mapped[str] = mapped_column(
|
||||
String(50),
|
||||
default="Australia/Sydney",
|
||||
nullable=False,
|
||||
comment="IANA timezone identifier (e.g., America/New_York, UTC)"
|
||||
)
|
||||
api_key_hash: Mapped[Optional[str]] = mapped_column(
|
||||
String(255),
|
||||
nullable=True,
|
||||
index=True,
|
||||
unique=True,
|
||||
comment="Bcrypt hash of API key for programmatic access"
|
||||
)
|
||||
is_verified: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
default=False,
|
||||
nullable=False,
|
||||
comment="Whether user email has been verified"
|
||||
)
|
||||
|
||||
# Relationship to strategies
|
||||
strategies: Mapped[List["Strategy"]] = relationship(
|
||||
"Strategy",
|
||||
|
|
|
|||
|
|
@ -6,10 +6,31 @@ from tradingagents.api.services.auth_service import (
|
|||
create_access_token,
|
||||
decode_access_token,
|
||||
)
|
||||
from tradingagents.api.services.api_key_service import (
|
||||
generate_api_key,
|
||||
hash_api_key,
|
||||
verify_api_key,
|
||||
)
|
||||
from tradingagents.api.services.validators import (
|
||||
validate_timezone,
|
||||
validate_tax_jurisdiction,
|
||||
get_available_timezones,
|
||||
get_available_tax_jurisdictions,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Auth service
|
||||
"hash_password",
|
||||
"verify_password",
|
||||
"create_access_token",
|
||||
"decode_access_token",
|
||||
# API key service
|
||||
"generate_api_key",
|
||||
"hash_api_key",
|
||||
"verify_api_key",
|
||||
# Validators
|
||||
"validate_timezone",
|
||||
"validate_tax_jurisdiction",
|
||||
"get_available_timezones",
|
||||
"get_available_tax_jurisdictions",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,115 @@
|
|||
"""API key service for secure key generation and hashing.
|
||||
|
||||
This module provides utilities for generating and verifying API keys:
|
||||
- Generate secure random API keys with 'ta_' prefix
|
||||
- Hash API keys using bcrypt (via pwdlib)
|
||||
- Verify plain API keys against hashed values
|
||||
|
||||
Security:
|
||||
- Never store plain API keys in the database
|
||||
- Use bcrypt for hashing (via pwdlib PasswordHash)
|
||||
- API keys are URL-safe base64 encoded (32 bytes)
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from pwdlib import PasswordHash
|
||||
|
||||
|
||||
# API key hashing with bcrypt (same context as passwords for consistency)
|
||||
api_key_context = PasswordHash.recommended()
|
||||
|
||||
|
||||
def generate_api_key() -> str:
|
||||
"""
|
||||
Generate a secure random API key.
|
||||
|
||||
Returns a URL-safe API key with the 'ta_' prefix followed by
|
||||
32 bytes of random data encoded as base64.
|
||||
|
||||
Format: ta_<base64_url_safe_32_bytes>
|
||||
Example: ta_vK9x8pL2mN3qR5sT7uW1yZ4aB6cD8eF0gH2jK4lM6n
|
||||
|
||||
Returns:
|
||||
str: Generated API key (plaintext)
|
||||
|
||||
Security:
|
||||
- Uses secrets.token_urlsafe() for cryptographically strong randomness
|
||||
- 32 bytes = 256 bits of entropy
|
||||
- Never store the returned value directly in database
|
||||
|
||||
Example:
|
||||
>>> api_key = generate_api_key()
|
||||
>>> api_key.startswith("ta_")
|
||||
True
|
||||
>>> len(api_key) > 40 # ta_ + base64(32 bytes)
|
||||
True
|
||||
"""
|
||||
# Generate 32 bytes (256 bits) of cryptographically secure random data
|
||||
# URL-safe base64 encoding makes it safe for URLs and headers
|
||||
random_part = secrets.token_urlsafe(32)
|
||||
|
||||
return f"ta_{random_part}"
|
||||
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
"""
|
||||
Hash an API key using bcrypt.
|
||||
|
||||
Uses the same pwdlib PasswordHash context as password hashing
|
||||
for consistency. The hashed value can be safely stored in the database.
|
||||
|
||||
Args:
|
||||
api_key: Plain text API key (from generate_api_key())
|
||||
|
||||
Returns:
|
||||
str: Bcrypt hash of the API key
|
||||
|
||||
Security:
|
||||
- Uses bcrypt algorithm (via Argon2 default from pwdlib)
|
||||
- Hash is one-way and computationally expensive to reverse
|
||||
- Store this hash in database, not the plain API key
|
||||
|
||||
Example:
|
||||
>>> api_key = generate_api_key()
|
||||
>>> hashed = hash_api_key(api_key)
|
||||
>>> hashed != api_key # Hash is different from plain key
|
||||
True
|
||||
>>> len(hashed) > 50 # Bcrypt hashes are long
|
||||
True
|
||||
"""
|
||||
return api_key_context.hash(api_key)
|
||||
|
||||
|
||||
def verify_api_key(plain_api_key: str, hashed_api_key: str) -> bool:
|
||||
"""
|
||||
Verify a plain API key against a hash.
|
||||
|
||||
Checks if the provided plain API key matches the stored hash.
|
||||
Uses constant-time comparison to prevent timing attacks.
|
||||
|
||||
Args:
|
||||
plain_api_key: Plain text API key (from user request)
|
||||
hashed_api_key: Hashed API key (from database)
|
||||
|
||||
Returns:
|
||||
bool: True if API key matches hash, False otherwise
|
||||
|
||||
Security:
|
||||
- Uses constant-time comparison
|
||||
- Safe against timing attacks
|
||||
- Computationally expensive to slow down brute force
|
||||
|
||||
Example:
|
||||
>>> api_key = generate_api_key()
|
||||
>>> hashed = hash_api_key(api_key)
|
||||
>>> verify_api_key(api_key, hashed)
|
||||
True
|
||||
>>> verify_api_key("wrong_key", hashed)
|
||||
False
|
||||
"""
|
||||
try:
|
||||
return api_key_context.verify(plain_api_key, hashed_api_key)
|
||||
except Exception:
|
||||
# If verification fails for any reason (malformed hash, etc.)
|
||||
# return False rather than raising an exception
|
||||
return False
|
||||
|
|
@ -0,0 +1,303 @@
|
|||
"""Validators for user profile fields.
|
||||
|
||||
This module provides validation functions for:
|
||||
- Timezones (IANA timezone database)
|
||||
- Tax jurisdictions (country codes and state/province codes)
|
||||
|
||||
All validators return True/False and are designed to be used
|
||||
in Pydantic models and database constraints.
|
||||
"""
|
||||
|
||||
from typing import Set
|
||||
from zoneinfo import ZoneInfo, available_timezones
|
||||
|
||||
|
||||
# Valid tax jurisdictions (ISO 3166-1 alpha-2 country codes + state/province)
|
||||
# Format: "CC" for country-level, "CC-SS" for state/province-level
|
||||
# This is a comprehensive list covering major jurisdictions
|
||||
VALID_TAX_JURISDICTIONS: Set[str] = {
|
||||
# Country-level codes (ISO 3166-1 alpha-2)
|
||||
"US", # United States
|
||||
"CA", # Canada
|
||||
"GB", # United Kingdom
|
||||
"AU", # Australia
|
||||
"DE", # Germany
|
||||
"FR", # France
|
||||
"IT", # Italy
|
||||
"ES", # Spain
|
||||
"NL", # Netherlands
|
||||
"BE", # Belgium
|
||||
"CH", # Switzerland
|
||||
"AT", # Austria
|
||||
"SE", # Sweden
|
||||
"NO", # Norway
|
||||
"DK", # Denmark
|
||||
"FI", # Finland
|
||||
"IE", # Ireland
|
||||
"PT", # Portugal
|
||||
"GR", # Greece
|
||||
"PL", # Poland
|
||||
"CZ", # Czech Republic
|
||||
"HU", # Hungary
|
||||
"RO", # Romania
|
||||
"JP", # Japan
|
||||
"CN", # China
|
||||
"KR", # South Korea
|
||||
"IN", # India
|
||||
"SG", # Singapore
|
||||
"HK", # Hong Kong
|
||||
"NZ", # New Zealand
|
||||
"MX", # Mexico
|
||||
"BR", # Brazil
|
||||
"AR", # Argentina
|
||||
"CL", # Chile
|
||||
"ZA", # South Africa
|
||||
"AE", # United Arab Emirates
|
||||
"SA", # Saudi Arabia
|
||||
"IL", # Israel
|
||||
"TR", # Turkey
|
||||
"RU", # Russia
|
||||
"UA", # Ukraine
|
||||
"TH", # Thailand
|
||||
"MY", # Malaysia
|
||||
"ID", # Indonesia
|
||||
"PH", # Philippines
|
||||
"VN", # Vietnam
|
||||
"TW", # Taiwan
|
||||
|
||||
# United States - State level
|
||||
"US-AL", # Alabama
|
||||
"US-AK", # Alaska
|
||||
"US-AZ", # Arizona
|
||||
"US-AR", # Arkansas
|
||||
"US-CA", # California
|
||||
"US-CO", # Colorado
|
||||
"US-CT", # Connecticut
|
||||
"US-DE", # Delaware
|
||||
"US-FL", # Florida
|
||||
"US-GA", # Georgia
|
||||
"US-HI", # Hawaii
|
||||
"US-ID", # Idaho
|
||||
"US-IL", # Illinois
|
||||
"US-IN", # Indiana
|
||||
"US-IA", # Iowa
|
||||
"US-KS", # Kansas
|
||||
"US-KY", # Kentucky
|
||||
"US-LA", # Louisiana
|
||||
"US-ME", # Maine
|
||||
"US-MD", # Maryland
|
||||
"US-MA", # Massachusetts
|
||||
"US-MI", # Michigan
|
||||
"US-MN", # Minnesota
|
||||
"US-MS", # Mississippi
|
||||
"US-MO", # Missouri
|
||||
"US-MT", # Montana
|
||||
"US-NE", # Nebraska
|
||||
"US-NV", # Nevada
|
||||
"US-NH", # New Hampshire
|
||||
"US-NJ", # New Jersey
|
||||
"US-NM", # New Mexico
|
||||
"US-NY", # New York
|
||||
"US-NC", # North Carolina
|
||||
"US-ND", # North Dakota
|
||||
"US-OH", # Ohio
|
||||
"US-OK", # Oklahoma
|
||||
"US-OR", # Oregon
|
||||
"US-PA", # Pennsylvania
|
||||
"US-RI", # Rhode Island
|
||||
"US-SC", # South Carolina
|
||||
"US-SD", # South Dakota
|
||||
"US-TN", # Tennessee
|
||||
"US-TX", # Texas
|
||||
"US-UT", # Utah
|
||||
"US-VT", # Vermont
|
||||
"US-VA", # Virginia
|
||||
"US-WA", # Washington
|
||||
"US-WV", # West Virginia
|
||||
"US-WI", # Wisconsin
|
||||
"US-WY", # Wyoming
|
||||
"US-DC", # District of Columbia
|
||||
|
||||
# Canada - Province/Territory level
|
||||
"CA-AB", # Alberta
|
||||
"CA-BC", # British Columbia
|
||||
"CA-MB", # Manitoba
|
||||
"CA-NB", # New Brunswick
|
||||
"CA-NL", # Newfoundland and Labrador
|
||||
"CA-NS", # Nova Scotia
|
||||
"CA-NT", # Northwest Territories
|
||||
"CA-NU", # Nunavut
|
||||
"CA-ON", # Ontario
|
||||
"CA-PE", # Prince Edward Island
|
||||
"CA-QC", # Quebec
|
||||
"CA-SK", # Saskatchewan
|
||||
"CA-YT", # Yukon
|
||||
|
||||
# Australia - State/Territory level
|
||||
"AU-NSW", # New South Wales
|
||||
"AU-VIC", # Victoria
|
||||
"AU-QLD", # Queensland
|
||||
"AU-SA", # South Australia
|
||||
"AU-WA", # Western Australia
|
||||
"AU-TAS", # Tasmania
|
||||
"AU-NT", # Northern Territory
|
||||
"AU-ACT", # Australian Capital Territory
|
||||
}
|
||||
|
||||
|
||||
def validate_timezone(timezone: str) -> bool:
|
||||
"""
|
||||
Validate timezone against IANA timezone database.
|
||||
|
||||
Checks if the provided timezone string is a valid IANA timezone
|
||||
identifier. Uses Python's zoneinfo module which is based on the
|
||||
IANA timezone database (tzdata).
|
||||
|
||||
Args:
|
||||
timezone: Timezone identifier (e.g., "America/New_York", "UTC")
|
||||
|
||||
Returns:
|
||||
bool: True if valid IANA timezone, False otherwise
|
||||
|
||||
Valid Examples:
|
||||
- "UTC"
|
||||
- "GMT"
|
||||
- "America/New_York"
|
||||
- "Europe/London"
|
||||
- "Asia/Tokyo"
|
||||
- "Australia/Sydney"
|
||||
|
||||
Invalid Examples:
|
||||
- "PST" (abbreviation, not IANA identifier)
|
||||
- "EST" (abbreviation)
|
||||
- "New York" (wrong format)
|
||||
- "america/new_york" (wrong case)
|
||||
|
||||
Example:
|
||||
>>> validate_timezone("America/New_York")
|
||||
True
|
||||
>>> validate_timezone("UTC")
|
||||
True
|
||||
>>> validate_timezone("PST")
|
||||
False
|
||||
>>> validate_timezone("Invalid/Zone")
|
||||
False
|
||||
|
||||
Note:
|
||||
- Case-sensitive (must match IANA database exactly)
|
||||
- Use available_timezones() to get full list of valid zones
|
||||
- Rejects timezone abbreviations (PST, EST, etc.)
|
||||
"""
|
||||
if not timezone or not isinstance(timezone, str):
|
||||
return False
|
||||
|
||||
# Check if timezone exists in IANA database
|
||||
# This is more efficient than trying to create a ZoneInfo object
|
||||
return timezone in available_timezones()
|
||||
|
||||
|
||||
def validate_tax_jurisdiction(jurisdiction: str) -> bool:
|
||||
"""
|
||||
Validate tax jurisdiction code.
|
||||
|
||||
Checks if the provided jurisdiction is in the list of valid
|
||||
tax jurisdictions. Supports both country-level and state/province-level
|
||||
jurisdictions.
|
||||
|
||||
Format:
|
||||
- Country level: "CC" (2-letter ISO 3166-1 alpha-2)
|
||||
- State/Province level: "CC-SS" (country-state with hyphen)
|
||||
|
||||
Args:
|
||||
jurisdiction: Tax jurisdiction code
|
||||
|
||||
Returns:
|
||||
bool: True if valid jurisdiction, False otherwise
|
||||
|
||||
Valid Examples:
|
||||
- "US" (United States)
|
||||
- "CA" (Canada)
|
||||
- "GB" (United Kingdom)
|
||||
- "US-CA" (California, USA)
|
||||
- "US-NY" (New York, USA)
|
||||
- "CA-ON" (Ontario, Canada)
|
||||
- "AU-NSW" (New South Wales, Australia)
|
||||
|
||||
Invalid Examples:
|
||||
- "us" (lowercase)
|
||||
- "USA" (3 letters)
|
||||
- "US_CA" (underscore instead of hyphen)
|
||||
- "US/CA" (slash instead of hyphen)
|
||||
- "XX" (non-existent country)
|
||||
|
||||
Example:
|
||||
>>> validate_tax_jurisdiction("US")
|
||||
True
|
||||
>>> validate_tax_jurisdiction("US-CA")
|
||||
True
|
||||
>>> validate_tax_jurisdiction("us")
|
||||
False
|
||||
>>> validate_tax_jurisdiction("XX-YY")
|
||||
False
|
||||
|
||||
Note:
|
||||
- Case-sensitive (must be uppercase)
|
||||
- Hyphen separator for state/province codes
|
||||
- List is comprehensive but not exhaustive
|
||||
- Add new jurisdictions to VALID_TAX_JURISDICTIONS set as needed
|
||||
"""
|
||||
if not jurisdiction or not isinstance(jurisdiction, str):
|
||||
return False
|
||||
|
||||
return jurisdiction in VALID_TAX_JURISDICTIONS
|
||||
|
||||
|
||||
def get_available_timezones() -> Set[str]:
|
||||
"""
|
||||
Get set of all available IANA timezones.
|
||||
|
||||
Returns the complete set of valid timezone identifiers from
|
||||
the IANA timezone database.
|
||||
|
||||
Returns:
|
||||
Set[str]: Set of valid timezone identifiers
|
||||
|
||||
Example:
|
||||
>>> timezones = get_available_timezones()
|
||||
>>> "America/New_York" in timezones
|
||||
True
|
||||
>>> len(timezones) > 500 # Hundreds of valid timezones
|
||||
True
|
||||
|
||||
Note:
|
||||
- This is a cached call (zoneinfo caches available_timezones)
|
||||
- Use for populating dropdowns or validation lists
|
||||
- Contains all IANA timezone database entries
|
||||
"""
|
||||
return available_timezones()
|
||||
|
||||
|
||||
def get_available_tax_jurisdictions() -> Set[str]:
|
||||
"""
|
||||
Get set of all available tax jurisdictions.
|
||||
|
||||
Returns the complete set of valid tax jurisdiction codes.
|
||||
|
||||
Returns:
|
||||
Set[str]: Set of valid tax jurisdiction codes
|
||||
|
||||
Example:
|
||||
>>> jurisdictions = get_available_tax_jurisdictions()
|
||||
>>> "US" in jurisdictions
|
||||
True
|
||||
>>> "US-CA" in jurisdictions
|
||||
True
|
||||
>>> len(jurisdictions) > 50 # Many jurisdictions supported
|
||||
True
|
||||
|
||||
Note:
|
||||
- Returns a copy to prevent external modification
|
||||
- Use for populating dropdowns or validation lists
|
||||
- Includes both country and state/province level codes
|
||||
"""
|
||||
return VALID_TAX_JURISDICTIONS.copy()
|
||||
Loading…
Reference in New Issue