added dao layer
This commit is contained in:
parent
5d258bf67f
commit
b9838282b3
|
|
@ -0,0 +1 @@
|
||||||
|
"""Data Access Object layer for TradingAgents API."""
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
"""Database configuration and setup."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
# Build DATABASE_URL from environment variables
|
||||||
|
POSTGRES_USER = os.getenv("POSTGRES_USER", "tradingagents")
|
||||||
|
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD")
|
||||||
|
POSTGRES_DB = os.getenv("POSTGRES_DB", "tradingagents")
|
||||||
|
POSTGRES_HOST = os.getenv("POSTGRES_HOST", "localhost")
|
||||||
|
POSTGRES_PORT = os.getenv("POSTGRES_PORT", "5432")
|
||||||
|
|
||||||
|
if not POSTGRES_PASSWORD:
|
||||||
|
raise ValueError("POSTGRES_PASSWORD environment variable is required")
|
||||||
|
|
||||||
|
# Use asyncpg for async PostgreSQL
|
||||||
|
ASYNC_DATABASE_URL = f"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}"
|
||||||
|
|
||||||
|
async_engine = create_async_engine(ASYNC_DATABASE_URL)
|
||||||
|
AsyncSessionLocal = sessionmaker(
|
||||||
|
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db():
|
||||||
|
"""Get async database session."""
|
||||||
|
async with AsyncSessionLocal() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
"""SQLAlchemy models for database tables."""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Integer, String, Boolean
|
||||||
|
from .database import Base
|
||||||
|
|
||||||
|
|
||||||
|
class UserModel(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
username = Column(String, unique=True, index=True, nullable=False)
|
||||||
|
email = Column(String, unique=True, index=True)
|
||||||
|
full_name = Column(String)
|
||||||
|
hashed_password = Column(String, nullable=False)
|
||||||
|
disabled = Column(Boolean, default=False)
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
"""User Data Access Object for database operations."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from .models import UserModel
|
||||||
|
|
||||||
|
|
||||||
|
class UserDAO:
|
||||||
|
"""Data Access Object for User operations."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_user(
|
||||||
|
db: AsyncSession,
|
||||||
|
username: str,
|
||||||
|
email: Optional[str],
|
||||||
|
full_name: Optional[str],
|
||||||
|
hashed_password: str,
|
||||||
|
disabled: bool = False
|
||||||
|
) -> UserModel:
|
||||||
|
"""Create a new user in the database."""
|
||||||
|
user = UserModel(
|
||||||
|
username=username,
|
||||||
|
email=email,
|
||||||
|
full_name=full_name,
|
||||||
|
hashed_password=hashed_password,
|
||||||
|
disabled=disabled
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_by_username(db: AsyncSession, username: str) -> Optional[UserModel]:
|
||||||
|
"""Get user by username."""
|
||||||
|
result = await db.execute(select(UserModel).where(UserModel.username == username))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]:
|
||||||
|
"""Get user by email."""
|
||||||
|
result = await db.execute(select(UserModel).where(UserModel.email == email))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_by_id(db: AsyncSession, user_id: int) -> Optional[UserModel]:
|
||||||
|
"""Get user by ID."""
|
||||||
|
result = await db.execute(select(UserModel).where(UserModel.id == user_id))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def update_user(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: int,
|
||||||
|
**kwargs
|
||||||
|
) -> Optional[UserModel]:
|
||||||
|
"""Update user fields."""
|
||||||
|
user = await UserDAO.get_user_by_id(db, user_id)
|
||||||
|
if user:
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(user, key):
|
||||||
|
setattr(user, key, value)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def delete_user(db: AsyncSession, user_id: int) -> bool:
|
||||||
|
"""Delete user by ID."""
|
||||||
|
user = await UserDAO.get_user_by_id(db, user_id)
|
||||||
|
if user:
|
||||||
|
await db.delete(user)
|
||||||
|
await db.commit()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
@ -0,0 +1,78 @@
|
||||||
|
"""User signup endpoint."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
from pydantic import BaseModel, EmailStr
|
||||||
|
|
||||||
|
from ..dao.database import get_db
|
||||||
|
from ..dao.user_dao import UserDAO
|
||||||
|
from ..users import User
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/users", tags=["users"])
|
||||||
|
|
||||||
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
|
class UserCreate(BaseModel):
|
||||||
|
username: str
|
||||||
|
email: Optional[EmailStr] = None
|
||||||
|
full_name: Optional[str] = None
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
id: int
|
||||||
|
username: str
|
||||||
|
email: Optional[str] = None
|
||||||
|
full_name: Optional[str] = None
|
||||||
|
disabled: bool
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
"""Hash a password using bcrypt."""
|
||||||
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/signup", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def signup(
|
||||||
|
user_data: UserCreate,
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Create a new user account."""
|
||||||
|
# Check if username already exists
|
||||||
|
existing_user = await UserDAO.get_user_by_username(db, user_data.username)
|
||||||
|
if existing_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Username already registered"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if email already exists (if provided)
|
||||||
|
if user_data.email:
|
||||||
|
existing_email = await UserDAO.get_user_by_email(db, user_data.email)
|
||||||
|
if existing_email:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Email already registered"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hash the password
|
||||||
|
hashed_password = hash_password(user_data.password)
|
||||||
|
|
||||||
|
# Create the user
|
||||||
|
user = await UserDAO.create_user(
|
||||||
|
db=db,
|
||||||
|
username=user_data.username,
|
||||||
|
email=user_data.email,
|
||||||
|
full_name=user_data.full_name,
|
||||||
|
hashed_password=hashed_password
|
||||||
|
)
|
||||||
|
|
||||||
|
return UserResponse.from_orm(user)
|
||||||
Loading…
Reference in New Issue