36 lines
1.2 KiB
Python
36 lines
1.2 KiB
Python
"""Database configuration and setup."""
|
|
|
|
import os
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
# Build DATABASE_URL from environment variables
|
|
POSTGRES_USER = os.getenv("POSTGRES_USER", "tradingagents")
|
|
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD")
|
|
POSTGRES_DB = os.getenv("POSTGRES_DB", "tradingagents")
|
|
POSTGRES_HOST = os.getenv("POSTGRES_HOST", "localhost")
|
|
POSTGRES_PORT = os.getenv("POSTGRES_PORT", "5432")
|
|
|
|
if not POSTGRES_PASSWORD:
|
|
raise ValueError("POSTGRES_PASSWORD environment variable is required")
|
|
|
|
# Use asyncpg for async PostgreSQL
|
|
ASYNC_DATABASE_URL = f"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}"
|
|
|
|
async_engine = create_async_engine(ASYNC_DATABASE_URL)
|
|
AsyncSessionLocal = sessionmaker(
|
|
async_engine, class_=AsyncSession, expire_on_commit=False
|
|
)
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
async def get_db():
|
|
"""Get async database session."""
|
|
async with AsyncSessionLocal() as session:
|
|
try:
|
|
yield session
|
|
finally:
|
|
await session.close() |