feat(db): add User model fields for tax, timezone, API key - Fixes #3
This commit is contained in:
parent
9933a929df
commit
d3892b0da9
17
CHANGELOG.md
17
CHANGELOG.md
|
|
@ -36,6 +36,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
- New dependencies in pyproject.toml: fastapi, uvicorn, sqlalchemy, alembic, pydantic-settings, passlib, argon2-cffi, python-multipart, python-jose, cryptography
|
- New dependencies in pyproject.toml: fastapi, uvicorn, sqlalchemy, alembic, pydantic-settings, passlib, argon2-cffi, python-multipart, python-jose, cryptography
|
||||||
- API documentation generated from FastAPI OpenAPI schema (available at /docs and /redoc)
|
- API documentation generated from FastAPI OpenAPI schema (available at /docs and /redoc)
|
||||||
|
|
||||||
|
- User model enhancement with profile and API key management (Issue #3)
|
||||||
|
- Extended User model with tax_jurisdiction and timezone fields [file:tradingagents/api/models/user.py:47-54](tradingagents/api/models/user.py)
|
||||||
|
- Tax jurisdiction field supporting country (e.g., "US", "AU") and state/province level codes (e.g., "US-CA", "AU-NSW")
|
||||||
|
- IANA timezone identifier field (e.g., "America/New_York", "Australia/Sydney") with automatic validation
|
||||||
|
- Email verification status tracking via is_verified boolean field [file:tradingagents/api/models/user.py:60-64](tradingagents/api/models/user.py)
|
||||||
|
- Secure API key management with bcrypt hashing and unique constraints [file:tradingagents/api/models/user.py:55-59](tradingagents/api/models/user.py)
|
||||||
|
- API key service module with generate_api_key(), hash_api_key(), and verify_api_key() functions [file:tradingagents/api/services/api_key_service.py](tradingagents/api/services/api_key_service.py)
|
||||||
|
- API key generation using secrets.token_urlsafe() with 256-bit entropy and 'ta_' prefix
|
||||||
|
- Bcrypt-based API key hashing using pwdlib.PasswordHash for secure storage
|
||||||
|
- Constant-time verification to prevent timing attacks on API keys
|
||||||
|
- Timezone validator using IANA zoneinfo database [file:tradingagents/api/services/validators.py:134-191](tradingagents/api/services/validators.py)
|
||||||
|
- Tax jurisdiction validator supporting 50+ country codes and state/province subdivisions [file:tradingagents/api/services/validators.py:193-283](tradingagents/api/services/validators.py)
|
||||||
|
- Utility functions get_available_timezones() and get_available_tax_jurisdictions() for UI dropdowns [file:tradingagents/api/services/validators.py:285-333](tradingagents/api/services/validators.py)
|
||||||
|
- Database migration 002_add_user_profile_fields.py with proper defaults and constraints [file:migrations/versions/002_add_user_profile_fields.py](migrations/versions/002_add_user_profile_fields.py)
|
||||||
|
- Migration rollback support for reversible schema changes
|
||||||
|
- Comprehensive docstrings and security considerations for all functions
|
||||||
|
|
||||||
- Test fixtures directory with centralized mock data (Issue #51)
|
- Test fixtures directory with centralized mock data (Issue #51)
|
||||||
- FixtureLoader class for loading JSON fixtures with automatic datetime parsing [file:tests/fixtures/__init__.py](tests/fixtures/__init__.py)
|
- FixtureLoader class for loading JSON fixtures with automatic datetime parsing [file:tests/fixtures/__init__.py](tests/fixtures/__init__.py)
|
||||||
- Stock data fixtures: US market OHLCV, Chinese market OHLCV, standardized data [file:tests/fixtures/stock_data/](tests/fixtures/stock_data/)
|
- Stock data fixtures: US market OHLCV, Chinese market OHLCV, standardized data [file:tests/fixtures/stock_data/](tests/fixtures/stock_data/)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,115 @@
|
||||||
|
"""Add user profile fields - tax_jurisdiction, timezone, api_key_hash, is_verified
|
||||||
|
|
||||||
|
Revision ID: 002
|
||||||
|
Revises: 001
|
||||||
|
Create Date: 2025-12-26 13:00:00.000000
|
||||||
|
|
||||||
|
This migration adds four new fields to the users table:
|
||||||
|
- tax_jurisdiction: Tax jurisdiction code (default: AU)
|
||||||
|
- timezone: IANA timezone identifier (default: Australia/Sydney)
|
||||||
|
- api_key_hash: Bcrypt hash of API key for programmatic access (nullable)
|
||||||
|
- is_verified: Email verification status (default: False)
|
||||||
|
|
||||||
|
All existing users will get default values for the new required fields.
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '002'
|
||||||
|
down_revision: Union[str, None] = '001'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Add tax_jurisdiction, timezone, api_key_hash, and is_verified columns to users table.
|
||||||
|
|
||||||
|
For existing rows:
|
||||||
|
- tax_jurisdiction defaults to "AU"
|
||||||
|
- timezone defaults to "Australia/Sydney"
|
||||||
|
- api_key_hash is NULL
|
||||||
|
- is_verified is False
|
||||||
|
"""
|
||||||
|
# Add tax_jurisdiction column
|
||||||
|
op.add_column(
|
||||||
|
'users',
|
||||||
|
sa.Column(
|
||||||
|
'tax_jurisdiction',
|
||||||
|
sa.String(length=10),
|
||||||
|
nullable=False,
|
||||||
|
server_default='AU',
|
||||||
|
comment='Tax jurisdiction code (e.g., US, US-CA, AU-NSW)'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add timezone column
|
||||||
|
op.add_column(
|
||||||
|
'users',
|
||||||
|
sa.Column(
|
||||||
|
'timezone',
|
||||||
|
sa.String(length=50),
|
||||||
|
nullable=False,
|
||||||
|
server_default='Australia/Sydney',
|
||||||
|
comment='IANA timezone identifier (e.g., America/New_York, UTC)'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add api_key_hash column with unique constraint and index
|
||||||
|
op.add_column(
|
||||||
|
'users',
|
||||||
|
sa.Column(
|
||||||
|
'api_key_hash',
|
||||||
|
sa.String(length=255),
|
||||||
|
nullable=True,
|
||||||
|
comment='Bcrypt hash of API key for programmatic access'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Create unique constraint for api_key_hash
|
||||||
|
op.create_unique_constraint(
|
||||||
|
'uq_users_api_key_hash',
|
||||||
|
'users',
|
||||||
|
['api_key_hash']
|
||||||
|
)
|
||||||
|
# Create index for fast lookups
|
||||||
|
op.create_index(
|
||||||
|
'ix_users_api_key_hash',
|
||||||
|
'users',
|
||||||
|
['api_key_hash'],
|
||||||
|
unique=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add is_verified column
|
||||||
|
op.add_column(
|
||||||
|
'users',
|
||||||
|
sa.Column(
|
||||||
|
'is_verified',
|
||||||
|
sa.Boolean(),
|
||||||
|
nullable=False,
|
||||||
|
server_default='0',
|
||||||
|
comment='Whether user email has been verified'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Remove tax_jurisdiction, timezone, api_key_hash, and is_verified columns from users table.
|
||||||
|
|
||||||
|
WARNING: This will permanently delete data in these columns!
|
||||||
|
"""
|
||||||
|
# Remove is_verified column
|
||||||
|
op.drop_column('users', 'is_verified')
|
||||||
|
|
||||||
|
# Remove api_key_hash column (drop index and constraint first)
|
||||||
|
op.drop_index('ix_users_api_key_hash', 'users')
|
||||||
|
op.drop_constraint('uq_users_api_key_hash', 'users', type_='unique')
|
||||||
|
op.drop_column('users', 'api_key_hash')
|
||||||
|
|
||||||
|
# Remove timezone column
|
||||||
|
op.drop_column('users', 'timezone')
|
||||||
|
|
||||||
|
# Remove tax_jurisdiction column
|
||||||
|
op.drop_column('users', 'tax_jurisdiction')
|
||||||
|
|
@ -253,6 +253,8 @@ def test_user_data() -> Dict[str, Any]:
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
"password": "SecurePassword123!",
|
"password": "SecurePassword123!",
|
||||||
"full_name": "Test User",
|
"full_name": "Test User",
|
||||||
|
"timezone": "America/New_York", # Issue #3
|
||||||
|
"tax_jurisdiction": "US-NY", # Issue #3
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -269,6 +271,8 @@ def second_user_data() -> Dict[str, Any]:
|
||||||
"email": "other@example.com",
|
"email": "other@example.com",
|
||||||
"password": "AnotherPassword456!",
|
"password": "AnotherPassword456!",
|
||||||
"full_name": "Other User",
|
"full_name": "Other User",
|
||||||
|
"timezone": "America/Los_Angeles", # Issue #3
|
||||||
|
"tax_jurisdiction": "US-CA", # Issue #3
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -667,3 +671,228 @@ def sample_xss_payloads() -> list[str]:
|
||||||
"<img src=x onerror=alert('XSS')>",
|
"<img src=x onerror=alert('XSS')>",
|
||||||
"<svg onload=alert('XSS')>",
|
"<svg onload=alert('XSS')>",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Issue #3 Fixtures: API Keys, Timezones, Tax Jurisdictions
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def verified_user_data() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Test user data for verified user (Issue #3).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Verified user data with all Issue #3 fields
|
||||||
|
|
||||||
|
Example:
|
||||||
|
async def test_verified_user(verified_user_data):
|
||||||
|
assert verified_user_data["is_verified"] is True
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"username": "verifieduser",
|
||||||
|
"email": "verified@example.com",
|
||||||
|
"password": "VerifiedPassword123!",
|
||||||
|
"full_name": "Verified User",
|
||||||
|
"timezone": "UTC",
|
||||||
|
"tax_jurisdiction": "US",
|
||||||
|
"is_verified": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def verified_user(db_session, verified_user_data):
|
||||||
|
"""
|
||||||
|
Create verified test user in database (Issue #3).
|
||||||
|
|
||||||
|
Creates a verified user with timezone and tax jurisdiction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Database session
|
||||||
|
verified_user_data: Verified user data
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
User: Created verified user model instance
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
from tradingagents.api.services.auth_service import hash_password
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username=verified_user_data["username"],
|
||||||
|
email=verified_user_data["email"],
|
||||||
|
hashed_password=hash_password(verified_user_data["password"]),
|
||||||
|
full_name=verified_user_data.get("full_name"),
|
||||||
|
timezone=verified_user_data.get("timezone"),
|
||||||
|
tax_jurisdiction=verified_user_data.get("tax_jurisdiction"),
|
||||||
|
is_verified=verified_user_data.get("is_verified", True),
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
yield user
|
||||||
|
except ImportError:
|
||||||
|
yield None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def user_with_api_key(db_session, test_user_data):
|
||||||
|
"""
|
||||||
|
Create test user with API key in database (Issue #3).
|
||||||
|
|
||||||
|
Creates a user with a hashed API key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session: Database session
|
||||||
|
test_user_data: Test user data
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
tuple[User, str]: (Created user, plain API key)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
from tradingagents.api.services.auth_service import hash_password
|
||||||
|
from tradingagents.api.services.api_key_service import (
|
||||||
|
generate_api_key,
|
||||||
|
hash_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate API key
|
||||||
|
plain_api_key = generate_api_key()
|
||||||
|
hashed_api_key = hash_api_key(plain_api_key)
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username=test_user_data["username"],
|
||||||
|
email=test_user_data["email"],
|
||||||
|
hashed_password=hash_password(test_user_data["password"]),
|
||||||
|
full_name=test_user_data.get("full_name"),
|
||||||
|
api_key_hash=hashed_api_key,
|
||||||
|
timezone=test_user_data.get("timezone"),
|
||||||
|
tax_jurisdiction=test_user_data.get("tax_jurisdiction"),
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
yield (user, plain_api_key)
|
||||||
|
except ImportError:
|
||||||
|
yield (None, None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_timezones() -> list[str]:
|
||||||
|
"""
|
||||||
|
List of valid IANA timezones for testing (Issue #3).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: Valid timezone identifiers
|
||||||
|
|
||||||
|
Example:
|
||||||
|
def test_timezones(valid_timezones):
|
||||||
|
for tz in valid_timezones:
|
||||||
|
assert validate_timezone(tz) is True
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"UTC",
|
||||||
|
"GMT",
|
||||||
|
"America/New_York",
|
||||||
|
"America/Los_Angeles",
|
||||||
|
"America/Chicago",
|
||||||
|
"America/Denver",
|
||||||
|
"Europe/London",
|
||||||
|
"Europe/Paris",
|
||||||
|
"Europe/Berlin",
|
||||||
|
"Asia/Tokyo",
|
||||||
|
"Asia/Shanghai",
|
||||||
|
"Asia/Hong_Kong",
|
||||||
|
"Australia/Sydney",
|
||||||
|
"Australia/Melbourne",
|
||||||
|
"Pacific/Auckland",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def invalid_timezones() -> list[str]:
|
||||||
|
"""
|
||||||
|
List of invalid timezones for testing (Issue #3).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: Invalid timezone identifiers
|
||||||
|
|
||||||
|
Example:
|
||||||
|
def test_invalid_timezones(invalid_timezones):
|
||||||
|
for tz in invalid_timezones:
|
||||||
|
assert validate_timezone(tz) is False
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"PST",
|
||||||
|
"EST",
|
||||||
|
"CST",
|
||||||
|
"MST",
|
||||||
|
"America/InvalidCity",
|
||||||
|
"Europe/FakePlace",
|
||||||
|
"Random/Stuff",
|
||||||
|
"america/new_york", # Wrong case
|
||||||
|
"123456",
|
||||||
|
"!@#$%",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_tax_jurisdictions() -> list[str]:
|
||||||
|
"""
|
||||||
|
List of valid tax jurisdictions for testing (Issue #3).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: Valid tax jurisdiction codes
|
||||||
|
|
||||||
|
Example:
|
||||||
|
def test_jurisdictions(valid_tax_jurisdictions):
|
||||||
|
for jurisdiction in valid_tax_jurisdictions:
|
||||||
|
assert validate_tax_jurisdiction(jurisdiction) is True
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"US",
|
||||||
|
"CA",
|
||||||
|
"GB",
|
||||||
|
"DE",
|
||||||
|
"FR",
|
||||||
|
"JP",
|
||||||
|
"AU",
|
||||||
|
"US-CA",
|
||||||
|
"US-NY",
|
||||||
|
"US-TX",
|
||||||
|
"US-FL",
|
||||||
|
"CA-ON",
|
||||||
|
"CA-QC",
|
||||||
|
"CA-BC",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def invalid_tax_jurisdictions() -> list[str]:
|
||||||
|
"""
|
||||||
|
List of invalid tax jurisdictions for testing (Issue #3).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: Invalid tax jurisdiction codes
|
||||||
|
|
||||||
|
Example:
|
||||||
|
def test_invalid_jurisdictions(invalid_tax_jurisdictions):
|
||||||
|
for jurisdiction in invalid_tax_jurisdictions:
|
||||||
|
assert validate_tax_jurisdiction(jurisdiction) is False
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"InvalidFormat",
|
||||||
|
"US_CA", # Wrong separator
|
||||||
|
"US/CA", # Wrong separator
|
||||||
|
"USCA", # No separator
|
||||||
|
"us-ca", # Lowercase
|
||||||
|
"XX-YY", # Invalid country code
|
||||||
|
"123",
|
||||||
|
"!@#",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,587 @@
|
||||||
|
"""
|
||||||
|
Test suite for API Key Service (Issue #3).
|
||||||
|
|
||||||
|
This module tests the API key generation, hashing, and verification service:
|
||||||
|
1. Generate secure random API keys
|
||||||
|
2. Hash API keys using bcrypt
|
||||||
|
3. Verify API keys against hashes
|
||||||
|
4. Key format validation
|
||||||
|
5. Security best practices
|
||||||
|
|
||||||
|
Tests follow TDD - written before implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Unit Tests: API Key Generation
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestApiKeyGeneration:
|
||||||
|
"""Test API key generation functionality."""
|
||||||
|
|
||||||
|
def test_generate_api_key_returns_string(self):
|
||||||
|
"""Test that generate_api_key returns a string."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import generate_api_key
|
||||||
|
|
||||||
|
api_key = generate_api_key()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(api_key, str)
|
||||||
|
assert len(api_key) > 0
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_generate_api_key_format(self):
|
||||||
|
"""Test that generated API key has correct format."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import generate_api_key
|
||||||
|
|
||||||
|
api_key = generate_api_key()
|
||||||
|
|
||||||
|
# Assert: Should be prefixed with "ta_" (TradingAgents)
|
||||||
|
assert api_key.startswith("ta_")
|
||||||
|
|
||||||
|
# Should contain only alphanumeric characters after prefix
|
||||||
|
key_part = api_key[3:] # Remove "ta_" prefix
|
||||||
|
assert re.match(r'^[A-Za-z0-9]+$', key_part)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_generate_api_key_length(self):
|
||||||
|
"""Test that generated API key has sufficient length for security."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import generate_api_key
|
||||||
|
|
||||||
|
api_key = generate_api_key()
|
||||||
|
|
||||||
|
# Assert: Should be at least 32 characters (including prefix)
|
||||||
|
assert len(api_key) >= 32
|
||||||
|
|
||||||
|
# Key part (without prefix) should be at least 29 chars
|
||||||
|
key_part = api_key[3:]
|
||||||
|
assert len(key_part) >= 29
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_generate_api_key_uniqueness(self):
|
||||||
|
"""Test that each generated API key is unique."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import generate_api_key
|
||||||
|
|
||||||
|
keys = [generate_api_key() for _ in range(100)]
|
||||||
|
|
||||||
|
# Assert: All keys should be unique
|
||||||
|
assert len(keys) == len(set(keys))
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_generate_api_key_randomness(self):
|
||||||
|
"""Test that API keys have sufficient randomness."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import generate_api_key
|
||||||
|
|
||||||
|
keys = [generate_api_key() for _ in range(10)]
|
||||||
|
|
||||||
|
# Assert: Keys should not have common patterns
|
||||||
|
# Check that no two keys share same first 10 chars after prefix
|
||||||
|
prefixes = [key[3:13] for key in keys]
|
||||||
|
assert len(prefixes) == len(set(prefixes))
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_generate_api_key_custom_length(self):
|
||||||
|
"""Test generating API key with custom length."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import generate_api_key
|
||||||
|
|
||||||
|
# Try generating key with custom length (if supported)
|
||||||
|
api_key = generate_api_key(length=40)
|
||||||
|
|
||||||
|
# Assert: Should respect custom length
|
||||||
|
assert len(api_key) >= 40 or len(api_key) >= 32 # May have min length
|
||||||
|
except (ImportError, TypeError):
|
||||||
|
# TypeError is ok if length parameter not implemented
|
||||||
|
pytest.skip("Custom length not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Unit Tests: API Key Hashing
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestApiKeyHashing:
|
||||||
|
"""Test API key hashing functionality."""
|
||||||
|
|
||||||
|
def test_hash_api_key_returns_string(self):
|
||||||
|
"""Test that hash_api_key returns a string."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import hash_api_key
|
||||||
|
|
||||||
|
api_key = "ta_test1234567890abcdefghijklmnop"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(hashed, str)
|
||||||
|
assert len(hashed) > 0
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_hash_api_key_bcrypt_format(self):
|
||||||
|
"""Test that hashed API key uses bcrypt format."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import hash_api_key
|
||||||
|
|
||||||
|
api_key = "ta_test1234567890abcdefghijklmnop"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
|
||||||
|
# Assert: Should be bcrypt format ($2b$rounds$salt+hash)
|
||||||
|
assert hashed.startswith("$2b$") or hashed.startswith("$2a$")
|
||||||
|
assert len(hashed) == 60 # Standard bcrypt hash length
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_hash_api_key_different_for_same_input(self):
|
||||||
|
"""Test that hashing same key twice produces different hashes (salt)."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import hash_api_key
|
||||||
|
|
||||||
|
api_key = "ta_test1234567890abcdefghijklmnop"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
hash1 = hash_api_key(api_key)
|
||||||
|
hash2 = hash_api_key(api_key)
|
||||||
|
|
||||||
|
# Assert: Should be different due to different salts
|
||||||
|
assert hash1 != hash2
|
||||||
|
assert hash1.startswith("$2b$")
|
||||||
|
assert hash2.startswith("$2b$")
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_hash_api_key_different_keys_different_hashes(self):
|
||||||
|
"""Test that different keys produce different hashes."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import hash_api_key
|
||||||
|
|
||||||
|
key1 = "ta_key1234567890abcdefghijklmnop"
|
||||||
|
key2 = "ta_key9876543210zyxwvutsrqponmlk"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
hash1 = hash_api_key(key1)
|
||||||
|
hash2 = hash_api_key(key2)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert hash1 != hash2
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_hash_api_key_empty_string(self):
|
||||||
|
"""Test hashing empty string."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import hash_api_key
|
||||||
|
|
||||||
|
# Act & Assert: Should handle gracefully or raise ValueError
|
||||||
|
try:
|
||||||
|
hashed = hash_api_key("")
|
||||||
|
# If it doesn't raise, should still return valid hash
|
||||||
|
assert isinstance(hashed, str)
|
||||||
|
except ValueError:
|
||||||
|
# Acceptable to raise ValueError for empty key
|
||||||
|
pass
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_hash_api_key_special_characters(self):
|
||||||
|
"""Test hashing API key with special characters."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import hash_api_key
|
||||||
|
|
||||||
|
api_key = "ta_key!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
|
||||||
|
# Assert: Should handle special characters
|
||||||
|
assert isinstance(hashed, str)
|
||||||
|
assert hashed.startswith("$2b$") or hashed.startswith("$2a$")
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Unit Tests: API Key Verification
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestApiKeyVerification:
|
||||||
|
"""Test API key verification functionality."""
|
||||||
|
|
||||||
|
def test_verify_api_key_correct_key(self):
|
||||||
|
"""Test verifying API key with correct hash."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import (
|
||||||
|
hash_api_key,
|
||||||
|
verify_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = "ta_correct1234567890abcdefghijk"
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
is_valid = verify_api_key(api_key, hashed)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert is_valid is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_verify_api_key_incorrect_key(self):
|
||||||
|
"""Test verifying API key with wrong hash."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import (
|
||||||
|
hash_api_key,
|
||||||
|
verify_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
correct_key = "ta_correct1234567890abcdefghijk"
|
||||||
|
wrong_key = "ta_wrongkey1234567890abcdefghij"
|
||||||
|
hashed = hash_api_key(correct_key)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
is_valid = verify_api_key(wrong_key, hashed)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert is_valid is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_verify_api_key_empty_key(self):
|
||||||
|
"""Test verifying empty API key."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import (
|
||||||
|
hash_api_key,
|
||||||
|
verify_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = "ta_test1234567890abcdefghijklmn"
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
is_valid = verify_api_key("", hashed)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert is_valid is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_verify_api_key_none_key(self):
|
||||||
|
"""Test verifying None API key."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import (
|
||||||
|
hash_api_key,
|
||||||
|
verify_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = "ta_test1234567890abcdefghijklmn"
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
|
||||||
|
# Act & Assert: Should return False or raise TypeError
|
||||||
|
try:
|
||||||
|
is_valid = verify_api_key(None, hashed)
|
||||||
|
assert is_valid is False
|
||||||
|
except TypeError:
|
||||||
|
# Acceptable to raise TypeError for None
|
||||||
|
pass
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_verify_api_key_invalid_hash(self):
|
||||||
|
"""Test verifying API key against invalid hash."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import verify_api_key
|
||||||
|
|
||||||
|
api_key = "ta_test1234567890abcdefghijklmn"
|
||||||
|
invalid_hash = "not-a-valid-bcrypt-hash"
|
||||||
|
|
||||||
|
# Act & Assert: Should return False or raise ValueError
|
||||||
|
try:
|
||||||
|
is_valid = verify_api_key(api_key, invalid_hash)
|
||||||
|
assert is_valid is False
|
||||||
|
except ValueError:
|
||||||
|
# Acceptable to raise ValueError for invalid hash
|
||||||
|
pass
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_verify_api_key_case_sensitive(self):
|
||||||
|
"""Test that API key verification is case-sensitive."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import (
|
||||||
|
hash_api_key,
|
||||||
|
verify_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = "ta_TestKey1234567890ABCDEFGHIJK"
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
is_valid_correct = verify_api_key(api_key, hashed)
|
||||||
|
is_valid_wrong_case = verify_api_key(api_key.lower(), hashed)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert is_valid_correct is True
|
||||||
|
assert is_valid_wrong_case is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_verify_api_key_similar_keys(self):
|
||||||
|
"""Test that similar keys don't validate."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import (
|
||||||
|
hash_api_key,
|
||||||
|
verify_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = "ta_key1234567890abcdefghijklmnop"
|
||||||
|
similar_key = "ta_key1234567890abcdefghijklmnox" # Last char different
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
is_valid = verify_api_key(similar_key, hashed)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert is_valid is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Integration Tests: Full API Key Lifecycle
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestApiKeyLifecycle:
|
||||||
|
"""Test complete API key generation, hashing, and verification workflow."""
|
||||||
|
|
||||||
|
def test_full_api_key_lifecycle(self):
|
||||||
|
"""Test complete lifecycle: generate -> hash -> verify."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import (
|
||||||
|
generate_api_key,
|
||||||
|
hash_api_key,
|
||||||
|
verify_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate new API key
|
||||||
|
api_key = generate_api_key()
|
||||||
|
|
||||||
|
# Hash the API key
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
|
||||||
|
# Verify with correct key
|
||||||
|
is_valid = verify_api_key(api_key, hashed)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(api_key, str)
|
||||||
|
assert api_key.startswith("ta_")
|
||||||
|
assert isinstance(hashed, str)
|
||||||
|
assert hashed.startswith("$2b$") or hashed.startswith("$2a$")
|
||||||
|
assert is_valid is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_multiple_keys_independent(self):
|
||||||
|
"""Test that multiple API keys can coexist independently."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import (
|
||||||
|
generate_api_key,
|
||||||
|
hash_api_key,
|
||||||
|
verify_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate multiple keys
|
||||||
|
key1 = generate_api_key()
|
||||||
|
key2 = generate_api_key()
|
||||||
|
key3 = generate_api_key()
|
||||||
|
|
||||||
|
# Hash each key
|
||||||
|
hash1 = hash_api_key(key1)
|
||||||
|
hash2 = hash_api_key(key2)
|
||||||
|
hash3 = hash_api_key(key3)
|
||||||
|
|
||||||
|
# Assert: Each key only verifies against its own hash
|
||||||
|
assert verify_api_key(key1, hash1) is True
|
||||||
|
assert verify_api_key(key1, hash2) is False
|
||||||
|
assert verify_api_key(key1, hash3) is False
|
||||||
|
|
||||||
|
assert verify_api_key(key2, hash1) is False
|
||||||
|
assert verify_api_key(key2, hash2) is True
|
||||||
|
assert verify_api_key(key2, hash3) is False
|
||||||
|
|
||||||
|
assert verify_api_key(key3, hash1) is False
|
||||||
|
assert verify_api_key(key3, hash2) is False
|
||||||
|
assert verify_api_key(key3, hash3) is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_regenerate_key_invalidates_old_hash(self):
|
||||||
|
"""Test that regenerating a key invalidates the old hash."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import (
|
||||||
|
generate_api_key,
|
||||||
|
hash_api_key,
|
||||||
|
verify_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate and hash first key
|
||||||
|
old_key = generate_api_key()
|
||||||
|
old_hash = hash_api_key(old_key)
|
||||||
|
|
||||||
|
# Generate new key (simulate regeneration)
|
||||||
|
new_key = generate_api_key()
|
||||||
|
|
||||||
|
# Assert: Old hash should not work with new key
|
||||||
|
assert verify_api_key(old_key, old_hash) is True
|
||||||
|
assert verify_api_key(new_key, old_hash) is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Edge Cases: API Key Service
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestApiKeyEdgeCases:
|
||||||
|
"""Test edge cases in API key service."""
|
||||||
|
|
||||||
|
def test_hash_very_long_key(self):
|
||||||
|
"""Test hashing very long API key."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import (
|
||||||
|
hash_api_key,
|
||||||
|
verify_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create very long key (bcrypt has 72 byte limit)
|
||||||
|
long_key = "ta_" + "a" * 200
|
||||||
|
|
||||||
|
# Act
|
||||||
|
hashed = hash_api_key(long_key)
|
||||||
|
is_valid = verify_api_key(long_key, hashed)
|
||||||
|
|
||||||
|
# Assert: Should handle gracefully (may truncate to 72 bytes)
|
||||||
|
assert isinstance(hashed, str)
|
||||||
|
# Bcrypt will only use first ~72 bytes, so verification should work
|
||||||
|
assert is_valid is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_hash_unicode_characters(self):
|
||||||
|
"""Test hashing API key with Unicode characters."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import (
|
||||||
|
hash_api_key,
|
||||||
|
verify_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Key with Unicode characters
|
||||||
|
unicode_key = "ta_测试key_🔑_αβγ"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
hashed = hash_api_key(unicode_key)
|
||||||
|
is_valid = verify_api_key(unicode_key, hashed)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(hashed, str)
|
||||||
|
assert is_valid is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_timing_attack_resistance(self):
|
||||||
|
"""Test that verification takes similar time for valid/invalid keys."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import (
|
||||||
|
generate_api_key,
|
||||||
|
hash_api_key,
|
||||||
|
verify_api_key,
|
||||||
|
)
|
||||||
|
import time
|
||||||
|
|
||||||
|
api_key = generate_api_key()
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
wrong_key = generate_api_key()
|
||||||
|
|
||||||
|
# Act: Measure time for correct and incorrect verification
|
||||||
|
times_correct = []
|
||||||
|
times_incorrect = []
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
start = time.perf_counter()
|
||||||
|
verify_api_key(api_key, hashed)
|
||||||
|
times_correct.append(time.perf_counter() - start)
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
verify_api_key(wrong_key, hashed)
|
||||||
|
times_incorrect.append(time.perf_counter() - start)
|
||||||
|
|
||||||
|
# Assert: Times should be similar (within same order of magnitude)
|
||||||
|
# This is a basic check - bcrypt is inherently resistant to timing attacks
|
||||||
|
avg_correct = sum(times_correct) / len(times_correct)
|
||||||
|
avg_incorrect = sum(times_incorrect) / len(times_incorrect)
|
||||||
|
|
||||||
|
# Both should take similar time (bcrypt always does full comparison)
|
||||||
|
assert avg_correct > 0
|
||||||
|
assert avg_incorrect > 0
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
||||||
|
def test_concurrent_key_generation(self):
|
||||||
|
"""Test generating API keys concurrently."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.api_key_service import generate_api_key
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
# Generate keys concurrently
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||||
|
futures = [executor.submit(generate_api_key) for _ in range(100)]
|
||||||
|
keys = [f.result() for f in futures]
|
||||||
|
|
||||||
|
# Assert: All keys should be unique
|
||||||
|
assert len(keys) == len(set(keys))
|
||||||
|
assert all(k.startswith("ta_") for k in keys)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("API key service not implemented yet")
|
||||||
|
|
@ -1,13 +1,14 @@
|
||||||
"""
|
"""
|
||||||
Test suite for SQLAlchemy database models.
|
Test suite for SQLAlchemy database models.
|
||||||
|
|
||||||
This module tests Issue #48 database models:
|
This module tests Issue #48 and Issue #3 database models:
|
||||||
1. User model with hashed passwords
|
1. User model with hashed passwords
|
||||||
2. Strategy model with JSON parameters
|
2. User model with tax_jurisdiction, timezone, api_key_hash, is_verified (Issue #3)
|
||||||
3. Relationships (User -> Strategies)
|
3. Strategy model with JSON parameters
|
||||||
4. Model validation and constraints
|
4. Relationships (User -> Strategies)
|
||||||
5. Timestamps (created_at, updated_at)
|
5. Model validation and constraints
|
||||||
6. Cascade delete behavior
|
6. Timestamps (created_at, updated_at)
|
||||||
|
7. Cascade delete behavior
|
||||||
|
|
||||||
Tests follow TDD - written before implementation.
|
Tests follow TDD - written before implementation.
|
||||||
"""
|
"""
|
||||||
|
|
@ -804,3 +805,451 @@ class TestModelEdgeCases:
|
||||||
assert strategy.parameters["l1"]["l2"]["l3"]["l4"]["l5"]["value"] == "deep"
|
assert strategy.parameters["l1"]["l2"]["l3"]["l4"]["l5"]["value"] == "deep"
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pytest.skip("Models not implemented yet")
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Unit Tests: Issue #3 - User Model Enhancements
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestUserModelIssue3:
|
||||||
|
"""Test User model enhancements from Issue #3.
|
||||||
|
|
||||||
|
New fields tested:
|
||||||
|
- tax_jurisdiction: Nullable string for user's tax jurisdiction
|
||||||
|
- timezone: Nullable string for user's timezone (must be valid IANA timezone)
|
||||||
|
- api_key_hash: Nullable string for hashed API key
|
||||||
|
- is_verified: Boolean flag for email verification (defaults to False)
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def test_user_tax_jurisdiction_default_none(self, db_session):
|
||||||
|
"""Test that tax_jurisdiction defaults to None."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="taxuser",
|
||||||
|
email="tax@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert hasattr(user, "tax_jurisdiction")
|
||||||
|
assert user.tax_jurisdiction is None
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_tax_jurisdiction_custom_value(self, db_session):
|
||||||
|
"""Test setting custom tax_jurisdiction value."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="taxuser2",
|
||||||
|
email="tax2@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
tax_jurisdiction="US-CA",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert user.tax_jurisdiction == "US-CA"
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_timezone_default_none(self, db_session):
|
||||||
|
"""Test that timezone defaults to None."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="tzuser",
|
||||||
|
email="tz@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert hasattr(user, "timezone")
|
||||||
|
assert user.timezone is None
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_timezone_valid_value(self, db_session):
|
||||||
|
"""Test setting valid timezone value."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="tzuser2",
|
||||||
|
email="tz2@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
timezone="America/New_York",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert user.timezone == "America/New_York"
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_timezone_various_timezones(self, db_session):
|
||||||
|
"""Test various valid IANA timezone values."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
|
||||||
|
timezones = [
|
||||||
|
"UTC",
|
||||||
|
"America/Los_Angeles",
|
||||||
|
"Europe/London",
|
||||||
|
"Asia/Tokyo",
|
||||||
|
"Australia/Sydney",
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, tz in enumerate(timezones):
|
||||||
|
user = User(
|
||||||
|
username=f"tzuser_{i}",
|
||||||
|
email=f"tz{i}@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
timezone=tz,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert user.timezone == tz
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_api_key_hash_default_none(self, db_session):
|
||||||
|
"""Test that api_key_hash defaults to None."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="apiuser",
|
||||||
|
email="api@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert hasattr(user, "api_key_hash")
|
||||||
|
assert user.api_key_hash is None
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_api_key_hash_custom_value(self, db_session):
|
||||||
|
"""Test setting api_key_hash value."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
|
||||||
|
# Simulate hashed API key (bcrypt hash format)
|
||||||
|
api_key_hash = "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/LewY5ygOy3f3K0E6O"
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="apiuser2",
|
||||||
|
email="api2@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
api_key_hash=api_key_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert user.api_key_hash == api_key_hash
|
||||||
|
assert user.api_key_hash.startswith("$2b$")
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_is_verified_default_false(self, db_session):
|
||||||
|
"""Test that is_verified defaults to False."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="verifyuser",
|
||||||
|
email="verify@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert hasattr(user, "is_verified")
|
||||||
|
assert user.is_verified is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_is_verified_can_be_true(self, db_session):
|
||||||
|
"""Test setting is_verified to True."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="verifyuser2",
|
||||||
|
email="verify2@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
is_verified=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert user.is_verified is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_all_new_fields_together(self, db_session):
|
||||||
|
"""Test creating user with all Issue #3 fields."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="fulluser",
|
||||||
|
email="full@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
tax_jurisdiction="US-NY",
|
||||||
|
timezone="America/New_York",
|
||||||
|
api_key_hash="$2b$12$hashedapikey123",
|
||||||
|
is_verified=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert user.tax_jurisdiction == "US-NY"
|
||||||
|
assert user.timezone == "America/New_York"
|
||||||
|
assert user.api_key_hash == "$2b$12$hashedapikey123"
|
||||||
|
assert user.is_verified is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_update_timezone(self, db_session):
|
||||||
|
"""Test updating user's timezone."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="updatetz",
|
||||||
|
email="updatetz@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
timezone="UTC",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Act: Update timezone
|
||||||
|
user.timezone = "America/Los_Angeles"
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert user.timezone == "America/Los_Angeles"
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_update_tax_jurisdiction(self, db_session):
|
||||||
|
"""Test updating user's tax_jurisdiction."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="updatetax",
|
||||||
|
email="updatetax@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
tax_jurisdiction="US-CA",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Act: Update tax jurisdiction
|
||||||
|
user.tax_jurisdiction = "US-TX"
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert user.tax_jurisdiction == "US-TX"
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_verify_email(self, db_session):
|
||||||
|
"""Test verifying user's email."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="toverify",
|
||||||
|
email="toverify@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
is_verified=False,
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.is_verified is False
|
||||||
|
|
||||||
|
# Act: Verify email
|
||||||
|
user.is_verified = True
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert user.is_verified is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_query_by_timezone(self, db_session):
|
||||||
|
"""Test querying users by timezone."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
# Create users with different timezones
|
||||||
|
user1 = User(
|
||||||
|
username="user_utc",
|
||||||
|
email="utc@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
timezone="UTC",
|
||||||
|
)
|
||||||
|
user2 = User(
|
||||||
|
username="user_ny",
|
||||||
|
email="ny@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
timezone="America/New_York",
|
||||||
|
)
|
||||||
|
user3 = User(
|
||||||
|
username="user_utc2",
|
||||||
|
email="utc2@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
timezone="UTC",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add_all([user1, user2, user3])
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Act: Query users in UTC timezone
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(User).where(User.timezone == "UTC")
|
||||||
|
)
|
||||||
|
utc_users = result.scalars().all()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(utc_users) == 2
|
||||||
|
assert all(u.timezone == "UTC" for u in utc_users)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_query_verified_users(self, db_session):
|
||||||
|
"""Test querying only verified users."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
# Create verified and unverified users
|
||||||
|
verified_user = User(
|
||||||
|
username="verified",
|
||||||
|
email="verified@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
is_verified=True,
|
||||||
|
)
|
||||||
|
unverified_user = User(
|
||||||
|
username="unverified",
|
||||||
|
email="unverified@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
is_verified=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add_all([verified_user, unverified_user])
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Act: Query verified users
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(User).where(User.is_verified == True)
|
||||||
|
)
|
||||||
|
verified_users = result.scalars().all()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(verified_users) >= 1
|
||||||
|
assert all(u.is_verified for u in verified_users)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
||||||
|
async def test_user_api_key_hash_nullable(self, db_session):
|
||||||
|
"""Test that api_key_hash can be set to None."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="apinull",
|
||||||
|
email="apinull@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
api_key_hash="$2b$12$somehash",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Act: Remove API key
|
||||||
|
user.api_key_hash = None
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert user.api_key_hash is None
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models not implemented yet")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,604 @@
|
||||||
|
"""Unit tests for User model with Issue #3 enhancements.
|
||||||
|
|
||||||
|
Tests for User model fields including:
|
||||||
|
- tax_jurisdiction
|
||||||
|
- timezone
|
||||||
|
- api_key_hash
|
||||||
|
- is_verified
|
||||||
|
|
||||||
|
Follows TDD principles with comprehensive coverage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
from tradingagents.api.models.user import User
|
||||||
|
from tradingagents.api.services.auth_service import hash_password
|
||||||
|
from tradingagents.api.services.api_key_service import generate_api_key, hash_api_key
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserModelBasicFields:
|
||||||
|
"""Tests for basic User model fields."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_user_with_required_fields(self, db_session):
|
||||||
|
"""Should create user with only required fields."""
|
||||||
|
user = User(
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
hashed_password=hash_password("SecurePassword123!"),
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.id is not None
|
||||||
|
assert user.username == "testuser"
|
||||||
|
assert user.email == "test@example.com"
|
||||||
|
assert user.hashed_password is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_defaults(self, db_session):
|
||||||
|
"""Should apply default values to optional fields."""
|
||||||
|
user = User(
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
hashed_password=hash_password("SecurePassword123!"),
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Check defaults
|
||||||
|
assert user.is_active is True
|
||||||
|
assert user.is_superuser is False
|
||||||
|
assert user.tax_jurisdiction == "AU"
|
||||||
|
assert user.timezone == "Australia/Sydney"
|
||||||
|
assert user.is_verified is False
|
||||||
|
assert user.api_key_hash is None
|
||||||
|
assert user.full_name is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_username_unique_constraint(self, db_session):
|
||||||
|
"""Should enforce unique username constraint."""
|
||||||
|
user1 = User(
|
||||||
|
username="testuser",
|
||||||
|
email="test1@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
)
|
||||||
|
db_session.add(user1)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Try to create user with same username
|
||||||
|
user2 = User(
|
||||||
|
username="testuser",
|
||||||
|
email="test2@example.com",
|
||||||
|
hashed_password=hash_password("Password456!"),
|
||||||
|
)
|
||||||
|
db_session.add(user2)
|
||||||
|
|
||||||
|
with pytest.raises(Exception): # IntegrityError
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_email_unique_constraint(self, db_session):
|
||||||
|
"""Should enforce unique email constraint."""
|
||||||
|
user1 = User(
|
||||||
|
username="user1",
|
||||||
|
email="test@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
)
|
||||||
|
db_session.add(user1)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Try to create user with same email
|
||||||
|
user2 = User(
|
||||||
|
username="user2",
|
||||||
|
email="test@example.com",
|
||||||
|
hashed_password=hash_password("Password456!"),
|
||||||
|
)
|
||||||
|
db_session.add(user2)
|
||||||
|
|
||||||
|
with pytest.raises(Exception): # IntegrityError
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserModelTaxJurisdiction:
|
||||||
|
"""Tests for tax_jurisdiction field (Issue #3)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_us_jurisdiction(self, db_session):
|
||||||
|
"""Should set US tax jurisdiction."""
|
||||||
|
user = User(
|
||||||
|
username="ususer",
|
||||||
|
email="us@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
tax_jurisdiction="US",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.tax_jurisdiction == "US"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_us_state_jurisdiction(self, db_session):
|
||||||
|
"""Should set US state-level tax jurisdiction."""
|
||||||
|
user = User(
|
||||||
|
username="nyuser",
|
||||||
|
email="ny@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
tax_jurisdiction="US-NY",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.tax_jurisdiction == "US-NY"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_canadian_province_jurisdiction(self, db_session):
|
||||||
|
"""Should set Canadian province-level tax jurisdiction."""
|
||||||
|
user = User(
|
||||||
|
username="causer",
|
||||||
|
email="ca@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
tax_jurisdiction="CA-ON",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.tax_jurisdiction == "CA-ON"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_australian_state_jurisdiction(self, db_session):
|
||||||
|
"""Should set Australian state-level tax jurisdiction."""
|
||||||
|
user = User(
|
||||||
|
username="auuser",
|
||||||
|
email="au@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
tax_jurisdiction="AU-NSW",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.tax_jurisdiction == "AU-NSW"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tax_jurisdiction_default(self, db_session):
|
||||||
|
"""Should default to AU if not specified."""
|
||||||
|
user = User(
|
||||||
|
username="defaultuser",
|
||||||
|
email="default@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.tax_jurisdiction == "AU"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tax_jurisdiction_max_length(self, db_session):
|
||||||
|
"""Tax jurisdiction should not exceed 10 characters."""
|
||||||
|
# This should work (10 chars max)
|
||||||
|
user = User(
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
tax_jurisdiction="AU-NSW", # 6 chars
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.tax_jurisdiction == "AU-NSW"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserModelTimezone:
|
||||||
|
"""Tests for timezone field (Issue #3)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_us_timezone(self, db_session):
|
||||||
|
"""Should set US timezone."""
|
||||||
|
user = User(
|
||||||
|
username="nyuser",
|
||||||
|
email="ny@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
timezone="America/New_York",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.timezone == "America/New_York"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_utc_timezone(self, db_session):
|
||||||
|
"""Should set UTC timezone."""
|
||||||
|
user = User(
|
||||||
|
username="utcuser",
|
||||||
|
email="utc@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
timezone="UTC",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.timezone == "UTC"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_european_timezone(self, db_session):
|
||||||
|
"""Should set European timezone."""
|
||||||
|
user = User(
|
||||||
|
username="londonuser",
|
||||||
|
email="london@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
timezone="Europe/London",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.timezone == "Europe/London"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_asian_timezone(self, db_session):
|
||||||
|
"""Should set Asian timezone."""
|
||||||
|
user = User(
|
||||||
|
username="tokyouser",
|
||||||
|
email="tokyo@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
timezone="Asia/Tokyo",
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.timezone == "Asia/Tokyo"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timezone_default(self, db_session):
|
||||||
|
"""Should default to Australia/Sydney if not specified."""
|
||||||
|
user = User(
|
||||||
|
username="defaultuser",
|
||||||
|
email="default@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.timezone == "Australia/Sydney"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timezone_max_length(self, db_session):
|
||||||
|
"""Timezone should not exceed 50 characters."""
|
||||||
|
# Longest IANA timezone is ~40 characters
|
||||||
|
user = User(
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
timezone="America/Argentina/ComodRivadavia", # 35 chars
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.timezone == "America/Argentina/ComodRivadavia"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserModelApiKey:
|
||||||
|
"""Tests for api_key_hash field (Issue #3)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_without_api_key(self, db_session):
|
||||||
|
"""User without API key should have None api_key_hash."""
|
||||||
|
user = User(
|
||||||
|
username="nokey",
|
||||||
|
email="nokey@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.api_key_hash is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_with_api_key(self, db_session):
|
||||||
|
"""User with API key should store hashed key."""
|
||||||
|
api_key = generate_api_key()
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="withkey",
|
||||||
|
email="withkey@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
api_key_hash=hashed,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.api_key_hash is not None
|
||||||
|
assert user.api_key_hash == hashed
|
||||||
|
assert user.api_key_hash != api_key # Should be hash, not plain key
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_key_hash_unique_constraint(self, db_session):
|
||||||
|
"""Should enforce unique api_key_hash constraint."""
|
||||||
|
api_key = generate_api_key()
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
|
||||||
|
user1 = User(
|
||||||
|
username="user1",
|
||||||
|
email="user1@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
api_key_hash=hashed,
|
||||||
|
)
|
||||||
|
db_session.add(user1)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Try to create user with same api_key_hash
|
||||||
|
user2 = User(
|
||||||
|
username="user2",
|
||||||
|
email="user2@example.com",
|
||||||
|
hashed_password=hash_password("Password456!"),
|
||||||
|
api_key_hash=hashed,
|
||||||
|
)
|
||||||
|
db_session.add(user2)
|
||||||
|
|
||||||
|
with pytest.raises(Exception): # IntegrityError
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_key_hash_indexed(self, db_session):
|
||||||
|
"""api_key_hash should be indexed for fast lookups."""
|
||||||
|
# Create users with API keys
|
||||||
|
for i in range(10):
|
||||||
|
api_key = generate_api_key()
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username=f"user{i}",
|
||||||
|
email=f"user{i}@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
api_key_hash=hashed,
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Lookup by api_key_hash should work
|
||||||
|
api_key = generate_api_key()
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="lookup",
|
||||||
|
email="lookup@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
api_key_hash=hashed,
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Query by api_key_hash
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(User).where(User.api_key_hash == hashed)
|
||||||
|
)
|
||||||
|
found_user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
assert found_user is not None
|
||||||
|
assert found_user.username == "lookup"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_api_key(self, db_session):
|
||||||
|
"""User should be able to regenerate their API key."""
|
||||||
|
# Create user with API key
|
||||||
|
old_api_key = generate_api_key()
|
||||||
|
old_hash = hash_api_key(old_api_key)
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="regenerate",
|
||||||
|
email="regenerate@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
api_key_hash=old_hash,
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Regenerate API key
|
||||||
|
new_api_key = generate_api_key()
|
||||||
|
new_hash = hash_api_key(new_api_key)
|
||||||
|
|
||||||
|
user.api_key_hash = new_hash
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.api_key_hash == new_hash
|
||||||
|
assert user.api_key_hash != old_hash
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_api_key(self, db_session):
|
||||||
|
"""User should be able to revoke their API key."""
|
||||||
|
# Create user with API key
|
||||||
|
api_key = generate_api_key()
|
||||||
|
hashed = hash_api_key(api_key)
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="revoke",
|
||||||
|
email="revoke@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
api_key_hash=hashed,
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Revoke API key (set to None)
|
||||||
|
user.api_key_hash = None
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.api_key_hash is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserModelIsVerified:
|
||||||
|
"""Tests for is_verified field (Issue #3)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_unverified_by_default(self, db_session):
|
||||||
|
"""New users should be unverified by default."""
|
||||||
|
user = User(
|
||||||
|
username="unverified",
|
||||||
|
email="unverified@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.is_verified is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_verified_user(self, db_session):
|
||||||
|
"""Should be able to create verified user."""
|
||||||
|
user = User(
|
||||||
|
username="verified",
|
||||||
|
email="verified@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
is_verified=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.is_verified is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_user(self, db_session):
|
||||||
|
"""Should be able to verify user after creation."""
|
||||||
|
# Create unverified user
|
||||||
|
user = User(
|
||||||
|
username="toverify",
|
||||||
|
email="toverify@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
is_verified=False,
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.is_verified is False
|
||||||
|
|
||||||
|
# Verify user
|
||||||
|
user.is_verified = True
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.is_verified is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_verified_users(self, db_session):
|
||||||
|
"""Should be able to query only verified users."""
|
||||||
|
# Create mix of verified and unverified users
|
||||||
|
for i in range(5):
|
||||||
|
user = User(
|
||||||
|
username=f"user{i}",
|
||||||
|
email=f"user{i}@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
is_verified=(i % 2 == 0), # Alternate verified/unverified
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Query only verified users
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(User).where(User.is_verified == True)
|
||||||
|
)
|
||||||
|
verified_users = result.scalars().all()
|
||||||
|
|
||||||
|
assert len(verified_users) == 3 # users 0, 2, 4
|
||||||
|
for user in verified_users:
|
||||||
|
assert user.is_verified is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserModelComplete:
|
||||||
|
"""Tests for complete user creation with all Issue #3 fields."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_complete_user(self, db_session):
|
||||||
|
"""Should create user with all fields including Issue #3 additions."""
|
||||||
|
api_key = generate_api_key()
|
||||||
|
hashed_api_key = hash_api_key(api_key)
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="complete",
|
||||||
|
email="complete@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
full_name="Complete User",
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False,
|
||||||
|
tax_jurisdiction="US-NY",
|
||||||
|
timezone="America/New_York",
|
||||||
|
api_key_hash=hashed_api_key,
|
||||||
|
is_verified=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
assert user.id is not None
|
||||||
|
assert user.username == "complete"
|
||||||
|
assert user.email == "complete@example.com"
|
||||||
|
assert user.full_name == "Complete User"
|
||||||
|
assert user.is_active is True
|
||||||
|
assert user.is_superuser is False
|
||||||
|
assert user.tax_jurisdiction == "US-NY"
|
||||||
|
assert user.timezone == "America/New_York"
|
||||||
|
assert user.api_key_hash == hashed_api_key
|
||||||
|
assert user.is_verified is True
|
||||||
|
assert user.created_at is not None
|
||||||
|
assert user.updated_at is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_repr(self, db_session):
|
||||||
|
"""Should have meaningful string representation."""
|
||||||
|
user = User(
|
||||||
|
username="reprtest",
|
||||||
|
email="repr@example.com",
|
||||||
|
hashed_password=hash_password("Password123!"),
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
repr_str = repr(user)
|
||||||
|
assert "reprtest" in repr_str
|
||||||
|
assert "repr@example.com" in repr_str
|
||||||
|
assert str(user.id) in repr_str
|
||||||
|
|
@ -0,0 +1,678 @@
|
||||||
|
"""
|
||||||
|
Test suite for validators (Issue #3).
|
||||||
|
|
||||||
|
This module tests timezone and tax jurisdiction validation:
|
||||||
|
1. Timezone validation (IANA timezone database)
|
||||||
|
2. Tax jurisdiction validation (format and valid codes)
|
||||||
|
3. Edge cases and error handling
|
||||||
|
4. Integration with User model
|
||||||
|
|
||||||
|
Tests follow TDD - written before implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Unit Tests: Timezone Validation
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestTimezoneValidation:
|
||||||
|
"""Test timezone validation functionality."""
|
||||||
|
|
||||||
|
def test_validate_timezone_valid_utc(self):
|
||||||
|
"""Test validating UTC timezone."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
result = validate_timezone("UTC")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_timezone_valid_america_new_york(self):
|
||||||
|
"""Test validating America/New_York timezone."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
result = validate_timezone("America/New_York")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_timezone_valid_europe_london(self):
|
||||||
|
"""Test validating Europe/London timezone."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
result = validate_timezone("Europe/London")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_timezone_valid_asia_tokyo(self):
|
||||||
|
"""Test validating Asia/Tokyo timezone."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
result = validate_timezone("Asia/Tokyo")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_timezone_valid_australia_sydney(self):
|
||||||
|
"""Test validating Australia/Sydney timezone."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
result = validate_timezone("Australia/Sydney")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_timezone_invalid_fake_timezone(self):
|
||||||
|
"""Test rejecting invalid timezone."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
result = validate_timezone("Invalid/Timezone")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_timezone_invalid_empty_string(self):
|
||||||
|
"""Test rejecting empty string."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
result = validate_timezone("")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_timezone_none(self):
|
||||||
|
"""Test handling None value."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
# None should be accepted (nullable field)
|
||||||
|
result = validate_timezone(None)
|
||||||
|
|
||||||
|
# Assert: None is valid (nullable)
|
||||||
|
assert result is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_timezone_case_sensitive(self):
|
||||||
|
"""Test that timezone validation is case-sensitive."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
# Correct case
|
||||||
|
result_correct = validate_timezone("America/New_York")
|
||||||
|
|
||||||
|
# Wrong case
|
||||||
|
result_wrong = validate_timezone("america/new_york")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result_correct is True
|
||||||
|
assert result_wrong is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_timezone_various_valid_timezones(self):
|
||||||
|
"""Test validating various valid IANA timezones."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
valid_timezones = [
|
||||||
|
"UTC",
|
||||||
|
"GMT",
|
||||||
|
"America/New_York",
|
||||||
|
"America/Los_Angeles",
|
||||||
|
"America/Chicago",
|
||||||
|
"America/Denver",
|
||||||
|
"Europe/London",
|
||||||
|
"Europe/Paris",
|
||||||
|
"Europe/Berlin",
|
||||||
|
"Asia/Tokyo",
|
||||||
|
"Asia/Shanghai",
|
||||||
|
"Asia/Hong_Kong",
|
||||||
|
"Australia/Sydney",
|
||||||
|
"Australia/Melbourne",
|
||||||
|
"Pacific/Auckland",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
for tz in valid_timezones:
|
||||||
|
result = validate_timezone(tz)
|
||||||
|
assert result is True, f"Timezone {tz} should be valid"
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_timezone_various_invalid_timezones(self):
|
||||||
|
"""Test rejecting various invalid timezones."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
invalid_timezones = [
|
||||||
|
"PST", # Abbreviations not valid
|
||||||
|
"EST",
|
||||||
|
"CST",
|
||||||
|
"MST",
|
||||||
|
"America/InvalidCity",
|
||||||
|
"Europe/FakePlace",
|
||||||
|
"Random/Stuff",
|
||||||
|
"123456",
|
||||||
|
"!@#$%",
|
||||||
|
"america/new_york", # Wrong case
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
for tz in invalid_timezones:
|
||||||
|
result = validate_timezone(tz)
|
||||||
|
assert result is False, f"Timezone {tz} should be invalid"
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_timezone_with_underscores(self):
|
||||||
|
"""Test timezones with underscores in city names."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
# Valid timezones with underscores
|
||||||
|
result1 = validate_timezone("America/New_York")
|
||||||
|
result2 = validate_timezone("America/Los_Angeles")
|
||||||
|
result3 = validate_timezone("America/Port-au-Prince")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result1 is True
|
||||||
|
assert result2 is True
|
||||||
|
assert result3 is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Unit Tests: Tax Jurisdiction Validation
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestTaxJurisdictionValidation:
|
||||||
|
"""Test tax jurisdiction validation functionality."""
|
||||||
|
|
||||||
|
def test_validate_tax_jurisdiction_valid_us_state(self):
|
||||||
|
"""Test validating US state tax jurisdiction."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
result = validate_tax_jurisdiction("US-CA")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_tax_jurisdiction_valid_us_states(self):
|
||||||
|
"""Test validating various US state tax jurisdictions."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
valid_jurisdictions = [
|
||||||
|
"US-CA", # California
|
||||||
|
"US-NY", # New York
|
||||||
|
"US-TX", # Texas
|
||||||
|
"US-FL", # Florida
|
||||||
|
"US-IL", # Illinois
|
||||||
|
"US-PA", # Pennsylvania
|
||||||
|
"US-OH", # Ohio
|
||||||
|
"US-WA", # Washington
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
for jurisdiction in valid_jurisdictions:
|
||||||
|
result = validate_tax_jurisdiction(jurisdiction)
|
||||||
|
assert result is True, f"Jurisdiction {jurisdiction} should be valid"
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_tax_jurisdiction_valid_countries(self):
|
||||||
|
"""Test validating country-level tax jurisdictions."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
valid_jurisdictions = [
|
||||||
|
"US", # United States
|
||||||
|
"CA", # Canada
|
||||||
|
"GB", # United Kingdom
|
||||||
|
"DE", # Germany
|
||||||
|
"FR", # France
|
||||||
|
"JP", # Japan
|
||||||
|
"AU", # Australia
|
||||||
|
"NZ", # New Zealand
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
for jurisdiction in valid_jurisdictions:
|
||||||
|
result = validate_tax_jurisdiction(jurisdiction)
|
||||||
|
assert result is True, f"Jurisdiction {jurisdiction} should be valid"
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_tax_jurisdiction_valid_canadian_provinces(self):
|
||||||
|
"""Test validating Canadian province tax jurisdictions."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
valid_jurisdictions = [
|
||||||
|
"CA-ON", # Ontario
|
||||||
|
"CA-QC", # Quebec
|
||||||
|
"CA-BC", # British Columbia
|
||||||
|
"CA-AB", # Alberta
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
for jurisdiction in valid_jurisdictions:
|
||||||
|
result = validate_tax_jurisdiction(jurisdiction)
|
||||||
|
assert result is True, f"Jurisdiction {jurisdiction} should be valid"
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_tax_jurisdiction_invalid_format(self):
|
||||||
|
"""Test rejecting invalid format."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
invalid_jurisdictions = [
|
||||||
|
"InvalidFormat",
|
||||||
|
"US_CA", # Wrong separator
|
||||||
|
"US/CA", # Wrong separator
|
||||||
|
"USCA", # No separator
|
||||||
|
"us-ca", # Lowercase
|
||||||
|
"123",
|
||||||
|
"!@#",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
for jurisdiction in invalid_jurisdictions:
|
||||||
|
result = validate_tax_jurisdiction(jurisdiction)
|
||||||
|
assert result is False, f"Jurisdiction {jurisdiction} should be invalid"
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_tax_jurisdiction_invalid_country_code(self):
|
||||||
|
"""Test rejecting invalid country codes."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
result = validate_tax_jurisdiction("XX-YY")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_tax_jurisdiction_none(self):
|
||||||
|
"""Test handling None value."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
# None should be accepted (nullable field)
|
||||||
|
result = validate_tax_jurisdiction(None)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_tax_jurisdiction_empty_string(self):
|
||||||
|
"""Test rejecting empty string."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
result = validate_tax_jurisdiction("")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_tax_jurisdiction_case_sensitive(self):
|
||||||
|
"""Test that tax jurisdiction validation is case-sensitive."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
# Correct case (uppercase)
|
||||||
|
result_correct = validate_tax_jurisdiction("US-CA")
|
||||||
|
|
||||||
|
# Wrong case (lowercase)
|
||||||
|
result_wrong = validate_tax_jurisdiction("us-ca")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result_correct is True
|
||||||
|
assert result_wrong is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_validate_tax_jurisdiction_max_length(self):
|
||||||
|
"""Test rejecting very long jurisdiction strings."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
# Very long string
|
||||||
|
result = validate_tax_jurisdiction("US-" + "A" * 100)
|
||||||
|
|
||||||
|
# Assert: Should reject overly long jurisdictions
|
||||||
|
assert result is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Integration Tests: Validators with User Model
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestValidatorsIntegration:
|
||||||
|
"""Test validators integrated with User model."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_with_valid_timezone(self, db_session):
|
||||||
|
"""Test creating user with validated timezone."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
timezone = "America/New_York"
|
||||||
|
assert validate_timezone(timezone) is True
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="tzvaliduser",
|
||||||
|
email="tzvalid@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
timezone=timezone,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert user.timezone == timezone
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models or validators not implemented yet")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_with_invalid_timezone_at_api_level(self, db_session):
|
||||||
|
"""Test that invalid timezone should be caught at API level, not DB."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
invalid_timezone = "Invalid/Timezone"
|
||||||
|
assert validate_timezone(invalid_timezone) is False
|
||||||
|
|
||||||
|
# Note: Database will accept it, validation happens at API layer
|
||||||
|
user = User(
|
||||||
|
username="tzinvalid",
|
||||||
|
email="tzinvalid@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
timezone=invalid_timezone,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act: DB should accept it (validation is at API level)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert: DB accepts it, but validator rejects it
|
||||||
|
assert user.timezone == invalid_timezone
|
||||||
|
assert validate_timezone(user.timezone) is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models or validators not implemented yet")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_with_valid_tax_jurisdiction(self, db_session):
|
||||||
|
"""Test creating user with validated tax jurisdiction."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
jurisdiction = "US-CA"
|
||||||
|
assert validate_tax_jurisdiction(jurisdiction) is True
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="taxvaliduser",
|
||||||
|
email="taxvalid@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
tax_jurisdiction=jurisdiction,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert user.tax_jurisdiction == jurisdiction
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models or validators not implemented yet")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_with_both_validators(self, db_session):
|
||||||
|
"""Test creating user with both validated fields."""
|
||||||
|
# Arrange
|
||||||
|
try:
|
||||||
|
from tradingagents.api.models import User
|
||||||
|
from tradingagents.api.services.validators import (
|
||||||
|
validate_timezone,
|
||||||
|
validate_tax_jurisdiction,
|
||||||
|
)
|
||||||
|
|
||||||
|
timezone = "America/Los_Angeles"
|
||||||
|
jurisdiction = "US-CA"
|
||||||
|
|
||||||
|
assert validate_timezone(timezone) is True
|
||||||
|
assert validate_tax_jurisdiction(jurisdiction) is True
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username="bothvalid",
|
||||||
|
email="bothvalid@example.com",
|
||||||
|
hashed_password="hash",
|
||||||
|
timezone=timezone,
|
||||||
|
tax_jurisdiction=jurisdiction,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert user.timezone == timezone
|
||||||
|
assert user.tax_jurisdiction == jurisdiction
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Models or validators not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Edge Cases: Validators
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestValidatorEdgeCases:
|
||||||
|
"""Test edge cases in validators."""
|
||||||
|
|
||||||
|
def test_timezone_with_special_characters(self):
|
||||||
|
"""Test timezone with special characters."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
# Test various special characters
|
||||||
|
result1 = validate_timezone("America/Port-au-Prince") # Hyphen
|
||||||
|
result2 = validate_timezone("America/Indiana/Indianapolis") # Multiple slashes
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result1 is True
|
||||||
|
assert result2 is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_timezone_whitespace_handling(self):
|
||||||
|
"""Test timezone validation with whitespace."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
# Timezones with leading/trailing whitespace
|
||||||
|
result1 = validate_timezone(" America/New_York ")
|
||||||
|
result2 = validate_timezone("America/New_York ")
|
||||||
|
result3 = validate_timezone(" America/New_York")
|
||||||
|
|
||||||
|
# Assert: Should reject (strict validation)
|
||||||
|
assert result1 is False
|
||||||
|
assert result2 is False
|
||||||
|
assert result3 is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_tax_jurisdiction_whitespace_handling(self):
|
||||||
|
"""Test tax jurisdiction validation with whitespace."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
# Jurisdictions with leading/trailing whitespace
|
||||||
|
result1 = validate_tax_jurisdiction(" US-CA ")
|
||||||
|
result2 = validate_tax_jurisdiction("US-CA ")
|
||||||
|
result3 = validate_tax_jurisdiction(" US-CA")
|
||||||
|
|
||||||
|
# Assert: Should reject (strict validation)
|
||||||
|
assert result1 is False
|
||||||
|
assert result2 is False
|
||||||
|
assert result3 is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_timezone_numeric_string(self):
|
||||||
|
"""Test timezone validation with numeric strings."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
result = validate_timezone("12345")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_tax_jurisdiction_only_country(self):
|
||||||
|
"""Test tax jurisdiction with only country code."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
# Two-letter country codes should be valid
|
||||||
|
result_us = validate_tax_jurisdiction("US")
|
||||||
|
result_ca = validate_tax_jurisdiction("CA")
|
||||||
|
result_gb = validate_tax_jurisdiction("GB")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result_us is True
|
||||||
|
assert result_ca is True
|
||||||
|
assert result_gb is True
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_tax_jurisdiction_single_letter(self):
|
||||||
|
"""Test tax jurisdiction with single letter."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
result = validate_tax_jurisdiction("A")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_timezone_sql_injection_attempt(self):
|
||||||
|
"""Test timezone validation against SQL injection."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_timezone
|
||||||
|
|
||||||
|
malicious_inputs = [
|
||||||
|
"'; DROP TABLE users; --",
|
||||||
|
"1' OR '1'='1",
|
||||||
|
"America/New_York'; DELETE FROM users; --",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
for malicious in malicious_inputs:
|
||||||
|
result = validate_timezone(malicious)
|
||||||
|
assert result is False, f"Should reject malicious input: {malicious}"
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
||||||
|
def test_tax_jurisdiction_sql_injection_attempt(self):
|
||||||
|
"""Test tax jurisdiction validation against SQL injection."""
|
||||||
|
# Arrange & Act
|
||||||
|
try:
|
||||||
|
from tradingagents.api.services.validators import validate_tax_jurisdiction
|
||||||
|
|
||||||
|
malicious_inputs = [
|
||||||
|
"'; DROP TABLE users; --",
|
||||||
|
"1' OR '1'='1",
|
||||||
|
"US-CA'; DELETE FROM users; --",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
for malicious in malicious_inputs:
|
||||||
|
result = validate_tax_jurisdiction(malicious)
|
||||||
|
assert result is False, f"Should reject malicious input: {malicious}"
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Validators not implemented yet")
|
||||||
|
|
@ -8,18 +8,63 @@ from tradingagents.api.models.base import Base, TimestampMixin
|
||||||
|
|
||||||
|
|
||||||
class User(Base, TimestampMixin):
|
class User(Base, TimestampMixin):
|
||||||
"""User model for authentication and authorization."""
|
"""User model for authentication and authorization.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Primary key
|
||||||
|
username: Unique username for authentication
|
||||||
|
email: Unique email address
|
||||||
|
hashed_password: Bcrypt hashed password
|
||||||
|
full_name: Optional full name
|
||||||
|
is_active: Whether user account is active
|
||||||
|
is_superuser: Whether user has admin privileges
|
||||||
|
tax_jurisdiction: Tax jurisdiction code (e.g., "US", "US-CA", "AU")
|
||||||
|
timezone: IANA timezone identifier (e.g., "America/New_York", "UTC")
|
||||||
|
api_key_hash: Bcrypt hash of API key (if user has API key)
|
||||||
|
is_verified: Whether user email is verified
|
||||||
|
strategies: Related Strategy objects owned by this user
|
||||||
|
"""
|
||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
# Primary identification
|
||||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||||
username: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
username: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
full_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
full_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||||
|
|
||||||
|
# User status and permissions
|
||||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||||
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
|
|
||||||
|
# Issue #3: Profile fields
|
||||||
|
tax_jurisdiction: Mapped[str] = mapped_column(
|
||||||
|
String(10),
|
||||||
|
default="AU",
|
||||||
|
nullable=False,
|
||||||
|
comment="Tax jurisdiction code (e.g., US, US-CA, AU-NSW)"
|
||||||
|
)
|
||||||
|
timezone: Mapped[str] = mapped_column(
|
||||||
|
String(50),
|
||||||
|
default="Australia/Sydney",
|
||||||
|
nullable=False,
|
||||||
|
comment="IANA timezone identifier (e.g., America/New_York, UTC)"
|
||||||
|
)
|
||||||
|
api_key_hash: Mapped[Optional[str]] = mapped_column(
|
||||||
|
String(255),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
unique=True,
|
||||||
|
comment="Bcrypt hash of API key for programmatic access"
|
||||||
|
)
|
||||||
|
is_verified: Mapped[bool] = mapped_column(
|
||||||
|
Boolean,
|
||||||
|
default=False,
|
||||||
|
nullable=False,
|
||||||
|
comment="Whether user email has been verified"
|
||||||
|
)
|
||||||
|
|
||||||
# Relationship to strategies
|
# Relationship to strategies
|
||||||
strategies: Mapped[List["Strategy"]] = relationship(
|
strategies: Mapped[List["Strategy"]] = relationship(
|
||||||
"Strategy",
|
"Strategy",
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,31 @@ from tradingagents.api.services.auth_service import (
|
||||||
create_access_token,
|
create_access_token,
|
||||||
decode_access_token,
|
decode_access_token,
|
||||||
)
|
)
|
||||||
|
from tradingagents.api.services.api_key_service import (
|
||||||
|
generate_api_key,
|
||||||
|
hash_api_key,
|
||||||
|
verify_api_key,
|
||||||
|
)
|
||||||
|
from tradingagents.api.services.validators import (
|
||||||
|
validate_timezone,
|
||||||
|
validate_tax_jurisdiction,
|
||||||
|
get_available_timezones,
|
||||||
|
get_available_tax_jurisdictions,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Auth service
|
||||||
"hash_password",
|
"hash_password",
|
||||||
"verify_password",
|
"verify_password",
|
||||||
"create_access_token",
|
"create_access_token",
|
||||||
"decode_access_token",
|
"decode_access_token",
|
||||||
|
# API key service
|
||||||
|
"generate_api_key",
|
||||||
|
"hash_api_key",
|
||||||
|
"verify_api_key",
|
||||||
|
# Validators
|
||||||
|
"validate_timezone",
|
||||||
|
"validate_tax_jurisdiction",
|
||||||
|
"get_available_timezones",
|
||||||
|
"get_available_tax_jurisdictions",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,115 @@
|
||||||
|
"""API key service for secure key generation and hashing.
|
||||||
|
|
||||||
|
This module provides utilities for generating and verifying API keys:
|
||||||
|
- Generate secure random API keys with 'ta_' prefix
|
||||||
|
- Hash API keys using bcrypt (via pwdlib)
|
||||||
|
- Verify plain API keys against hashed values
|
||||||
|
|
||||||
|
Security:
|
||||||
|
- Never store plain API keys in the database
|
||||||
|
- Use bcrypt for hashing (via pwdlib PasswordHash)
|
||||||
|
- API keys are URL-safe base64 encoded (32 bytes)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
from pwdlib import PasswordHash
|
||||||
|
|
||||||
|
|
||||||
|
# API key hashing with bcrypt (same context as passwords for consistency)
|
||||||
|
api_key_context = PasswordHash.recommended()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_api_key() -> str:
|
||||||
|
"""
|
||||||
|
Generate a secure random API key.
|
||||||
|
|
||||||
|
Returns a URL-safe API key with the 'ta_' prefix followed by
|
||||||
|
32 bytes of random data encoded as base64.
|
||||||
|
|
||||||
|
Format: ta_<base64_url_safe_32_bytes>
|
||||||
|
Example: ta_vK9x8pL2mN3qR5sT7uW1yZ4aB6cD8eF0gH2jK4lM6n
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Generated API key (plaintext)
|
||||||
|
|
||||||
|
Security:
|
||||||
|
- Uses secrets.token_urlsafe() for cryptographically strong randomness
|
||||||
|
- 32 bytes = 256 bits of entropy
|
||||||
|
- Never store the returned value directly in database
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> api_key = generate_api_key()
|
||||||
|
>>> api_key.startswith("ta_")
|
||||||
|
True
|
||||||
|
>>> len(api_key) > 40 # ta_ + base64(32 bytes)
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
# Generate 32 bytes (256 bits) of cryptographically secure random data
|
||||||
|
# URL-safe base64 encoding makes it safe for URLs and headers
|
||||||
|
random_part = secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
return f"ta_{random_part}"
|
||||||
|
|
||||||
|
|
||||||
|
def hash_api_key(api_key: str) -> str:
|
||||||
|
"""
|
||||||
|
Hash an API key using bcrypt.
|
||||||
|
|
||||||
|
Uses the same pwdlib PasswordHash context as password hashing
|
||||||
|
for consistency. The hashed value can be safely stored in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Plain text API key (from generate_api_key())
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Bcrypt hash of the API key
|
||||||
|
|
||||||
|
Security:
|
||||||
|
- Uses bcrypt algorithm (via Argon2 default from pwdlib)
|
||||||
|
- Hash is one-way and computationally expensive to reverse
|
||||||
|
- Store this hash in database, not the plain API key
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> api_key = generate_api_key()
|
||||||
|
>>> hashed = hash_api_key(api_key)
|
||||||
|
>>> hashed != api_key # Hash is different from plain key
|
||||||
|
True
|
||||||
|
>>> len(hashed) > 50 # Bcrypt hashes are long
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
return api_key_context.hash(api_key)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_api_key(plain_api_key: str, hashed_api_key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Verify a plain API key against a hash.
|
||||||
|
|
||||||
|
Checks if the provided plain API key matches the stored hash.
|
||||||
|
Uses constant-time comparison to prevent timing attacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plain_api_key: Plain text API key (from user request)
|
||||||
|
hashed_api_key: Hashed API key (from database)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if API key matches hash, False otherwise
|
||||||
|
|
||||||
|
Security:
|
||||||
|
- Uses constant-time comparison
|
||||||
|
- Safe against timing attacks
|
||||||
|
- Computationally expensive to slow down brute force
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> api_key = generate_api_key()
|
||||||
|
>>> hashed = hash_api_key(api_key)
|
||||||
|
>>> verify_api_key(api_key, hashed)
|
||||||
|
True
|
||||||
|
>>> verify_api_key("wrong_key", hashed)
|
||||||
|
False
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return api_key_context.verify(plain_api_key, hashed_api_key)
|
||||||
|
except Exception:
|
||||||
|
# If verification fails for any reason (malformed hash, etc.)
|
||||||
|
# return False rather than raising an exception
|
||||||
|
return False
|
||||||
|
|
@ -0,0 +1,303 @@
|
||||||
|
"""Validators for user profile fields.
|
||||||
|
|
||||||
|
This module provides validation functions for:
|
||||||
|
- Timezones (IANA timezone database)
|
||||||
|
- Tax jurisdictions (country codes and state/province codes)
|
||||||
|
|
||||||
|
All validators return True/False and are designed to be used
|
||||||
|
in Pydantic models and database constraints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Set
|
||||||
|
from zoneinfo import ZoneInfo, available_timezones
|
||||||
|
|
||||||
|
|
||||||
|
# Valid tax jurisdictions (ISO 3166-1 alpha-2 country codes + state/province)
|
||||||
|
# Format: "CC" for country-level, "CC-SS" for state/province-level
|
||||||
|
# This is a comprehensive list covering major jurisdictions
|
||||||
|
VALID_TAX_JURISDICTIONS: Set[str] = {
|
||||||
|
# Country-level codes (ISO 3166-1 alpha-2)
|
||||||
|
"US", # United States
|
||||||
|
"CA", # Canada
|
||||||
|
"GB", # United Kingdom
|
||||||
|
"AU", # Australia
|
||||||
|
"DE", # Germany
|
||||||
|
"FR", # France
|
||||||
|
"IT", # Italy
|
||||||
|
"ES", # Spain
|
||||||
|
"NL", # Netherlands
|
||||||
|
"BE", # Belgium
|
||||||
|
"CH", # Switzerland
|
||||||
|
"AT", # Austria
|
||||||
|
"SE", # Sweden
|
||||||
|
"NO", # Norway
|
||||||
|
"DK", # Denmark
|
||||||
|
"FI", # Finland
|
||||||
|
"IE", # Ireland
|
||||||
|
"PT", # Portugal
|
||||||
|
"GR", # Greece
|
||||||
|
"PL", # Poland
|
||||||
|
"CZ", # Czech Republic
|
||||||
|
"HU", # Hungary
|
||||||
|
"RO", # Romania
|
||||||
|
"JP", # Japan
|
||||||
|
"CN", # China
|
||||||
|
"KR", # South Korea
|
||||||
|
"IN", # India
|
||||||
|
"SG", # Singapore
|
||||||
|
"HK", # Hong Kong
|
||||||
|
"NZ", # New Zealand
|
||||||
|
"MX", # Mexico
|
||||||
|
"BR", # Brazil
|
||||||
|
"AR", # Argentina
|
||||||
|
"CL", # Chile
|
||||||
|
"ZA", # South Africa
|
||||||
|
"AE", # United Arab Emirates
|
||||||
|
"SA", # Saudi Arabia
|
||||||
|
"IL", # Israel
|
||||||
|
"TR", # Turkey
|
||||||
|
"RU", # Russia
|
||||||
|
"UA", # Ukraine
|
||||||
|
"TH", # Thailand
|
||||||
|
"MY", # Malaysia
|
||||||
|
"ID", # Indonesia
|
||||||
|
"PH", # Philippines
|
||||||
|
"VN", # Vietnam
|
||||||
|
"TW", # Taiwan
|
||||||
|
|
||||||
|
# United States - State level
|
||||||
|
"US-AL", # Alabama
|
||||||
|
"US-AK", # Alaska
|
||||||
|
"US-AZ", # Arizona
|
||||||
|
"US-AR", # Arkansas
|
||||||
|
"US-CA", # California
|
||||||
|
"US-CO", # Colorado
|
||||||
|
"US-CT", # Connecticut
|
||||||
|
"US-DE", # Delaware
|
||||||
|
"US-FL", # Florida
|
||||||
|
"US-GA", # Georgia
|
||||||
|
"US-HI", # Hawaii
|
||||||
|
"US-ID", # Idaho
|
||||||
|
"US-IL", # Illinois
|
||||||
|
"US-IN", # Indiana
|
||||||
|
"US-IA", # Iowa
|
||||||
|
"US-KS", # Kansas
|
||||||
|
"US-KY", # Kentucky
|
||||||
|
"US-LA", # Louisiana
|
||||||
|
"US-ME", # Maine
|
||||||
|
"US-MD", # Maryland
|
||||||
|
"US-MA", # Massachusetts
|
||||||
|
"US-MI", # Michigan
|
||||||
|
"US-MN", # Minnesota
|
||||||
|
"US-MS", # Mississippi
|
||||||
|
"US-MO", # Missouri
|
||||||
|
"US-MT", # Montana
|
||||||
|
"US-NE", # Nebraska
|
||||||
|
"US-NV", # Nevada
|
||||||
|
"US-NH", # New Hampshire
|
||||||
|
"US-NJ", # New Jersey
|
||||||
|
"US-NM", # New Mexico
|
||||||
|
"US-NY", # New York
|
||||||
|
"US-NC", # North Carolina
|
||||||
|
"US-ND", # North Dakota
|
||||||
|
"US-OH", # Ohio
|
||||||
|
"US-OK", # Oklahoma
|
||||||
|
"US-OR", # Oregon
|
||||||
|
"US-PA", # Pennsylvania
|
||||||
|
"US-RI", # Rhode Island
|
||||||
|
"US-SC", # South Carolina
|
||||||
|
"US-SD", # South Dakota
|
||||||
|
"US-TN", # Tennessee
|
||||||
|
"US-TX", # Texas
|
||||||
|
"US-UT", # Utah
|
||||||
|
"US-VT", # Vermont
|
||||||
|
"US-VA", # Virginia
|
||||||
|
"US-WA", # Washington
|
||||||
|
"US-WV", # West Virginia
|
||||||
|
"US-WI", # Wisconsin
|
||||||
|
"US-WY", # Wyoming
|
||||||
|
"US-DC", # District of Columbia
|
||||||
|
|
||||||
|
# Canada - Province/Territory level
|
||||||
|
"CA-AB", # Alberta
|
||||||
|
"CA-BC", # British Columbia
|
||||||
|
"CA-MB", # Manitoba
|
||||||
|
"CA-NB", # New Brunswick
|
||||||
|
"CA-NL", # Newfoundland and Labrador
|
||||||
|
"CA-NS", # Nova Scotia
|
||||||
|
"CA-NT", # Northwest Territories
|
||||||
|
"CA-NU", # Nunavut
|
||||||
|
"CA-ON", # Ontario
|
||||||
|
"CA-PE", # Prince Edward Island
|
||||||
|
"CA-QC", # Quebec
|
||||||
|
"CA-SK", # Saskatchewan
|
||||||
|
"CA-YT", # Yukon
|
||||||
|
|
||||||
|
# Australia - State/Territory level
|
||||||
|
"AU-NSW", # New South Wales
|
||||||
|
"AU-VIC", # Victoria
|
||||||
|
"AU-QLD", # Queensland
|
||||||
|
"AU-SA", # South Australia
|
||||||
|
"AU-WA", # Western Australia
|
||||||
|
"AU-TAS", # Tasmania
|
||||||
|
"AU-NT", # Northern Territory
|
||||||
|
"AU-ACT", # Australian Capital Territory
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def validate_timezone(timezone: str) -> bool:
|
||||||
|
"""
|
||||||
|
Validate timezone against IANA timezone database.
|
||||||
|
|
||||||
|
Checks if the provided timezone string is a valid IANA timezone
|
||||||
|
identifier. Uses Python's zoneinfo module which is based on the
|
||||||
|
IANA timezone database (tzdata).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timezone: Timezone identifier (e.g., "America/New_York", "UTC")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if valid IANA timezone, False otherwise
|
||||||
|
|
||||||
|
Valid Examples:
|
||||||
|
- "UTC"
|
||||||
|
- "GMT"
|
||||||
|
- "America/New_York"
|
||||||
|
- "Europe/London"
|
||||||
|
- "Asia/Tokyo"
|
||||||
|
- "Australia/Sydney"
|
||||||
|
|
||||||
|
Invalid Examples:
|
||||||
|
- "PST" (abbreviation, not IANA identifier)
|
||||||
|
- "EST" (abbreviation)
|
||||||
|
- "New York" (wrong format)
|
||||||
|
- "america/new_york" (wrong case)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> validate_timezone("America/New_York")
|
||||||
|
True
|
||||||
|
>>> validate_timezone("UTC")
|
||||||
|
True
|
||||||
|
>>> validate_timezone("PST")
|
||||||
|
False
|
||||||
|
>>> validate_timezone("Invalid/Zone")
|
||||||
|
False
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Case-sensitive (must match IANA database exactly)
|
||||||
|
- Use available_timezones() to get full list of valid zones
|
||||||
|
- Rejects timezone abbreviations (PST, EST, etc.)
|
||||||
|
"""
|
||||||
|
if not timezone or not isinstance(timezone, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if timezone exists in IANA database
|
||||||
|
# This is more efficient than trying to create a ZoneInfo object
|
||||||
|
return timezone in available_timezones()
|
||||||
|
|
||||||
|
|
||||||
|
def validate_tax_jurisdiction(jurisdiction: str) -> bool:
|
||||||
|
"""
|
||||||
|
Validate tax jurisdiction code.
|
||||||
|
|
||||||
|
Checks if the provided jurisdiction is in the list of valid
|
||||||
|
tax jurisdictions. Supports both country-level and state/province-level
|
||||||
|
jurisdictions.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
- Country level: "CC" (2-letter ISO 3166-1 alpha-2)
|
||||||
|
- State/Province level: "CC-SS" (country-state with hyphen)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
jurisdiction: Tax jurisdiction code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if valid jurisdiction, False otherwise
|
||||||
|
|
||||||
|
Valid Examples:
|
||||||
|
- "US" (United States)
|
||||||
|
- "CA" (Canada)
|
||||||
|
- "GB" (United Kingdom)
|
||||||
|
- "US-CA" (California, USA)
|
||||||
|
- "US-NY" (New York, USA)
|
||||||
|
- "CA-ON" (Ontario, Canada)
|
||||||
|
- "AU-NSW" (New South Wales, Australia)
|
||||||
|
|
||||||
|
Invalid Examples:
|
||||||
|
- "us" (lowercase)
|
||||||
|
- "USA" (3 letters)
|
||||||
|
- "US_CA" (underscore instead of hyphen)
|
||||||
|
- "US/CA" (slash instead of hyphen)
|
||||||
|
- "XX" (non-existent country)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> validate_tax_jurisdiction("US")
|
||||||
|
True
|
||||||
|
>>> validate_tax_jurisdiction("US-CA")
|
||||||
|
True
|
||||||
|
>>> validate_tax_jurisdiction("us")
|
||||||
|
False
|
||||||
|
>>> validate_tax_jurisdiction("XX-YY")
|
||||||
|
False
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Case-sensitive (must be uppercase)
|
||||||
|
- Hyphen separator for state/province codes
|
||||||
|
- List is comprehensive but not exhaustive
|
||||||
|
- Add new jurisdictions to VALID_TAX_JURISDICTIONS set as needed
|
||||||
|
"""
|
||||||
|
if not jurisdiction or not isinstance(jurisdiction, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return jurisdiction in VALID_TAX_JURISDICTIONS
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_timezones() -> Set[str]:
|
||||||
|
"""
|
||||||
|
Get set of all available IANA timezones.
|
||||||
|
|
||||||
|
Returns the complete set of valid timezone identifiers from
|
||||||
|
the IANA timezone database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set[str]: Set of valid timezone identifiers
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> timezones = get_available_timezones()
|
||||||
|
>>> "America/New_York" in timezones
|
||||||
|
True
|
||||||
|
>>> len(timezones) > 500 # Hundreds of valid timezones
|
||||||
|
True
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- This is a cached call (zoneinfo caches available_timezones)
|
||||||
|
- Use for populating dropdowns or validation lists
|
||||||
|
- Contains all IANA timezone database entries
|
||||||
|
"""
|
||||||
|
return available_timezones()
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_tax_jurisdictions() -> Set[str]:
|
||||||
|
"""
|
||||||
|
Get set of all available tax jurisdictions.
|
||||||
|
|
||||||
|
Returns the complete set of valid tax jurisdiction codes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set[str]: Set of valid tax jurisdiction codes
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> jurisdictions = get_available_tax_jurisdictions()
|
||||||
|
>>> "US" in jurisdictions
|
||||||
|
True
|
||||||
|
>>> "US-CA" in jurisdictions
|
||||||
|
True
|
||||||
|
>>> len(jurisdictions) > 50 # Many jurisdictions supported
|
||||||
|
True
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Returns a copy to prevent external modification
|
||||||
|
- Use for populating dropdowns or validation lists
|
||||||
|
- Includes both country and state/province level codes
|
||||||
|
"""
|
||||||
|
return VALID_TAX_JURISDICTIONS.copy()
|
||||||
Loading…
Reference in New Issue