This commit is contained in:
GitWit 2025-06-22 12:27:13 +08:00 committed by GitHub
commit eb7545c127
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 783 additions and 0 deletions

17
.env.example Normal file
View File

@ -0,0 +1,17 @@
# Rename this file to .env and fill in your actual values.
# For app.core.config.py
# Generate a strong secret key, e.g., using: openssl rand -hex 32
SECRET_KEY=your_very_secret_jwt_key_here
ALGORITHM=HS256
ACCESS_TOKEN_EXPIRE_MINUTES=30
# For Database (app.core.config.py & alembic)
# Example for PostgreSQL:
DATABASE_URL=postgresql://user:password@localhost:5432/tradingdb_dev
# To run with uvicorn (example):
# uvicorn app.main:app --reload
# Ensure the DATABASE_URL points to a running PostgreSQL instance.
# Apply migrations before running the app for the first time:
# alembic upgrade head

113
README.md
View File

@ -210,4 +210,117 @@ Please reference our work if you find *TradingAgents* provides you with some hel
primaryClass={q-fin.TR},
url={https://arxiv.org/abs/2412.20138},
}
---
## FastAPI Backend API (Work in Progress)
This section describes the setup for the FastAPI backend API, which is currently under development.
### Setup and Running
1. **Create and activate a Python virtual environment.**
```bash
python -m venv venv
source venv/bin/activate
# On Windows: venv\Scripts\activate
```
2. **Install dependencies:**
The main project `requirements.txt` includes dependencies for the `TradingAgents` framework. For the API, additional dependencies are also listed there.
```bash
pip install -r requirements.txt
```
3. **Set up environment variables:**
Copy the `.env.example` file to `.env` and update the values, especially `DATABASE_URL` and `SECRET_KEY`.
```bash
cp .env.example .env
# Open .env and edit the variables
```
A strong `SECRET_KEY` can be generated using:
```bash
openssl rand -hex 32
```
4. **Database Setup (PostgreSQL):**
- Ensure you have a PostgreSQL server running and accessible.
- Update `DATABASE_URL` in your `.env` file to point to your PostgreSQL instance (e.g., `postgresql://user:password@localhost:5432/yourdbname`).
- Apply database migrations using Alembic:
```bash
alembic upgrade head
```
This will create the necessary tables (e.g., `users`).
5. **Create an initial user (Optional - for testing login):**
You'll need a user in the database to test the login. You can create one using a Python script that utilizes `app.services.user_service.create_user` or by manually inserting into the database.
Example (run from project root, e.g., `python -m scripts.create_initial_user` after creating such a script):
```python
# Example: scripts/create_initial_user.py
# from app.db.session import SessionLocal
# from app.services.user_service import create_user
# from app.schemas.user import UserCreate
#
# db = SessionLocal()
#
# user_in = UserCreate(email="test@example.com", password="password123")
# try:
# db_user = create_user(db, user_in=user_in)
# print(f"User {db_user.email} created successfully.")
# except Exception as e:
# print(f"Error creating user: {e}")
# finally:
# db.close()
```
6. **Run the FastAPI application:**
```bash
uvicorn app.main:app --reload
```
The API will typically be available at `http://127.0.0.1:8000`.
The OpenAPI documentation (Swagger UI) will be at `http://127.0.0.1:8000/api/v1/openapi.json` (raw JSON) or `http://127.0.0.1:8000/docs` (Swagger UI) if you enable it in `app/main.py`. For now, the openapi_url is set directly. Access `/docs` or `/redoc` for interactive API documentation.
### API Endpoints
- **`POST /api/v1/auth/login`**: Authenticate a user and receive a JWT token.
- Request body: `application/x-www-form-urlencoded` with `username` (email) and `password`.
- **`GET /api/v1/strategies`**: Retrieve a list of (dummy) strategies for the authenticated user. Requires Bearer token authentication.
### Project Structure (API part)
```
app/
├── api/
│ └── v1/
│ ├── api.py # Main v1 router
│ └── endpoints/
│ ├── auth.py # Auth endpoints (login)
│ └── strategies.py # Strategies endpoints
├── core/
│ ├── config.py # Configuration settings (dotenv loading)
│ └── security.py # Password hashing, JWT utils
├── db/
│ ├── base.py # SQLAlchemy Base, imports all models
│ └── session.py # SQLAlchemy engine and session management
├── models/
│ └── user.py # User SQLAlchemy model
├── schemas/
│ ├── strategy.py # Pydantic schemas for Strategy
│ ├── token.py # Pydantic schemas for Token
│ └── user.py # Pydantic schemas for User
├── services/
│ └── user_service.py # User related database operations
├── __init__.py
└── main.py # Main FastAPI application
alembic/ # Alembic migration scripts
├── versions/
└── env.py # Alembic environment config
alembic.ini # Alembic configuration
.env.example # Example environment variables
requirements.txt # Project dependencies (includes API deps)
README.md # This file
... (other TradingAgents files)
```
```

142
alembic.ini Normal file
View File

@ -0,0 +1,142 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
# The actual URL will be read from environment variable or config.py in env.py
sqlalchemy.url = postgresql://user:password@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
alembic/README Normal file
View File

@ -0,0 +1 @@
Generic single-database configuration.

94
alembic/env.py Normal file
View File

@ -0,0 +1,94 @@
import os
import sys
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlalchemy.ext.declarative import declarative_base
from alembic import context
# Import your app's models' Base and settings
# Ensure your app is in the Python path
sys.path.append(os.path.join(os.path.dirname(__file__), '..')) # Add project root to sys.path
from app.db.base import Base # Your models' base
from app.core.config import settings # Your application settings
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
target_metadata = Base.metadata
# target_metadata = None # Original line
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def get_url():
return settings.DATABASE_URL
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = get_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
# This line uses the database URL from your application's settings
configuration = config.get_section(config.config_ini_section)
configuration["sqlalchemy.url"] = get_url()
connectable = engine_from_config(
configuration, # Use updated configuration
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

28
alembic/script.py.mako Normal file
View File

@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@ -0,0 +1,41 @@
"""create_users_table
Revision ID: fdbf94c56ad3
Revises:
Create Date: 2025-06-22 03:26:11.493175
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'fdbf94c56ad3'
down_revision: Union[str, Sequence[str], None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('users',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('email', sa.String(), nullable=False),
sa.Column('hashed_password', sa.String(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_users_id'), table_name='users')
op.drop_index(op.f('ix_users_email'), table_name='users')
op.drop_table('users')
# ### end Alembic commands ###

1
app/__init__.py Normal file
View File

@ -0,0 +1 @@
# This file makes 'app' a Python package.

1
app/api/__init__.py Normal file
View File

@ -0,0 +1 @@
# This file makes 'api' a Python package.

8
app/api/v1/api.py Normal file
View File

@ -0,0 +1,8 @@
from fastapi import APIRouter
from app.api.v1.endpoints import auth, strategies
api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
api_router.include_router(strategies.router, prefix="/strategies", tags=["strategies"])

View File

@ -0,0 +1,48 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from app.core.security import create_access_token, verify_password
from app.schemas.token import Token
from app.db.session import get_db
# We'll create user_service.py next for get_user_by_email
from app.services.user_service import get_user_by_email
from app.models.user import User as UserModel # To avoid confusion with pydantic User schema
router = APIRouter()
@router.post("/login", response_model=Token)
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db)
):
user = get_user_by_email(db, email=form_data.username) # OAuth2 form uses 'username' for email
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token = create_access_token(
data={"sub": user.email} # 'sub' is the subject of the token (user's email)
)
return {"access_token": access_token, "token_type": "bearer"}
# Placeholder for user creation - to be developed further if needed for testing
# This is NOT a production-ready registration endpoint.
# from app.schemas.user import UserCreate
# from app.core.security import get_password_hash
# @router.post("/register", response_model=UserSchema) # Assuming UserSchema is your Pydantic model for User response
# def register_user(user_in: UserCreate, db: Session = Depends(get_db)):
# db_user = get_user_by_email(db, email=user_in.email)
# if db_user:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Email already registered",
# )
# hashed_password = get_password_hash(user_in.password)
# db_user = UserModel(email=user_in.email, hashed_password=hashed_password)
# db.add(db_user)
# db.commit()
# db.refresh(db_user)
# return db_user

View File

@ -0,0 +1,61 @@
from typing import List, Any
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from fastapi.security import OAuth2PasswordBearer
from app.schemas.strategy import Strategy
from app.schemas.token import TokenPayload
from app.db.session import get_db
from app.services.user_service import get_user_by_email
from app.models.user import User as UserModel
from app.core.security import decode_access_token
from app.core.config import settings
router = APIRouter()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login")
async def get_current_user(
db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)
) -> UserModel:
payload = decode_access_token(token)
if payload is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials (payload missing)",
headers={"WWW-Authenticate": "Bearer"},
)
email: str | None = payload.get("sub")
if email is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials (email missing in token)",
headers={"WWW-Authenticate": "Bearer"},
)
user = get_user_by_email(db, email=email)
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
return user
@router.get("/", response_model=List[Strategy])
async def read_strategies(
db: Session = Depends(get_db),
current_user: UserModel = Depends(get_current_user)
):
"""
Retrieve strategies for the current user.
This is a dummy endpoint and returns a hardcoded list.
"""
# In a real application, you would fetch strategies for current_user.id from the database
dummy_strategies = [
Strategy(id=1, name="DQN_Long_BTC", algorithm="DQN", user_id=current_user.id, status="running", params={"symbol": "BTC/USDT"}),
Strategy(id=2, name="PPO_Short_ETH", algorithm="PPO", user_id=current_user.id, status="stopped", params={"symbol": "ETH/USDT"}),
]
return dummy_strategies

1
app/core/__init__.py Normal file
View File

@ -0,0 +1 @@
# This file makes 'core' a Python package.

23
app/core/config.py Normal file
View File

@ -0,0 +1,23 @@
import os
from dotenv import load_dotenv
from pydantic import BaseSettings
load_dotenv()
class Settings(BaseSettings):
PROJECT_NAME: str = "Trading Platform API"
API_V1_STR: str = "/api/v1"
# JWT settings
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key_that_should_be_in_env_file")
ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 30))
# Database settings
DATABASE_URL: str = os.getenv("DATABASE_URL", "postgresql://user:password@localhost:5432/tradingdb")
class Config:
case_sensitive = True
# env_file = ".env" # Handled by load_dotenv() for more flexibility
settings = Settings()

36
app/core/security.py Normal file
View File

@ -0,0 +1,36 @@
from datetime import datetime, timedelta, timezone
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from app.core.config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
ALGORITHM = settings.ALGORITHM
SECRET_KEY = settings.SECRET_KEY
ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
def decode_access_token(token: str) -> Optional[dict]:
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
except JWTError:
return None

1
app/db/__init__.py Normal file
View File

@ -0,0 +1 @@
# This file makes 'db' a Python package.

16
app/db/base.py Normal file
View File

@ -0,0 +1,16 @@
# Import all the models, so that Base has them before being
# imported by Alembic
from app.models.user import User # Adjust if your User model is elsewhere
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
# You might need to adjust the import path for User above
# depending on your final User model location relative to this file.
# For now, assuming app.models.user.User is correct.
# If you have multiple model files, you would import them all here, e.g.:
# from app.models.item import Item
# from app.models.order import Order
# This Base will be used by Alembic in env.py to know about your models.

14
app/db/session.py Normal file
View File

@ -0,0 +1,14 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.core.config import settings
engine = create_engine(settings.DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Dependency to get DB session
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

14
app/main.py Normal file
View File

@ -0,0 +1,14 @@
from fastapi import FastAPI
from app.core.config import settings
from app.api.v1.api import api_router as api_v1_router
app = FastAPI(
title=settings.PROJECT_NAME,
openapi_url=f"{settings.API_V1_STR}/openapi.json"
)
@app.get("/")
async def root():
return {"message": f"Welcome to {settings.PROJECT_NAME}"}
app.include_router(api_v1_router, prefix=settings.API_V1_STR)

1
app/models/__init__.py Normal file
View File

@ -0,0 +1 @@
# This file makes 'models' a Python package.

13
app/models/user.py Normal file
View File

@ -0,0 +1,13 @@
from sqlalchemy import Column, Integer, String
from sqlalchemy.orm import declarative_base
Base = declarative_base()
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
email = Column(String, unique=True, index=True, nullable=False)
hashed_password = Column(String, nullable=False)
# Add other fields like is_active, is_superuser, role etc. later if needed
# role = Column(String, default="user")

1
app/schemas/__init__.py Normal file
View File

@ -0,0 +1 @@
# This file makes 'schemas' a Python package.

26
app/schemas/strategy.py Normal file
View File

@ -0,0 +1,26 @@
from pydantic import BaseModel
from typing import Optional, Dict, Any
class StrategyBase(BaseModel):
name: str
algorithm: Optional[str] = None
params: Optional[Dict[str, Any]] = None
class StrategyCreate(StrategyBase):
pass
class StrategyUpdate(StrategyBase):
pass
class StrategyInDBBase(StrategyBase):
id: int
user_id: int # Assuming strategies are linked to users
class Config:
orm_mode = True # For Pydantic v1 compatibility
class Strategy(StrategyInDBBase):
status: Optional[str] = "stopped" # Example field
class StrategyInDB(StrategyInDBBase):
pass

13
app/schemas/token.py Normal file
View File

@ -0,0 +1,13 @@
from typing import Optional
from pydantic import BaseModel
class Token(BaseModel):
access_token: str
token_type: str
class TokenPayload(BaseModel):
sub: Optional[str] = None # 'sub' is standard for subject (user identifier)
# Add any other data you want to store in the token payload
# For example, roles, permissions, etc.
# email: Optional[str] = None
# exp: Optional[int] = None # Already handled by create_access_token's expiry logic

29
app/schemas/user.py Normal file
View File

@ -0,0 +1,29 @@
from pydantic import BaseModel, EmailStr
# Shared properties
class UserBase(BaseModel):
email: EmailStr
# Properties to receive via API on creation
class UserCreate(UserBase):
password: str
# Properties to receive via API on update
class UserUpdate(UserBase):
password: Optional[str] = None
# Properties stored in DB
class UserInDBBase(UserBase):
id: int
hashed_password: str
class Config:
orm_mode = True # Changed from from_attributes = True for Pydantic v1 compatibility
# Additional properties to return via API
class User(UserInDBBase):
pass
# Additional properties stored in DB
class UserInDB(UserInDBBase):
pass

1
app/services/__init__.py Normal file
View File

@ -0,0 +1 @@
# This file makes 'services' a Python package.

View File

@ -0,0 +1,30 @@
from sqlalchemy.orm import Session
from app.models.user import User
from app.schemas.user import UserCreate # Will be used for user creation later
from app.core.security import get_password_hash # For user creation
def get_user_by_email(db: Session, *, email: str) -> User | None:
return db.query(User).filter(User.email == email).first()
def create_user(db: Session, *, user_in: UserCreate) -> User:
hashed_password = get_password_hash(user_in.password)
db_user = User(
email=user_in.email,
hashed_password=hashed_password,
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
# Example of how to create a user (e.g., for a CLI command or initial setup script)
# This is not an API endpoint.
# def init_db_add_user(db: Session, email: str, password: str) -> User:
# user = get_user_by_email(db, email=email)
# if not user:
# user_in = UserCreate(email=email, password=password)
# user = create_user(db, user_in=user_in)
# print(f"User {email} created.")
# return user
# print(f"User {email} already exists.")
# return user

View File

@ -22,3 +22,12 @@ redis
chainlit
rich
questionary
python-jose[cryptography]
passlib[bcrypt]
sqlalchemy
psycopg2-binary
alembic
fastapi
uvicorn[standard]
pydantic[email]
python-dotenv