From d3892b0da9112b6224fa5de345a7ebc3142be22c Mon Sep 17 00:00:00 2001 From: Andrew Kaszubski Date: Fri, 26 Dec 2025 13:15:37 +1100 Subject: [PATCH] feat(db): add User model fields for tax, timezone, API key - Fixes #3 --- CHANGELOG.md | 17 + .../versions/002_add_user_profile_fields.py | 115 +++ tests/api/conftest.py | 229 ++++++ tests/api/test_api_key_service.py | 587 +++++++++++++++ tests/api/test_models.py | 461 +++++++++++- tests/api/test_user_model.py | 604 ++++++++++++++++ tests/api/test_validators.py | 678 ++++++++++++++++++ tradingagents/api/models/user.py | 47 +- tradingagents/api/services/__init__.py | 21 + tradingagents/api/services/api_key_service.py | 115 +++ tradingagents/api/services/validators.py | 303 ++++++++ 11 files changed, 3170 insertions(+), 7 deletions(-) create mode 100644 migrations/versions/002_add_user_profile_fields.py create mode 100644 tests/api/test_api_key_service.py create mode 100644 tests/api/test_user_model.py create mode 100644 tests/api/test_validators.py create mode 100644 tradingagents/api/services/api_key_service.py create mode 100644 tradingagents/api/services/validators.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c5053361..c2ce2a40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - New dependencies in pyproject.toml: fastapi, uvicorn, sqlalchemy, alembic, pydantic-settings, passlib, argon2-cffi, python-multipart, python-jose, cryptography - API documentation generated from FastAPI OpenAPI schema (available at /docs and /redoc) +- User model enhancement with profile and API key management (Issue #3) + - Extended User model with tax_jurisdiction and timezone fields [file:tradingagents/api/models/user.py:47-54](tradingagents/api/models/user.py) + - Tax jurisdiction field supporting country (e.g., "US", "AU") and state/province level codes (e.g., "US-CA", "AU-NSW") + - IANA timezone identifier field (e.g., "America/New_York", "Australia/Sydney") with automatic validation + - Email verification status tracking via is_verified boolean field [file:tradingagents/api/models/user.py:60-64](tradingagents/api/models/user.py) + - Secure API key management with bcrypt hashing and unique constraints [file:tradingagents/api/models/user.py:55-59](tradingagents/api/models/user.py) + - API key service module with generate_api_key(), hash_api_key(), and verify_api_key() functions [file:tradingagents/api/services/api_key_service.py](tradingagents/api/services/api_key_service.py) + - API key generation using secrets.token_urlsafe() with 256-bit entropy and 'ta_' prefix + - Bcrypt-based API key hashing using pwdlib.PasswordHash for secure storage + - Constant-time verification to prevent timing attacks on API keys + - Timezone validator using IANA zoneinfo database [file:tradingagents/api/services/validators.py:134-191](tradingagents/api/services/validators.py) + - Tax jurisdiction validator supporting 50+ country codes and state/province subdivisions [file:tradingagents/api/services/validators.py:193-283](tradingagents/api/services/validators.py) + - Utility functions get_available_timezones() and get_available_tax_jurisdictions() for UI dropdowns [file:tradingagents/api/services/validators.py:285-333](tradingagents/api/services/validators.py) + - Database migration 002_add_user_profile_fields.py with proper defaults and constraints [file:migrations/versions/002_add_user_profile_fields.py](migrations/versions/002_add_user_profile_fields.py) + - Migration rollback support for reversible schema changes + - Comprehensive docstrings and security considerations for all functions + - Test fixtures directory with centralized mock data (Issue #51) - FixtureLoader class for loading JSON fixtures with automatic datetime parsing [file:tests/fixtures/__init__.py](tests/fixtures/__init__.py) - Stock data fixtures: US market OHLCV, Chinese market OHLCV, standardized data [file:tests/fixtures/stock_data/](tests/fixtures/stock_data/) diff --git a/migrations/versions/002_add_user_profile_fields.py b/migrations/versions/002_add_user_profile_fields.py new file mode 100644 index 00000000..1dc908f0 --- /dev/null +++ b/migrations/versions/002_add_user_profile_fields.py @@ -0,0 +1,115 @@ +"""Add user profile fields - tax_jurisdiction, timezone, api_key_hash, is_verified + +Revision ID: 002 +Revises: 001 +Create Date: 2025-12-26 13:00:00.000000 + +This migration adds four new fields to the users table: +- tax_jurisdiction: Tax jurisdiction code (default: AU) +- timezone: IANA timezone identifier (default: Australia/Sydney) +- api_key_hash: Bcrypt hash of API key for programmatic access (nullable) +- is_verified: Email verification status (default: False) + +All existing users will get default values for the new required fields. +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '002' +down_revision: Union[str, None] = '001' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add tax_jurisdiction, timezone, api_key_hash, and is_verified columns to users table. + + For existing rows: + - tax_jurisdiction defaults to "AU" + - timezone defaults to "Australia/Sydney" + - api_key_hash is NULL + - is_verified is False + """ + # Add tax_jurisdiction column + op.add_column( + 'users', + sa.Column( + 'tax_jurisdiction', + sa.String(length=10), + nullable=False, + server_default='AU', + comment='Tax jurisdiction code (e.g., US, US-CA, AU-NSW)' + ) + ) + + # Add timezone column + op.add_column( + 'users', + sa.Column( + 'timezone', + sa.String(length=50), + nullable=False, + server_default='Australia/Sydney', + comment='IANA timezone identifier (e.g., America/New_York, UTC)' + ) + ) + + # Add api_key_hash column with unique constraint and index + op.add_column( + 'users', + sa.Column( + 'api_key_hash', + sa.String(length=255), + nullable=True, + comment='Bcrypt hash of API key for programmatic access' + ) + ) + # Create unique constraint for api_key_hash + op.create_unique_constraint( + 'uq_users_api_key_hash', + 'users', + ['api_key_hash'] + ) + # Create index for fast lookups + op.create_index( + 'ix_users_api_key_hash', + 'users', + ['api_key_hash'], + unique=False + ) + + # Add is_verified column + op.add_column( + 'users', + sa.Column( + 'is_verified', + sa.Boolean(), + nullable=False, + server_default='0', + comment='Whether user email has been verified' + ) + ) + + +def downgrade() -> None: + """Remove tax_jurisdiction, timezone, api_key_hash, and is_verified columns from users table. + + WARNING: This will permanently delete data in these columns! + """ + # Remove is_verified column + op.drop_column('users', 'is_verified') + + # Remove api_key_hash column (drop index and constraint first) + op.drop_index('ix_users_api_key_hash', 'users') + op.drop_constraint('uq_users_api_key_hash', 'users', type_='unique') + op.drop_column('users', 'api_key_hash') + + # Remove timezone column + op.drop_column('users', 'timezone') + + # Remove tax_jurisdiction column + op.drop_column('users', 'tax_jurisdiction') diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 6df62144..b9eac5dd 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -253,6 +253,8 @@ def test_user_data() -> Dict[str, Any]: "email": "test@example.com", "password": "SecurePassword123!", "full_name": "Test User", + "timezone": "America/New_York", # Issue #3 + "tax_jurisdiction": "US-NY", # Issue #3 } @@ -269,6 +271,8 @@ def second_user_data() -> Dict[str, Any]: "email": "other@example.com", "password": "AnotherPassword456!", "full_name": "Other User", + "timezone": "America/Los_Angeles", # Issue #3 + "tax_jurisdiction": "US-CA", # Issue #3 } @@ -667,3 +671,228 @@ def sample_xss_payloads() -> list[str]: "", "", ] + + +# ============================================================================ +# Issue #3 Fixtures: API Keys, Timezones, Tax Jurisdictions +# ============================================================================ + +@pytest.fixture +def verified_user_data() -> Dict[str, Any]: + """ + Test user data for verified user (Issue #3). + + Returns: + dict: Verified user data with all Issue #3 fields + + Example: + async def test_verified_user(verified_user_data): + assert verified_user_data["is_verified"] is True + """ + return { + "username": "verifieduser", + "email": "verified@example.com", + "password": "VerifiedPassword123!", + "full_name": "Verified User", + "timezone": "UTC", + "tax_jurisdiction": "US", + "is_verified": True, + } + + +@pytest.fixture +async def verified_user(db_session, verified_user_data): + """ + Create verified test user in database (Issue #3). + + Creates a verified user with timezone and tax jurisdiction. + + Args: + db_session: Database session + verified_user_data: Verified user data + + Yields: + User: Created verified user model instance + """ + try: + from tradingagents.api.models import User + from tradingagents.api.services.auth_service import hash_password + + user = User( + username=verified_user_data["username"], + email=verified_user_data["email"], + hashed_password=hash_password(verified_user_data["password"]), + full_name=verified_user_data.get("full_name"), + timezone=verified_user_data.get("timezone"), + tax_jurisdiction=verified_user_data.get("tax_jurisdiction"), + is_verified=verified_user_data.get("is_verified", True), + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + yield user + except ImportError: + yield None + + +@pytest.fixture +async def user_with_api_key(db_session, test_user_data): + """ + Create test user with API key in database (Issue #3). + + Creates a user with a hashed API key. + + Args: + db_session: Database session + test_user_data: Test user data + + Yields: + tuple[User, str]: (Created user, plain API key) + """ + try: + from tradingagents.api.models import User + from tradingagents.api.services.auth_service import hash_password + from tradingagents.api.services.api_key_service import ( + generate_api_key, + hash_api_key, + ) + + # Generate API key + plain_api_key = generate_api_key() + hashed_api_key = hash_api_key(plain_api_key) + + user = User( + username=test_user_data["username"], + email=test_user_data["email"], + hashed_password=hash_password(test_user_data["password"]), + full_name=test_user_data.get("full_name"), + api_key_hash=hashed_api_key, + timezone=test_user_data.get("timezone"), + tax_jurisdiction=test_user_data.get("tax_jurisdiction"), + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + yield (user, plain_api_key) + except ImportError: + yield (None, None) + + +@pytest.fixture +def valid_timezones() -> list[str]: + """ + List of valid IANA timezones for testing (Issue #3). + + Returns: + list[str]: Valid timezone identifiers + + Example: + def test_timezones(valid_timezones): + for tz in valid_timezones: + assert validate_timezone(tz) is True + """ + return [ + "UTC", + "GMT", + "America/New_York", + "America/Los_Angeles", + "America/Chicago", + "America/Denver", + "Europe/London", + "Europe/Paris", + "Europe/Berlin", + "Asia/Tokyo", + "Asia/Shanghai", + "Asia/Hong_Kong", + "Australia/Sydney", + "Australia/Melbourne", + "Pacific/Auckland", + ] + + +@pytest.fixture +def invalid_timezones() -> list[str]: + """ + List of invalid timezones for testing (Issue #3). + + Returns: + list[str]: Invalid timezone identifiers + + Example: + def test_invalid_timezones(invalid_timezones): + for tz in invalid_timezones: + assert validate_timezone(tz) is False + """ + return [ + "PST", + "EST", + "CST", + "MST", + "America/InvalidCity", + "Europe/FakePlace", + "Random/Stuff", + "america/new_york", # Wrong case + "123456", + "!@#$%", + ] + + +@pytest.fixture +def valid_tax_jurisdictions() -> list[str]: + """ + List of valid tax jurisdictions for testing (Issue #3). + + Returns: + list[str]: Valid tax jurisdiction codes + + Example: + def test_jurisdictions(valid_tax_jurisdictions): + for jurisdiction in valid_tax_jurisdictions: + assert validate_tax_jurisdiction(jurisdiction) is True + """ + return [ + "US", + "CA", + "GB", + "DE", + "FR", + "JP", + "AU", + "US-CA", + "US-NY", + "US-TX", + "US-FL", + "CA-ON", + "CA-QC", + "CA-BC", + ] + + +@pytest.fixture +def invalid_tax_jurisdictions() -> list[str]: + """ + List of invalid tax jurisdictions for testing (Issue #3). + + Returns: + list[str]: Invalid tax jurisdiction codes + + Example: + def test_invalid_jurisdictions(invalid_tax_jurisdictions): + for jurisdiction in invalid_tax_jurisdictions: + assert validate_tax_jurisdiction(jurisdiction) is False + """ + return [ + "InvalidFormat", + "US_CA", # Wrong separator + "US/CA", # Wrong separator + "USCA", # No separator + "us-ca", # Lowercase + "XX-YY", # Invalid country code + "123", + "!@#", + "", + ] diff --git a/tests/api/test_api_key_service.py b/tests/api/test_api_key_service.py new file mode 100644 index 00000000..7ec4df49 --- /dev/null +++ b/tests/api/test_api_key_service.py @@ -0,0 +1,587 @@ +""" +Test suite for API Key Service (Issue #3). + +This module tests the API key generation, hashing, and verification service: +1. Generate secure random API keys +2. Hash API keys using bcrypt +3. Verify API keys against hashes +4. Key format validation +5. Security best practices + +Tests follow TDD - written before implementation. +""" + +import pytest +import re +from typing import Optional + +pytestmark = pytest.mark.unit + + +# ============================================================================ +# Unit Tests: API Key Generation +# ============================================================================ + +class TestApiKeyGeneration: + """Test API key generation functionality.""" + + def test_generate_api_key_returns_string(self): + """Test that generate_api_key returns a string.""" + # Arrange & Act + try: + from tradingagents.api.services.api_key_service import generate_api_key + + api_key = generate_api_key() + + # Assert + assert isinstance(api_key, str) + assert len(api_key) > 0 + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_generate_api_key_format(self): + """Test that generated API key has correct format.""" + # Arrange & Act + try: + from tradingagents.api.services.api_key_service import generate_api_key + + api_key = generate_api_key() + + # Assert: Should be prefixed with "ta_" (TradingAgents) + assert api_key.startswith("ta_") + + # Should contain only alphanumeric characters after prefix + key_part = api_key[3:] # Remove "ta_" prefix + assert re.match(r'^[A-Za-z0-9]+$', key_part) + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_generate_api_key_length(self): + """Test that generated API key has sufficient length for security.""" + # Arrange & Act + try: + from tradingagents.api.services.api_key_service import generate_api_key + + api_key = generate_api_key() + + # Assert: Should be at least 32 characters (including prefix) + assert len(api_key) >= 32 + + # Key part (without prefix) should be at least 29 chars + key_part = api_key[3:] + assert len(key_part) >= 29 + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_generate_api_key_uniqueness(self): + """Test that each generated API key is unique.""" + # Arrange & Act + try: + from tradingagents.api.services.api_key_service import generate_api_key + + keys = [generate_api_key() for _ in range(100)] + + # Assert: All keys should be unique + assert len(keys) == len(set(keys)) + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_generate_api_key_randomness(self): + """Test that API keys have sufficient randomness.""" + # Arrange & Act + try: + from tradingagents.api.services.api_key_service import generate_api_key + + keys = [generate_api_key() for _ in range(10)] + + # Assert: Keys should not have common patterns + # Check that no two keys share same first 10 chars after prefix + prefixes = [key[3:13] for key in keys] + assert len(prefixes) == len(set(prefixes)) + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_generate_api_key_custom_length(self): + """Test generating API key with custom length.""" + # Arrange & Act + try: + from tradingagents.api.services.api_key_service import generate_api_key + + # Try generating key with custom length (if supported) + api_key = generate_api_key(length=40) + + # Assert: Should respect custom length + assert len(api_key) >= 40 or len(api_key) >= 32 # May have min length + except (ImportError, TypeError): + # TypeError is ok if length parameter not implemented + pytest.skip("Custom length not implemented yet") + + +# ============================================================================ +# Unit Tests: API Key Hashing +# ============================================================================ + +class TestApiKeyHashing: + """Test API key hashing functionality.""" + + def test_hash_api_key_returns_string(self): + """Test that hash_api_key returns a string.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import hash_api_key + + api_key = "ta_test1234567890abcdefghijklmnop" + + # Act + hashed = hash_api_key(api_key) + + # Assert + assert isinstance(hashed, str) + assert len(hashed) > 0 + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_hash_api_key_bcrypt_format(self): + """Test that hashed API key uses bcrypt format.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import hash_api_key + + api_key = "ta_test1234567890abcdefghijklmnop" + + # Act + hashed = hash_api_key(api_key) + + # Assert: Should be bcrypt format ($2b$rounds$salt+hash) + assert hashed.startswith("$2b$") or hashed.startswith("$2a$") + assert len(hashed) == 60 # Standard bcrypt hash length + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_hash_api_key_different_for_same_input(self): + """Test that hashing same key twice produces different hashes (salt).""" + # Arrange + try: + from tradingagents.api.services.api_key_service import hash_api_key + + api_key = "ta_test1234567890abcdefghijklmnop" + + # Act + hash1 = hash_api_key(api_key) + hash2 = hash_api_key(api_key) + + # Assert: Should be different due to different salts + assert hash1 != hash2 + assert hash1.startswith("$2b$") + assert hash2.startswith("$2b$") + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_hash_api_key_different_keys_different_hashes(self): + """Test that different keys produce different hashes.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import hash_api_key + + key1 = "ta_key1234567890abcdefghijklmnop" + key2 = "ta_key9876543210zyxwvutsrqponmlk" + + # Act + hash1 = hash_api_key(key1) + hash2 = hash_api_key(key2) + + # Assert + assert hash1 != hash2 + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_hash_api_key_empty_string(self): + """Test hashing empty string.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import hash_api_key + + # Act & Assert: Should handle gracefully or raise ValueError + try: + hashed = hash_api_key("") + # If it doesn't raise, should still return valid hash + assert isinstance(hashed, str) + except ValueError: + # Acceptable to raise ValueError for empty key + pass + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_hash_api_key_special_characters(self): + """Test hashing API key with special characters.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import hash_api_key + + api_key = "ta_key!@#$%^&*()_+-=[]{}|;:,.<>?" + + # Act + hashed = hash_api_key(api_key) + + # Assert: Should handle special characters + assert isinstance(hashed, str) + assert hashed.startswith("$2b$") or hashed.startswith("$2a$") + except ImportError: + pytest.skip("API key service not implemented yet") + + +# ============================================================================ +# Unit Tests: API Key Verification +# ============================================================================ + +class TestApiKeyVerification: + """Test API key verification functionality.""" + + def test_verify_api_key_correct_key(self): + """Test verifying API key with correct hash.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import ( + hash_api_key, + verify_api_key, + ) + + api_key = "ta_correct1234567890abcdefghijk" + hashed = hash_api_key(api_key) + + # Act + is_valid = verify_api_key(api_key, hashed) + + # Assert + assert is_valid is True + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_verify_api_key_incorrect_key(self): + """Test verifying API key with wrong hash.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import ( + hash_api_key, + verify_api_key, + ) + + correct_key = "ta_correct1234567890abcdefghijk" + wrong_key = "ta_wrongkey1234567890abcdefghij" + hashed = hash_api_key(correct_key) + + # Act + is_valid = verify_api_key(wrong_key, hashed) + + # Assert + assert is_valid is False + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_verify_api_key_empty_key(self): + """Test verifying empty API key.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import ( + hash_api_key, + verify_api_key, + ) + + api_key = "ta_test1234567890abcdefghijklmn" + hashed = hash_api_key(api_key) + + # Act + is_valid = verify_api_key("", hashed) + + # Assert + assert is_valid is False + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_verify_api_key_none_key(self): + """Test verifying None API key.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import ( + hash_api_key, + verify_api_key, + ) + + api_key = "ta_test1234567890abcdefghijklmn" + hashed = hash_api_key(api_key) + + # Act & Assert: Should return False or raise TypeError + try: + is_valid = verify_api_key(None, hashed) + assert is_valid is False + except TypeError: + # Acceptable to raise TypeError for None + pass + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_verify_api_key_invalid_hash(self): + """Test verifying API key against invalid hash.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import verify_api_key + + api_key = "ta_test1234567890abcdefghijklmn" + invalid_hash = "not-a-valid-bcrypt-hash" + + # Act & Assert: Should return False or raise ValueError + try: + is_valid = verify_api_key(api_key, invalid_hash) + assert is_valid is False + except ValueError: + # Acceptable to raise ValueError for invalid hash + pass + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_verify_api_key_case_sensitive(self): + """Test that API key verification is case-sensitive.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import ( + hash_api_key, + verify_api_key, + ) + + api_key = "ta_TestKey1234567890ABCDEFGHIJK" + hashed = hash_api_key(api_key) + + # Act + is_valid_correct = verify_api_key(api_key, hashed) + is_valid_wrong_case = verify_api_key(api_key.lower(), hashed) + + # Assert + assert is_valid_correct is True + assert is_valid_wrong_case is False + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_verify_api_key_similar_keys(self): + """Test that similar keys don't validate.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import ( + hash_api_key, + verify_api_key, + ) + + api_key = "ta_key1234567890abcdefghijklmnop" + similar_key = "ta_key1234567890abcdefghijklmnox" # Last char different + hashed = hash_api_key(api_key) + + # Act + is_valid = verify_api_key(similar_key, hashed) + + # Assert + assert is_valid is False + except ImportError: + pytest.skip("API key service not implemented yet") + + +# ============================================================================ +# Integration Tests: Full API Key Lifecycle +# ============================================================================ + +class TestApiKeyLifecycle: + """Test complete API key generation, hashing, and verification workflow.""" + + def test_full_api_key_lifecycle(self): + """Test complete lifecycle: generate -> hash -> verify.""" + # Arrange & Act + try: + from tradingagents.api.services.api_key_service import ( + generate_api_key, + hash_api_key, + verify_api_key, + ) + + # Generate new API key + api_key = generate_api_key() + + # Hash the API key + hashed = hash_api_key(api_key) + + # Verify with correct key + is_valid = verify_api_key(api_key, hashed) + + # Assert + assert isinstance(api_key, str) + assert api_key.startswith("ta_") + assert isinstance(hashed, str) + assert hashed.startswith("$2b$") or hashed.startswith("$2a$") + assert is_valid is True + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_multiple_keys_independent(self): + """Test that multiple API keys can coexist independently.""" + # Arrange & Act + try: + from tradingagents.api.services.api_key_service import ( + generate_api_key, + hash_api_key, + verify_api_key, + ) + + # Generate multiple keys + key1 = generate_api_key() + key2 = generate_api_key() + key3 = generate_api_key() + + # Hash each key + hash1 = hash_api_key(key1) + hash2 = hash_api_key(key2) + hash3 = hash_api_key(key3) + + # Assert: Each key only verifies against its own hash + assert verify_api_key(key1, hash1) is True + assert verify_api_key(key1, hash2) is False + assert verify_api_key(key1, hash3) is False + + assert verify_api_key(key2, hash1) is False + assert verify_api_key(key2, hash2) is True + assert verify_api_key(key2, hash3) is False + + assert verify_api_key(key3, hash1) is False + assert verify_api_key(key3, hash2) is False + assert verify_api_key(key3, hash3) is True + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_regenerate_key_invalidates_old_hash(self): + """Test that regenerating a key invalidates the old hash.""" + # Arrange & Act + try: + from tradingagents.api.services.api_key_service import ( + generate_api_key, + hash_api_key, + verify_api_key, + ) + + # Generate and hash first key + old_key = generate_api_key() + old_hash = hash_api_key(old_key) + + # Generate new key (simulate regeneration) + new_key = generate_api_key() + + # Assert: Old hash should not work with new key + assert verify_api_key(old_key, old_hash) is True + assert verify_api_key(new_key, old_hash) is False + except ImportError: + pytest.skip("API key service not implemented yet") + + +# ============================================================================ +# Edge Cases: API Key Service +# ============================================================================ + +class TestApiKeyEdgeCases: + """Test edge cases in API key service.""" + + def test_hash_very_long_key(self): + """Test hashing very long API key.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import ( + hash_api_key, + verify_api_key, + ) + + # Create very long key (bcrypt has 72 byte limit) + long_key = "ta_" + "a" * 200 + + # Act + hashed = hash_api_key(long_key) + is_valid = verify_api_key(long_key, hashed) + + # Assert: Should handle gracefully (may truncate to 72 bytes) + assert isinstance(hashed, str) + # Bcrypt will only use first ~72 bytes, so verification should work + assert is_valid is True + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_hash_unicode_characters(self): + """Test hashing API key with Unicode characters.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import ( + hash_api_key, + verify_api_key, + ) + + # Key with Unicode characters + unicode_key = "ta_测试key_🔑_αβγ" + + # Act + hashed = hash_api_key(unicode_key) + is_valid = verify_api_key(unicode_key, hashed) + + # Assert + assert isinstance(hashed, str) + assert is_valid is True + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_timing_attack_resistance(self): + """Test that verification takes similar time for valid/invalid keys.""" + # Arrange + try: + from tradingagents.api.services.api_key_service import ( + generate_api_key, + hash_api_key, + verify_api_key, + ) + import time + + api_key = generate_api_key() + hashed = hash_api_key(api_key) + wrong_key = generate_api_key() + + # Act: Measure time for correct and incorrect verification + times_correct = [] + times_incorrect = [] + + for _ in range(10): + start = time.perf_counter() + verify_api_key(api_key, hashed) + times_correct.append(time.perf_counter() - start) + + start = time.perf_counter() + verify_api_key(wrong_key, hashed) + times_incorrect.append(time.perf_counter() - start) + + # Assert: Times should be similar (within same order of magnitude) + # This is a basic check - bcrypt is inherently resistant to timing attacks + avg_correct = sum(times_correct) / len(times_correct) + avg_incorrect = sum(times_incorrect) / len(times_incorrect) + + # Both should take similar time (bcrypt always does full comparison) + assert avg_correct > 0 + assert avg_incorrect > 0 + except ImportError: + pytest.skip("API key service not implemented yet") + + def test_concurrent_key_generation(self): + """Test generating API keys concurrently.""" + # Arrange & Act + try: + from tradingagents.api.services.api_key_service import generate_api_key + import concurrent.futures + + # Generate keys concurrently + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(generate_api_key) for _ in range(100)] + keys = [f.result() for f in futures] + + # Assert: All keys should be unique + assert len(keys) == len(set(keys)) + assert all(k.startswith("ta_") for k in keys) + except ImportError: + pytest.skip("API key service not implemented yet") diff --git a/tests/api/test_models.py b/tests/api/test_models.py index 591fff0f..230f3f6a 100644 --- a/tests/api/test_models.py +++ b/tests/api/test_models.py @@ -1,13 +1,14 @@ """ Test suite for SQLAlchemy database models. -This module tests Issue #48 database models: +This module tests Issue #48 and Issue #3 database models: 1. User model with hashed passwords -2. Strategy model with JSON parameters -3. Relationships (User -> Strategies) -4. Model validation and constraints -5. Timestamps (created_at, updated_at) -6. Cascade delete behavior +2. User model with tax_jurisdiction, timezone, api_key_hash, is_verified (Issue #3) +3. Strategy model with JSON parameters +4. Relationships (User -> Strategies) +5. Model validation and constraints +6. Timestamps (created_at, updated_at) +7. Cascade delete behavior Tests follow TDD - written before implementation. """ @@ -804,3 +805,451 @@ class TestModelEdgeCases: assert strategy.parameters["l1"]["l2"]["l3"]["l4"]["l5"]["value"] == "deep" except ImportError: pytest.skip("Models not implemented yet") + + +# ============================================================================ +# Unit Tests: Issue #3 - User Model Enhancements +# ============================================================================ + +class TestUserModelIssue3: + """Test User model enhancements from Issue #3. + + New fields tested: + - tax_jurisdiction: Nullable string for user's tax jurisdiction + - timezone: Nullable string for user's timezone (must be valid IANA timezone) + - api_key_hash: Nullable string for hashed API key + - is_verified: Boolean flag for email verification (defaults to False) + """ + + async def test_user_tax_jurisdiction_default_none(self, db_session): + """Test that tax_jurisdiction defaults to None.""" + # Arrange + try: + from tradingagents.api.models import User + + user = User( + username="taxuser", + email="tax@example.com", + hashed_password="hash", + ) + + # Act + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert hasattr(user, "tax_jurisdiction") + assert user.tax_jurisdiction is None + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_tax_jurisdiction_custom_value(self, db_session): + """Test setting custom tax_jurisdiction value.""" + # Arrange + try: + from tradingagents.api.models import User + + user = User( + username="taxuser2", + email="tax2@example.com", + hashed_password="hash", + tax_jurisdiction="US-CA", + ) + + # Act + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert user.tax_jurisdiction == "US-CA" + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_timezone_default_none(self, db_session): + """Test that timezone defaults to None.""" + # Arrange + try: + from tradingagents.api.models import User + + user = User( + username="tzuser", + email="tz@example.com", + hashed_password="hash", + ) + + # Act + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert hasattr(user, "timezone") + assert user.timezone is None + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_timezone_valid_value(self, db_session): + """Test setting valid timezone value.""" + # Arrange + try: + from tradingagents.api.models import User + + user = User( + username="tzuser2", + email="tz2@example.com", + hashed_password="hash", + timezone="America/New_York", + ) + + # Act + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert user.timezone == "America/New_York" + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_timezone_various_timezones(self, db_session): + """Test various valid IANA timezone values.""" + # Arrange + try: + from tradingagents.api.models import User + + timezones = [ + "UTC", + "America/Los_Angeles", + "Europe/London", + "Asia/Tokyo", + "Australia/Sydney", + ] + + for i, tz in enumerate(timezones): + user = User( + username=f"tzuser_{i}", + email=f"tz{i}@example.com", + hashed_password="hash", + timezone=tz, + ) + + # Act + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert user.timezone == tz + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_api_key_hash_default_none(self, db_session): + """Test that api_key_hash defaults to None.""" + # Arrange + try: + from tradingagents.api.models import User + + user = User( + username="apiuser", + email="api@example.com", + hashed_password="hash", + ) + + # Act + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert hasattr(user, "api_key_hash") + assert user.api_key_hash is None + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_api_key_hash_custom_value(self, db_session): + """Test setting api_key_hash value.""" + # Arrange + try: + from tradingagents.api.models import User + + # Simulate hashed API key (bcrypt hash format) + api_key_hash = "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/LewY5ygOy3f3K0E6O" + + user = User( + username="apiuser2", + email="api2@example.com", + hashed_password="hash", + api_key_hash=api_key_hash, + ) + + # Act + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert user.api_key_hash == api_key_hash + assert user.api_key_hash.startswith("$2b$") + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_is_verified_default_false(self, db_session): + """Test that is_verified defaults to False.""" + # Arrange + try: + from tradingagents.api.models import User + + user = User( + username="verifyuser", + email="verify@example.com", + hashed_password="hash", + ) + + # Act + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert hasattr(user, "is_verified") + assert user.is_verified is False + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_is_verified_can_be_true(self, db_session): + """Test setting is_verified to True.""" + # Arrange + try: + from tradingagents.api.models import User + + user = User( + username="verifyuser2", + email="verify2@example.com", + hashed_password="hash", + is_verified=True, + ) + + # Act + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert user.is_verified is True + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_all_new_fields_together(self, db_session): + """Test creating user with all Issue #3 fields.""" + # Arrange + try: + from tradingagents.api.models import User + + user = User( + username="fulluser", + email="full@example.com", + hashed_password="hash", + tax_jurisdiction="US-NY", + timezone="America/New_York", + api_key_hash="$2b$12$hashedapikey123", + is_verified=True, + ) + + # Act + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert user.tax_jurisdiction == "US-NY" + assert user.timezone == "America/New_York" + assert user.api_key_hash == "$2b$12$hashedapikey123" + assert user.is_verified is True + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_update_timezone(self, db_session): + """Test updating user's timezone.""" + # Arrange + try: + from tradingagents.api.models import User + + user = User( + username="updatetz", + email="updatetz@example.com", + hashed_password="hash", + timezone="UTC", + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Act: Update timezone + user.timezone = "America/Los_Angeles" + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert user.timezone == "America/Los_Angeles" + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_update_tax_jurisdiction(self, db_session): + """Test updating user's tax_jurisdiction.""" + # Arrange + try: + from tradingagents.api.models import User + + user = User( + username="updatetax", + email="updatetax@example.com", + hashed_password="hash", + tax_jurisdiction="US-CA", + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Act: Update tax jurisdiction + user.tax_jurisdiction = "US-TX" + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert user.tax_jurisdiction == "US-TX" + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_verify_email(self, db_session): + """Test verifying user's email.""" + # Arrange + try: + from tradingagents.api.models import User + + user = User( + username="toverify", + email="toverify@example.com", + hashed_password="hash", + is_verified=False, + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.is_verified is False + + # Act: Verify email + user.is_verified = True + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert user.is_verified is True + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_query_by_timezone(self, db_session): + """Test querying users by timezone.""" + # Arrange + try: + from tradingagents.api.models import User + from sqlalchemy import select + + # Create users with different timezones + user1 = User( + username="user_utc", + email="utc@example.com", + hashed_password="hash", + timezone="UTC", + ) + user2 = User( + username="user_ny", + email="ny@example.com", + hashed_password="hash", + timezone="America/New_York", + ) + user3 = User( + username="user_utc2", + email="utc2@example.com", + hashed_password="hash", + timezone="UTC", + ) + + db_session.add_all([user1, user2, user3]) + await db_session.commit() + + # Act: Query users in UTC timezone + result = await db_session.execute( + select(User).where(User.timezone == "UTC") + ) + utc_users = result.scalars().all() + + # Assert + assert len(utc_users) == 2 + assert all(u.timezone == "UTC" for u in utc_users) + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_query_verified_users(self, db_session): + """Test querying only verified users.""" + # Arrange + try: + from tradingagents.api.models import User + from sqlalchemy import select + + # Create verified and unverified users + verified_user = User( + username="verified", + email="verified@example.com", + hashed_password="hash", + is_verified=True, + ) + unverified_user = User( + username="unverified", + email="unverified@example.com", + hashed_password="hash", + is_verified=False, + ) + + db_session.add_all([verified_user, unverified_user]) + await db_session.commit() + + # Act: Query verified users + result = await db_session.execute( + select(User).where(User.is_verified == True) + ) + verified_users = result.scalars().all() + + # Assert + assert len(verified_users) >= 1 + assert all(u.is_verified for u in verified_users) + except ImportError: + pytest.skip("Models not implemented yet") + + async def test_user_api_key_hash_nullable(self, db_session): + """Test that api_key_hash can be set to None.""" + # Arrange + try: + from tradingagents.api.models import User + + user = User( + username="apinull", + email="apinull@example.com", + hashed_password="hash", + api_key_hash="$2b$12$somehash", + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Act: Remove API key + user.api_key_hash = None + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert user.api_key_hash is None + except ImportError: + pytest.skip("Models not implemented yet") diff --git a/tests/api/test_user_model.py b/tests/api/test_user_model.py new file mode 100644 index 00000000..93eb91a3 --- /dev/null +++ b/tests/api/test_user_model.py @@ -0,0 +1,604 @@ +"""Unit tests for User model with Issue #3 enhancements. + +Tests for User model fields including: +- tax_jurisdiction +- timezone +- api_key_hash +- is_verified + +Follows TDD principles with comprehensive coverage. +""" + +import pytest +from sqlalchemy import select +from tradingagents.api.models.user import User +from tradingagents.api.services.auth_service import hash_password +from tradingagents.api.services.api_key_service import generate_api_key, hash_api_key + + +class TestUserModelBasicFields: + """Tests for basic User model fields.""" + + @pytest.mark.asyncio + async def test_create_user_with_required_fields(self, db_session): + """Should create user with only required fields.""" + user = User( + username="testuser", + email="test@example.com", + hashed_password=hash_password("SecurePassword123!"), + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.id is not None + assert user.username == "testuser" + assert user.email == "test@example.com" + assert user.hashed_password is not None + + @pytest.mark.asyncio + async def test_user_defaults(self, db_session): + """Should apply default values to optional fields.""" + user = User( + username="testuser", + email="test@example.com", + hashed_password=hash_password("SecurePassword123!"), + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Check defaults + assert user.is_active is True + assert user.is_superuser is False + assert user.tax_jurisdiction == "AU" + assert user.timezone == "Australia/Sydney" + assert user.is_verified is False + assert user.api_key_hash is None + assert user.full_name is None + + @pytest.mark.asyncio + async def test_username_unique_constraint(self, db_session): + """Should enforce unique username constraint.""" + user1 = User( + username="testuser", + email="test1@example.com", + hashed_password=hash_password("Password123!"), + ) + db_session.add(user1) + await db_session.commit() + + # Try to create user with same username + user2 = User( + username="testuser", + email="test2@example.com", + hashed_password=hash_password("Password456!"), + ) + db_session.add(user2) + + with pytest.raises(Exception): # IntegrityError + await db_session.commit() + + @pytest.mark.asyncio + async def test_email_unique_constraint(self, db_session): + """Should enforce unique email constraint.""" + user1 = User( + username="user1", + email="test@example.com", + hashed_password=hash_password("Password123!"), + ) + db_session.add(user1) + await db_session.commit() + + # Try to create user with same email + user2 = User( + username="user2", + email="test@example.com", + hashed_password=hash_password("Password456!"), + ) + db_session.add(user2) + + with pytest.raises(Exception): # IntegrityError + await db_session.commit() + + +class TestUserModelTaxJurisdiction: + """Tests for tax_jurisdiction field (Issue #3).""" + + @pytest.mark.asyncio + async def test_set_us_jurisdiction(self, db_session): + """Should set US tax jurisdiction.""" + user = User( + username="ususer", + email="us@example.com", + hashed_password=hash_password("Password123!"), + tax_jurisdiction="US", + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.tax_jurisdiction == "US" + + @pytest.mark.asyncio + async def test_set_us_state_jurisdiction(self, db_session): + """Should set US state-level tax jurisdiction.""" + user = User( + username="nyuser", + email="ny@example.com", + hashed_password=hash_password("Password123!"), + tax_jurisdiction="US-NY", + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.tax_jurisdiction == "US-NY" + + @pytest.mark.asyncio + async def test_set_canadian_province_jurisdiction(self, db_session): + """Should set Canadian province-level tax jurisdiction.""" + user = User( + username="causer", + email="ca@example.com", + hashed_password=hash_password("Password123!"), + tax_jurisdiction="CA-ON", + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.tax_jurisdiction == "CA-ON" + + @pytest.mark.asyncio + async def test_set_australian_state_jurisdiction(self, db_session): + """Should set Australian state-level tax jurisdiction.""" + user = User( + username="auuser", + email="au@example.com", + hashed_password=hash_password("Password123!"), + tax_jurisdiction="AU-NSW", + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.tax_jurisdiction == "AU-NSW" + + @pytest.mark.asyncio + async def test_tax_jurisdiction_default(self, db_session): + """Should default to AU if not specified.""" + user = User( + username="defaultuser", + email="default@example.com", + hashed_password=hash_password("Password123!"), + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.tax_jurisdiction == "AU" + + @pytest.mark.asyncio + async def test_tax_jurisdiction_max_length(self, db_session): + """Tax jurisdiction should not exceed 10 characters.""" + # This should work (10 chars max) + user = User( + username="testuser", + email="test@example.com", + hashed_password=hash_password("Password123!"), + tax_jurisdiction="AU-NSW", # 6 chars + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.tax_jurisdiction == "AU-NSW" + + +class TestUserModelTimezone: + """Tests for timezone field (Issue #3).""" + + @pytest.mark.asyncio + async def test_set_us_timezone(self, db_session): + """Should set US timezone.""" + user = User( + username="nyuser", + email="ny@example.com", + hashed_password=hash_password("Password123!"), + timezone="America/New_York", + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.timezone == "America/New_York" + + @pytest.mark.asyncio + async def test_set_utc_timezone(self, db_session): + """Should set UTC timezone.""" + user = User( + username="utcuser", + email="utc@example.com", + hashed_password=hash_password("Password123!"), + timezone="UTC", + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.timezone == "UTC" + + @pytest.mark.asyncio + async def test_set_european_timezone(self, db_session): + """Should set European timezone.""" + user = User( + username="londonuser", + email="london@example.com", + hashed_password=hash_password("Password123!"), + timezone="Europe/London", + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.timezone == "Europe/London" + + @pytest.mark.asyncio + async def test_set_asian_timezone(self, db_session): + """Should set Asian timezone.""" + user = User( + username="tokyouser", + email="tokyo@example.com", + hashed_password=hash_password("Password123!"), + timezone="Asia/Tokyo", + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.timezone == "Asia/Tokyo" + + @pytest.mark.asyncio + async def test_timezone_default(self, db_session): + """Should default to Australia/Sydney if not specified.""" + user = User( + username="defaultuser", + email="default@example.com", + hashed_password=hash_password("Password123!"), + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.timezone == "Australia/Sydney" + + @pytest.mark.asyncio + async def test_timezone_max_length(self, db_session): + """Timezone should not exceed 50 characters.""" + # Longest IANA timezone is ~40 characters + user = User( + username="testuser", + email="test@example.com", + hashed_password=hash_password("Password123!"), + timezone="America/Argentina/ComodRivadavia", # 35 chars + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.timezone == "America/Argentina/ComodRivadavia" + + +class TestUserModelApiKey: + """Tests for api_key_hash field (Issue #3).""" + + @pytest.mark.asyncio + async def test_user_without_api_key(self, db_session): + """User without API key should have None api_key_hash.""" + user = User( + username="nokey", + email="nokey@example.com", + hashed_password=hash_password("Password123!"), + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.api_key_hash is None + + @pytest.mark.asyncio + async def test_user_with_api_key(self, db_session): + """User with API key should store hashed key.""" + api_key = generate_api_key() + hashed = hash_api_key(api_key) + + user = User( + username="withkey", + email="withkey@example.com", + hashed_password=hash_password("Password123!"), + api_key_hash=hashed, + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.api_key_hash is not None + assert user.api_key_hash == hashed + assert user.api_key_hash != api_key # Should be hash, not plain key + + @pytest.mark.asyncio + async def test_api_key_hash_unique_constraint(self, db_session): + """Should enforce unique api_key_hash constraint.""" + api_key = generate_api_key() + hashed = hash_api_key(api_key) + + user1 = User( + username="user1", + email="user1@example.com", + hashed_password=hash_password("Password123!"), + api_key_hash=hashed, + ) + db_session.add(user1) + await db_session.commit() + + # Try to create user with same api_key_hash + user2 = User( + username="user2", + email="user2@example.com", + hashed_password=hash_password("Password456!"), + api_key_hash=hashed, + ) + db_session.add(user2) + + with pytest.raises(Exception): # IntegrityError + await db_session.commit() + + @pytest.mark.asyncio + async def test_api_key_hash_indexed(self, db_session): + """api_key_hash should be indexed for fast lookups.""" + # Create users with API keys + for i in range(10): + api_key = generate_api_key() + hashed = hash_api_key(api_key) + + user = User( + username=f"user{i}", + email=f"user{i}@example.com", + hashed_password=hash_password("Password123!"), + api_key_hash=hashed, + ) + db_session.add(user) + + await db_session.commit() + + # Lookup by api_key_hash should work + api_key = generate_api_key() + hashed = hash_api_key(api_key) + + user = User( + username="lookup", + email="lookup@example.com", + hashed_password=hash_password("Password123!"), + api_key_hash=hashed, + ) + db_session.add(user) + await db_session.commit() + + # Query by api_key_hash + result = await db_session.execute( + select(User).where(User.api_key_hash == hashed) + ) + found_user = result.scalar_one_or_none() + + assert found_user is not None + assert found_user.username == "lookup" + + @pytest.mark.asyncio + async def test_update_api_key(self, db_session): + """User should be able to regenerate their API key.""" + # Create user with API key + old_api_key = generate_api_key() + old_hash = hash_api_key(old_api_key) + + user = User( + username="regenerate", + email="regenerate@example.com", + hashed_password=hash_password("Password123!"), + api_key_hash=old_hash, + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Regenerate API key + new_api_key = generate_api_key() + new_hash = hash_api_key(new_api_key) + + user.api_key_hash = new_hash + await db_session.commit() + await db_session.refresh(user) + + assert user.api_key_hash == new_hash + assert user.api_key_hash != old_hash + + @pytest.mark.asyncio + async def test_revoke_api_key(self, db_session): + """User should be able to revoke their API key.""" + # Create user with API key + api_key = generate_api_key() + hashed = hash_api_key(api_key) + + user = User( + username="revoke", + email="revoke@example.com", + hashed_password=hash_password("Password123!"), + api_key_hash=hashed, + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Revoke API key (set to None) + user.api_key_hash = None + await db_session.commit() + await db_session.refresh(user) + + assert user.api_key_hash is None + + +class TestUserModelIsVerified: + """Tests for is_verified field (Issue #3).""" + + @pytest.mark.asyncio + async def test_user_unverified_by_default(self, db_session): + """New users should be unverified by default.""" + user = User( + username="unverified", + email="unverified@example.com", + hashed_password=hash_password("Password123!"), + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.is_verified is False + + @pytest.mark.asyncio + async def test_create_verified_user(self, db_session): + """Should be able to create verified user.""" + user = User( + username="verified", + email="verified@example.com", + hashed_password=hash_password("Password123!"), + is_verified=True, + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.is_verified is True + + @pytest.mark.asyncio + async def test_verify_user(self, db_session): + """Should be able to verify user after creation.""" + # Create unverified user + user = User( + username="toverify", + email="toverify@example.com", + hashed_password=hash_password("Password123!"), + is_verified=False, + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.is_verified is False + + # Verify user + user.is_verified = True + await db_session.commit() + await db_session.refresh(user) + + assert user.is_verified is True + + @pytest.mark.asyncio + async def test_query_verified_users(self, db_session): + """Should be able to query only verified users.""" + # Create mix of verified and unverified users + for i in range(5): + user = User( + username=f"user{i}", + email=f"user{i}@example.com", + hashed_password=hash_password("Password123!"), + is_verified=(i % 2 == 0), # Alternate verified/unverified + ) + db_session.add(user) + + await db_session.commit() + + # Query only verified users + result = await db_session.execute( + select(User).where(User.is_verified == True) + ) + verified_users = result.scalars().all() + + assert len(verified_users) == 3 # users 0, 2, 4 + for user in verified_users: + assert user.is_verified is True + + +class TestUserModelComplete: + """Tests for complete user creation with all Issue #3 fields.""" + + @pytest.mark.asyncio + async def test_create_complete_user(self, db_session): + """Should create user with all fields including Issue #3 additions.""" + api_key = generate_api_key() + hashed_api_key = hash_api_key(api_key) + + user = User( + username="complete", + email="complete@example.com", + hashed_password=hash_password("Password123!"), + full_name="Complete User", + is_active=True, + is_superuser=False, + tax_jurisdiction="US-NY", + timezone="America/New_York", + api_key_hash=hashed_api_key, + is_verified=True, + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + assert user.id is not None + assert user.username == "complete" + assert user.email == "complete@example.com" + assert user.full_name == "Complete User" + assert user.is_active is True + assert user.is_superuser is False + assert user.tax_jurisdiction == "US-NY" + assert user.timezone == "America/New_York" + assert user.api_key_hash == hashed_api_key + assert user.is_verified is True + assert user.created_at is not None + assert user.updated_at is not None + + @pytest.mark.asyncio + async def test_user_repr(self, db_session): + """Should have meaningful string representation.""" + user = User( + username="reprtest", + email="repr@example.com", + hashed_password=hash_password("Password123!"), + ) + + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + repr_str = repr(user) + assert "reprtest" in repr_str + assert "repr@example.com" in repr_str + assert str(user.id) in repr_str diff --git a/tests/api/test_validators.py b/tests/api/test_validators.py new file mode 100644 index 00000000..21d30488 --- /dev/null +++ b/tests/api/test_validators.py @@ -0,0 +1,678 @@ +""" +Test suite for validators (Issue #3). + +This module tests timezone and tax jurisdiction validation: +1. Timezone validation (IANA timezone database) +2. Tax jurisdiction validation (format and valid codes) +3. Edge cases and error handling +4. Integration with User model + +Tests follow TDD - written before implementation. +""" + +import pytest +from typing import Optional + +pytestmark = pytest.mark.unit + + +# ============================================================================ +# Unit Tests: Timezone Validation +# ============================================================================ + +class TestTimezoneValidation: + """Test timezone validation functionality.""" + + def test_validate_timezone_valid_utc(self): + """Test validating UTC timezone.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_timezone + + result = validate_timezone("UTC") + + # Assert + assert result is True + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_timezone_valid_america_new_york(self): + """Test validating America/New_York timezone.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_timezone + + result = validate_timezone("America/New_York") + + # Assert + assert result is True + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_timezone_valid_europe_london(self): + """Test validating Europe/London timezone.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_timezone + + result = validate_timezone("Europe/London") + + # Assert + assert result is True + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_timezone_valid_asia_tokyo(self): + """Test validating Asia/Tokyo timezone.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_timezone + + result = validate_timezone("Asia/Tokyo") + + # Assert + assert result is True + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_timezone_valid_australia_sydney(self): + """Test validating Australia/Sydney timezone.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_timezone + + result = validate_timezone("Australia/Sydney") + + # Assert + assert result is True + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_timezone_invalid_fake_timezone(self): + """Test rejecting invalid timezone.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_timezone + + result = validate_timezone("Invalid/Timezone") + + # Assert + assert result is False + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_timezone_invalid_empty_string(self): + """Test rejecting empty string.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_timezone + + result = validate_timezone("") + + # Assert + assert result is False + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_timezone_none(self): + """Test handling None value.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_timezone + + # None should be accepted (nullable field) + result = validate_timezone(None) + + # Assert: None is valid (nullable) + assert result is True + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_timezone_case_sensitive(self): + """Test that timezone validation is case-sensitive.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_timezone + + # Correct case + result_correct = validate_timezone("America/New_York") + + # Wrong case + result_wrong = validate_timezone("america/new_york") + + # Assert + assert result_correct is True + assert result_wrong is False + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_timezone_various_valid_timezones(self): + """Test validating various valid IANA timezones.""" + # Arrange + try: + from tradingagents.api.services.validators import validate_timezone + + valid_timezones = [ + "UTC", + "GMT", + "America/New_York", + "America/Los_Angeles", + "America/Chicago", + "America/Denver", + "Europe/London", + "Europe/Paris", + "Europe/Berlin", + "Asia/Tokyo", + "Asia/Shanghai", + "Asia/Hong_Kong", + "Australia/Sydney", + "Australia/Melbourne", + "Pacific/Auckland", + ] + + # Act & Assert + for tz in valid_timezones: + result = validate_timezone(tz) + assert result is True, f"Timezone {tz} should be valid" + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_timezone_various_invalid_timezones(self): + """Test rejecting various invalid timezones.""" + # Arrange + try: + from tradingagents.api.services.validators import validate_timezone + + invalid_timezones = [ + "PST", # Abbreviations not valid + "EST", + "CST", + "MST", + "America/InvalidCity", + "Europe/FakePlace", + "Random/Stuff", + "123456", + "!@#$%", + "america/new_york", # Wrong case + ] + + # Act & Assert + for tz in invalid_timezones: + result = validate_timezone(tz) + assert result is False, f"Timezone {tz} should be invalid" + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_timezone_with_underscores(self): + """Test timezones with underscores in city names.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_timezone + + # Valid timezones with underscores + result1 = validate_timezone("America/New_York") + result2 = validate_timezone("America/Los_Angeles") + result3 = validate_timezone("America/Port-au-Prince") + + # Assert + assert result1 is True + assert result2 is True + assert result3 is True + except ImportError: + pytest.skip("Validators not implemented yet") + + +# ============================================================================ +# Unit Tests: Tax Jurisdiction Validation +# ============================================================================ + +class TestTaxJurisdictionValidation: + """Test tax jurisdiction validation functionality.""" + + def test_validate_tax_jurisdiction_valid_us_state(self): + """Test validating US state tax jurisdiction.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_tax_jurisdiction + + result = validate_tax_jurisdiction("US-CA") + + # Assert + assert result is True + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_tax_jurisdiction_valid_us_states(self): + """Test validating various US state tax jurisdictions.""" + # Arrange + try: + from tradingagents.api.services.validators import validate_tax_jurisdiction + + valid_jurisdictions = [ + "US-CA", # California + "US-NY", # New York + "US-TX", # Texas + "US-FL", # Florida + "US-IL", # Illinois + "US-PA", # Pennsylvania + "US-OH", # Ohio + "US-WA", # Washington + ] + + # Act & Assert + for jurisdiction in valid_jurisdictions: + result = validate_tax_jurisdiction(jurisdiction) + assert result is True, f"Jurisdiction {jurisdiction} should be valid" + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_tax_jurisdiction_valid_countries(self): + """Test validating country-level tax jurisdictions.""" + # Arrange + try: + from tradingagents.api.services.validators import validate_tax_jurisdiction + + valid_jurisdictions = [ + "US", # United States + "CA", # Canada + "GB", # United Kingdom + "DE", # Germany + "FR", # France + "JP", # Japan + "AU", # Australia + "NZ", # New Zealand + ] + + # Act & Assert + for jurisdiction in valid_jurisdictions: + result = validate_tax_jurisdiction(jurisdiction) + assert result is True, f"Jurisdiction {jurisdiction} should be valid" + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_tax_jurisdiction_valid_canadian_provinces(self): + """Test validating Canadian province tax jurisdictions.""" + # Arrange + try: + from tradingagents.api.services.validators import validate_tax_jurisdiction + + valid_jurisdictions = [ + "CA-ON", # Ontario + "CA-QC", # Quebec + "CA-BC", # British Columbia + "CA-AB", # Alberta + ] + + # Act & Assert + for jurisdiction in valid_jurisdictions: + result = validate_tax_jurisdiction(jurisdiction) + assert result is True, f"Jurisdiction {jurisdiction} should be valid" + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_tax_jurisdiction_invalid_format(self): + """Test rejecting invalid format.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_tax_jurisdiction + + invalid_jurisdictions = [ + "InvalidFormat", + "US_CA", # Wrong separator + "US/CA", # Wrong separator + "USCA", # No separator + "us-ca", # Lowercase + "123", + "!@#", + ] + + # Act & Assert + for jurisdiction in invalid_jurisdictions: + result = validate_tax_jurisdiction(jurisdiction) + assert result is False, f"Jurisdiction {jurisdiction} should be invalid" + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_tax_jurisdiction_invalid_country_code(self): + """Test rejecting invalid country codes.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_tax_jurisdiction + + result = validate_tax_jurisdiction("XX-YY") + + # Assert + assert result is False + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_tax_jurisdiction_none(self): + """Test handling None value.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_tax_jurisdiction + + # None should be accepted (nullable field) + result = validate_tax_jurisdiction(None) + + # Assert + assert result is True + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_tax_jurisdiction_empty_string(self): + """Test rejecting empty string.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_tax_jurisdiction + + result = validate_tax_jurisdiction("") + + # Assert + assert result is False + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_tax_jurisdiction_case_sensitive(self): + """Test that tax jurisdiction validation is case-sensitive.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_tax_jurisdiction + + # Correct case (uppercase) + result_correct = validate_tax_jurisdiction("US-CA") + + # Wrong case (lowercase) + result_wrong = validate_tax_jurisdiction("us-ca") + + # Assert + assert result_correct is True + assert result_wrong is False + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_validate_tax_jurisdiction_max_length(self): + """Test rejecting very long jurisdiction strings.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_tax_jurisdiction + + # Very long string + result = validate_tax_jurisdiction("US-" + "A" * 100) + + # Assert: Should reject overly long jurisdictions + assert result is False + except ImportError: + pytest.skip("Validators not implemented yet") + + +# ============================================================================ +# Integration Tests: Validators with User Model +# ============================================================================ + +class TestValidatorsIntegration: + """Test validators integrated with User model.""" + + @pytest.mark.asyncio + async def test_user_with_valid_timezone(self, db_session): + """Test creating user with validated timezone.""" + # Arrange + try: + from tradingagents.api.models import User + from tradingagents.api.services.validators import validate_timezone + + timezone = "America/New_York" + assert validate_timezone(timezone) is True + + user = User( + username="tzvaliduser", + email="tzvalid@example.com", + hashed_password="hash", + timezone=timezone, + ) + + # Act + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert user.timezone == timezone + except ImportError: + pytest.skip("Models or validators not implemented yet") + + @pytest.mark.asyncio + async def test_user_with_invalid_timezone_at_api_level(self, db_session): + """Test that invalid timezone should be caught at API level, not DB.""" + # Arrange + try: + from tradingagents.api.models import User + from tradingagents.api.services.validators import validate_timezone + + invalid_timezone = "Invalid/Timezone" + assert validate_timezone(invalid_timezone) is False + + # Note: Database will accept it, validation happens at API layer + user = User( + username="tzinvalid", + email="tzinvalid@example.com", + hashed_password="hash", + timezone=invalid_timezone, + ) + + # Act: DB should accept it (validation is at API level) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Assert: DB accepts it, but validator rejects it + assert user.timezone == invalid_timezone + assert validate_timezone(user.timezone) is False + except ImportError: + pytest.skip("Models or validators not implemented yet") + + @pytest.mark.asyncio + async def test_user_with_valid_tax_jurisdiction(self, db_session): + """Test creating user with validated tax jurisdiction.""" + # Arrange + try: + from tradingagents.api.models import User + from tradingagents.api.services.validators import validate_tax_jurisdiction + + jurisdiction = "US-CA" + assert validate_tax_jurisdiction(jurisdiction) is True + + user = User( + username="taxvaliduser", + email="taxvalid@example.com", + hashed_password="hash", + tax_jurisdiction=jurisdiction, + ) + + # Act + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert user.tax_jurisdiction == jurisdiction + except ImportError: + pytest.skip("Models or validators not implemented yet") + + @pytest.mark.asyncio + async def test_user_with_both_validators(self, db_session): + """Test creating user with both validated fields.""" + # Arrange + try: + from tradingagents.api.models import User + from tradingagents.api.services.validators import ( + validate_timezone, + validate_tax_jurisdiction, + ) + + timezone = "America/Los_Angeles" + jurisdiction = "US-CA" + + assert validate_timezone(timezone) is True + assert validate_tax_jurisdiction(jurisdiction) is True + + user = User( + username="bothvalid", + email="bothvalid@example.com", + hashed_password="hash", + timezone=timezone, + tax_jurisdiction=jurisdiction, + ) + + # Act + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Assert + assert user.timezone == timezone + assert user.tax_jurisdiction == jurisdiction + except ImportError: + pytest.skip("Models or validators not implemented yet") + + +# ============================================================================ +# Edge Cases: Validators +# ============================================================================ + +class TestValidatorEdgeCases: + """Test edge cases in validators.""" + + def test_timezone_with_special_characters(self): + """Test timezone with special characters.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_timezone + + # Test various special characters + result1 = validate_timezone("America/Port-au-Prince") # Hyphen + result2 = validate_timezone("America/Indiana/Indianapolis") # Multiple slashes + + # Assert + assert result1 is True + assert result2 is True + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_timezone_whitespace_handling(self): + """Test timezone validation with whitespace.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_timezone + + # Timezones with leading/trailing whitespace + result1 = validate_timezone(" America/New_York ") + result2 = validate_timezone("America/New_York ") + result3 = validate_timezone(" America/New_York") + + # Assert: Should reject (strict validation) + assert result1 is False + assert result2 is False + assert result3 is False + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_tax_jurisdiction_whitespace_handling(self): + """Test tax jurisdiction validation with whitespace.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_tax_jurisdiction + + # Jurisdictions with leading/trailing whitespace + result1 = validate_tax_jurisdiction(" US-CA ") + result2 = validate_tax_jurisdiction("US-CA ") + result3 = validate_tax_jurisdiction(" US-CA") + + # Assert: Should reject (strict validation) + assert result1 is False + assert result2 is False + assert result3 is False + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_timezone_numeric_string(self): + """Test timezone validation with numeric strings.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_timezone + + result = validate_timezone("12345") + + # Assert + assert result is False + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_tax_jurisdiction_only_country(self): + """Test tax jurisdiction with only country code.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_tax_jurisdiction + + # Two-letter country codes should be valid + result_us = validate_tax_jurisdiction("US") + result_ca = validate_tax_jurisdiction("CA") + result_gb = validate_tax_jurisdiction("GB") + + # Assert + assert result_us is True + assert result_ca is True + assert result_gb is True + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_tax_jurisdiction_single_letter(self): + """Test tax jurisdiction with single letter.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_tax_jurisdiction + + result = validate_tax_jurisdiction("A") + + # Assert + assert result is False + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_timezone_sql_injection_attempt(self): + """Test timezone validation against SQL injection.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_timezone + + malicious_inputs = [ + "'; DROP TABLE users; --", + "1' OR '1'='1", + "America/New_York'; DELETE FROM users; --", + ] + + # Act & Assert + for malicious in malicious_inputs: + result = validate_timezone(malicious) + assert result is False, f"Should reject malicious input: {malicious}" + except ImportError: + pytest.skip("Validators not implemented yet") + + def test_tax_jurisdiction_sql_injection_attempt(self): + """Test tax jurisdiction validation against SQL injection.""" + # Arrange & Act + try: + from tradingagents.api.services.validators import validate_tax_jurisdiction + + malicious_inputs = [ + "'; DROP TABLE users; --", + "1' OR '1'='1", + "US-CA'; DELETE FROM users; --", + ] + + # Act & Assert + for malicious in malicious_inputs: + result = validate_tax_jurisdiction(malicious) + assert result is False, f"Should reject malicious input: {malicious}" + except ImportError: + pytest.skip("Validators not implemented yet") diff --git a/tradingagents/api/models/user.py b/tradingagents/api/models/user.py index c0f897be..be353c2e 100644 --- a/tradingagents/api/models/user.py +++ b/tradingagents/api/models/user.py @@ -8,18 +8,63 @@ from tradingagents.api.models.base import Base, TimestampMixin class User(Base, TimestampMixin): - """User model for authentication and authorization.""" + """User model for authentication and authorization. + + Attributes: + id: Primary key + username: Unique username for authentication + email: Unique email address + hashed_password: Bcrypt hashed password + full_name: Optional full name + is_active: Whether user account is active + is_superuser: Whether user has admin privileges + tax_jurisdiction: Tax jurisdiction code (e.g., "US", "US-CA", "AU") + timezone: IANA timezone identifier (e.g., "America/New_York", "UTC") + api_key_hash: Bcrypt hash of API key (if user has API key) + is_verified: Whether user email is verified + strategies: Related Strategy objects owned by this user + """ __tablename__ = "users" + # Primary identification id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) username: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True) email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True) hashed_password: Mapped[str] = mapped_column(String(255), nullable=False) full_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + + # User status and permissions is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + # Issue #3: Profile fields + tax_jurisdiction: Mapped[str] = mapped_column( + String(10), + default="AU", + nullable=False, + comment="Tax jurisdiction code (e.g., US, US-CA, AU-NSW)" + ) + timezone: Mapped[str] = mapped_column( + String(50), + default="Australia/Sydney", + nullable=False, + comment="IANA timezone identifier (e.g., America/New_York, UTC)" + ) + api_key_hash: Mapped[Optional[str]] = mapped_column( + String(255), + nullable=True, + index=True, + unique=True, + comment="Bcrypt hash of API key for programmatic access" + ) + is_verified: Mapped[bool] = mapped_column( + Boolean, + default=False, + nullable=False, + comment="Whether user email has been verified" + ) + # Relationship to strategies strategies: Mapped[List["Strategy"]] = relationship( "Strategy", diff --git a/tradingagents/api/services/__init__.py b/tradingagents/api/services/__init__.py index 912bc452..c49bb76c 100644 --- a/tradingagents/api/services/__init__.py +++ b/tradingagents/api/services/__init__.py @@ -6,10 +6,31 @@ from tradingagents.api.services.auth_service import ( create_access_token, decode_access_token, ) +from tradingagents.api.services.api_key_service import ( + generate_api_key, + hash_api_key, + verify_api_key, +) +from tradingagents.api.services.validators import ( + validate_timezone, + validate_tax_jurisdiction, + get_available_timezones, + get_available_tax_jurisdictions, +) __all__ = [ + # Auth service "hash_password", "verify_password", "create_access_token", "decode_access_token", + # API key service + "generate_api_key", + "hash_api_key", + "verify_api_key", + # Validators + "validate_timezone", + "validate_tax_jurisdiction", + "get_available_timezones", + "get_available_tax_jurisdictions", ] diff --git a/tradingagents/api/services/api_key_service.py b/tradingagents/api/services/api_key_service.py new file mode 100644 index 00000000..b93e5785 --- /dev/null +++ b/tradingagents/api/services/api_key_service.py @@ -0,0 +1,115 @@ +"""API key service for secure key generation and hashing. + +This module provides utilities for generating and verifying API keys: +- Generate secure random API keys with 'ta_' prefix +- Hash API keys using bcrypt (via pwdlib) +- Verify plain API keys against hashed values + +Security: +- Never store plain API keys in the database +- Use bcrypt for hashing (via pwdlib PasswordHash) +- API keys are URL-safe base64 encoded (32 bytes) +""" + +import secrets +from pwdlib import PasswordHash + + +# API key hashing with bcrypt (same context as passwords for consistency) +api_key_context = PasswordHash.recommended() + + +def generate_api_key() -> str: + """ + Generate a secure random API key. + + Returns a URL-safe API key with the 'ta_' prefix followed by + 32 bytes of random data encoded as base64. + + Format: ta_ + Example: ta_vK9x8pL2mN3qR5sT7uW1yZ4aB6cD8eF0gH2jK4lM6n + + Returns: + str: Generated API key (plaintext) + + Security: + - Uses secrets.token_urlsafe() for cryptographically strong randomness + - 32 bytes = 256 bits of entropy + - Never store the returned value directly in database + + Example: + >>> api_key = generate_api_key() + >>> api_key.startswith("ta_") + True + >>> len(api_key) > 40 # ta_ + base64(32 bytes) + True + """ + # Generate 32 bytes (256 bits) of cryptographically secure random data + # URL-safe base64 encoding makes it safe for URLs and headers + random_part = secrets.token_urlsafe(32) + + return f"ta_{random_part}" + + +def hash_api_key(api_key: str) -> str: + """ + Hash an API key using bcrypt. + + Uses the same pwdlib PasswordHash context as password hashing + for consistency. The hashed value can be safely stored in the database. + + Args: + api_key: Plain text API key (from generate_api_key()) + + Returns: + str: Bcrypt hash of the API key + + Security: + - Uses bcrypt algorithm (via Argon2 default from pwdlib) + - Hash is one-way and computationally expensive to reverse + - Store this hash in database, not the plain API key + + Example: + >>> api_key = generate_api_key() + >>> hashed = hash_api_key(api_key) + >>> hashed != api_key # Hash is different from plain key + True + >>> len(hashed) > 50 # Bcrypt hashes are long + True + """ + return api_key_context.hash(api_key) + + +def verify_api_key(plain_api_key: str, hashed_api_key: str) -> bool: + """ + Verify a plain API key against a hash. + + Checks if the provided plain API key matches the stored hash. + Uses constant-time comparison to prevent timing attacks. + + Args: + plain_api_key: Plain text API key (from user request) + hashed_api_key: Hashed API key (from database) + + Returns: + bool: True if API key matches hash, False otherwise + + Security: + - Uses constant-time comparison + - Safe against timing attacks + - Computationally expensive to slow down brute force + + Example: + >>> api_key = generate_api_key() + >>> hashed = hash_api_key(api_key) + >>> verify_api_key(api_key, hashed) + True + >>> verify_api_key("wrong_key", hashed) + False + """ + try: + return api_key_context.verify(plain_api_key, hashed_api_key) + except Exception: + # If verification fails for any reason (malformed hash, etc.) + # return False rather than raising an exception + return False diff --git a/tradingagents/api/services/validators.py b/tradingagents/api/services/validators.py new file mode 100644 index 00000000..5f5f189b --- /dev/null +++ b/tradingagents/api/services/validators.py @@ -0,0 +1,303 @@ +"""Validators for user profile fields. + +This module provides validation functions for: +- Timezones (IANA timezone database) +- Tax jurisdictions (country codes and state/province codes) + +All validators return True/False and are designed to be used +in Pydantic models and database constraints. +""" + +from typing import Set +from zoneinfo import ZoneInfo, available_timezones + + +# Valid tax jurisdictions (ISO 3166-1 alpha-2 country codes + state/province) +# Format: "CC" for country-level, "CC-SS" for state/province-level +# This is a comprehensive list covering major jurisdictions +VALID_TAX_JURISDICTIONS: Set[str] = { + # Country-level codes (ISO 3166-1 alpha-2) + "US", # United States + "CA", # Canada + "GB", # United Kingdom + "AU", # Australia + "DE", # Germany + "FR", # France + "IT", # Italy + "ES", # Spain + "NL", # Netherlands + "BE", # Belgium + "CH", # Switzerland + "AT", # Austria + "SE", # Sweden + "NO", # Norway + "DK", # Denmark + "FI", # Finland + "IE", # Ireland + "PT", # Portugal + "GR", # Greece + "PL", # Poland + "CZ", # Czech Republic + "HU", # Hungary + "RO", # Romania + "JP", # Japan + "CN", # China + "KR", # South Korea + "IN", # India + "SG", # Singapore + "HK", # Hong Kong + "NZ", # New Zealand + "MX", # Mexico + "BR", # Brazil + "AR", # Argentina + "CL", # Chile + "ZA", # South Africa + "AE", # United Arab Emirates + "SA", # Saudi Arabia + "IL", # Israel + "TR", # Turkey + "RU", # Russia + "UA", # Ukraine + "TH", # Thailand + "MY", # Malaysia + "ID", # Indonesia + "PH", # Philippines + "VN", # Vietnam + "TW", # Taiwan + + # United States - State level + "US-AL", # Alabama + "US-AK", # Alaska + "US-AZ", # Arizona + "US-AR", # Arkansas + "US-CA", # California + "US-CO", # Colorado + "US-CT", # Connecticut + "US-DE", # Delaware + "US-FL", # Florida + "US-GA", # Georgia + "US-HI", # Hawaii + "US-ID", # Idaho + "US-IL", # Illinois + "US-IN", # Indiana + "US-IA", # Iowa + "US-KS", # Kansas + "US-KY", # Kentucky + "US-LA", # Louisiana + "US-ME", # Maine + "US-MD", # Maryland + "US-MA", # Massachusetts + "US-MI", # Michigan + "US-MN", # Minnesota + "US-MS", # Mississippi + "US-MO", # Missouri + "US-MT", # Montana + "US-NE", # Nebraska + "US-NV", # Nevada + "US-NH", # New Hampshire + "US-NJ", # New Jersey + "US-NM", # New Mexico + "US-NY", # New York + "US-NC", # North Carolina + "US-ND", # North Dakota + "US-OH", # Ohio + "US-OK", # Oklahoma + "US-OR", # Oregon + "US-PA", # Pennsylvania + "US-RI", # Rhode Island + "US-SC", # South Carolina + "US-SD", # South Dakota + "US-TN", # Tennessee + "US-TX", # Texas + "US-UT", # Utah + "US-VT", # Vermont + "US-VA", # Virginia + "US-WA", # Washington + "US-WV", # West Virginia + "US-WI", # Wisconsin + "US-WY", # Wyoming + "US-DC", # District of Columbia + + # Canada - Province/Territory level + "CA-AB", # Alberta + "CA-BC", # British Columbia + "CA-MB", # Manitoba + "CA-NB", # New Brunswick + "CA-NL", # Newfoundland and Labrador + "CA-NS", # Nova Scotia + "CA-NT", # Northwest Territories + "CA-NU", # Nunavut + "CA-ON", # Ontario + "CA-PE", # Prince Edward Island + "CA-QC", # Quebec + "CA-SK", # Saskatchewan + "CA-YT", # Yukon + + # Australia - State/Territory level + "AU-NSW", # New South Wales + "AU-VIC", # Victoria + "AU-QLD", # Queensland + "AU-SA", # South Australia + "AU-WA", # Western Australia + "AU-TAS", # Tasmania + "AU-NT", # Northern Territory + "AU-ACT", # Australian Capital Territory +} + + +def validate_timezone(timezone: str) -> bool: + """ + Validate timezone against IANA timezone database. + + Checks if the provided timezone string is a valid IANA timezone + identifier. Uses Python's zoneinfo module which is based on the + IANA timezone database (tzdata). + + Args: + timezone: Timezone identifier (e.g., "America/New_York", "UTC") + + Returns: + bool: True if valid IANA timezone, False otherwise + + Valid Examples: + - "UTC" + - "GMT" + - "America/New_York" + - "Europe/London" + - "Asia/Tokyo" + - "Australia/Sydney" + + Invalid Examples: + - "PST" (abbreviation, not IANA identifier) + - "EST" (abbreviation) + - "New York" (wrong format) + - "america/new_york" (wrong case) + + Example: + >>> validate_timezone("America/New_York") + True + >>> validate_timezone("UTC") + True + >>> validate_timezone("PST") + False + >>> validate_timezone("Invalid/Zone") + False + + Note: + - Case-sensitive (must match IANA database exactly) + - Use available_timezones() to get full list of valid zones + - Rejects timezone abbreviations (PST, EST, etc.) + """ + if not timezone or not isinstance(timezone, str): + return False + + # Check if timezone exists in IANA database + # This is more efficient than trying to create a ZoneInfo object + return timezone in available_timezones() + + +def validate_tax_jurisdiction(jurisdiction: str) -> bool: + """ + Validate tax jurisdiction code. + + Checks if the provided jurisdiction is in the list of valid + tax jurisdictions. Supports both country-level and state/province-level + jurisdictions. + + Format: + - Country level: "CC" (2-letter ISO 3166-1 alpha-2) + - State/Province level: "CC-SS" (country-state with hyphen) + + Args: + jurisdiction: Tax jurisdiction code + + Returns: + bool: True if valid jurisdiction, False otherwise + + Valid Examples: + - "US" (United States) + - "CA" (Canada) + - "GB" (United Kingdom) + - "US-CA" (California, USA) + - "US-NY" (New York, USA) + - "CA-ON" (Ontario, Canada) + - "AU-NSW" (New South Wales, Australia) + + Invalid Examples: + - "us" (lowercase) + - "USA" (3 letters) + - "US_CA" (underscore instead of hyphen) + - "US/CA" (slash instead of hyphen) + - "XX" (non-existent country) + + Example: + >>> validate_tax_jurisdiction("US") + True + >>> validate_tax_jurisdiction("US-CA") + True + >>> validate_tax_jurisdiction("us") + False + >>> validate_tax_jurisdiction("XX-YY") + False + + Note: + - Case-sensitive (must be uppercase) + - Hyphen separator for state/province codes + - List is comprehensive but not exhaustive + - Add new jurisdictions to VALID_TAX_JURISDICTIONS set as needed + """ + if not jurisdiction or not isinstance(jurisdiction, str): + return False + + return jurisdiction in VALID_TAX_JURISDICTIONS + + +def get_available_timezones() -> Set[str]: + """ + Get set of all available IANA timezones. + + Returns the complete set of valid timezone identifiers from + the IANA timezone database. + + Returns: + Set[str]: Set of valid timezone identifiers + + Example: + >>> timezones = get_available_timezones() + >>> "America/New_York" in timezones + True + >>> len(timezones) > 500 # Hundreds of valid timezones + True + + Note: + - This is a cached call (zoneinfo caches available_timezones) + - Use for populating dropdowns or validation lists + - Contains all IANA timezone database entries + """ + return available_timezones() + + +def get_available_tax_jurisdictions() -> Set[str]: + """ + Get set of all available tax jurisdictions. + + Returns the complete set of valid tax jurisdiction codes. + + Returns: + Set[str]: Set of valid tax jurisdiction codes + + Example: + >>> jurisdictions = get_available_tax_jurisdictions() + >>> "US" in jurisdictions + True + >>> "US-CA" in jurisdictions + True + >>> len(jurisdictions) > 50 # Many jurisdictions supported + True + + Note: + - Returns a copy to prevent external modification + - Use for populating dropdowns or validation lists + - Includes both country and state/province level codes + """ + return VALID_TAX_JURISDICTIONS.copy()