TradingAgents/tradingagents/portfolio/supabase_client.py

392 lines
15 KiB
Python

"""PostgreSQL database client for the Portfolio Manager.
Uses ``psycopg2`` with the ``SUPABASE_CONNECTION_STRING`` env var to talk
directly to the Supabase-hosted PostgreSQL database. No ORM — see
``docs/agent/decisions/012-portfolio-no-orm.md`` for rationale.
Usage::
from tradingagents.portfolio.supabase_client import SupabaseClient
client = SupabaseClient.get_instance()
portfolio = client.get_portfolio("some-uuid")
"""
from __future__ import annotations
import json
import uuid
import psycopg2
import psycopg2.extras
from tradingagents.portfolio.config import get_portfolio_config
from tradingagents.portfolio.exceptions import (
DuplicatePortfolioError,
HoldingNotFoundError,
PortfolioError,
PortfolioNotFoundError,
)
from tradingagents.portfolio.models import (
Holding,
Portfolio,
PortfolioSnapshot,
Trade,
)
class SupabaseClient:
"""Singleton PostgreSQL CRUD client for portfolio data.
All public methods translate database errors into domain exceptions
and return typed model instances.
"""
_instance: SupabaseClient | None = None
def __init__(self, connection_string: str) -> None:
self._dsn = self._fix_dsn(connection_string)
self._conn = psycopg2.connect(self._dsn)
self._conn.autocommit = True
@staticmethod
def _fix_dsn(dsn: str) -> str:
"""URL-encode the password if it contains special characters."""
from urllib.parse import quote
if "://" not in dsn:
return dsn # already key=value format
scheme, rest = dsn.split("://", 1)
at_idx = rest.rfind("@")
if at_idx == -1:
return dsn
userinfo = rest[:at_idx]
hostinfo = rest[at_idx + 1:]
colon_idx = userinfo.find(":")
if colon_idx == -1:
return dsn
user = userinfo[:colon_idx]
password = userinfo[colon_idx + 1:]
encoded = quote(password, safe="")
return f"{scheme}://{user}:{encoded}@{hostinfo}"
@classmethod
def get_instance(cls) -> SupabaseClient:
"""Return the singleton instance, creating it if necessary."""
if cls._instance is None:
cfg = get_portfolio_config()
dsn = cfg["supabase_connection_string"]
if not dsn:
raise PortfolioError(
"SUPABASE_CONNECTION_STRING not configured. "
"Set it in .env or as an environment variable."
)
cls._instance = cls(dsn)
return cls._instance
@classmethod
def reset_instance(cls) -> None:
"""Close and reset the singleton (for testing)."""
if cls._instance is not None:
try:
cls._instance._conn.close()
except Exception:
pass
cls._instance = None
def _cursor(self):
"""Return a RealDictCursor."""
return self._conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
# ------------------------------------------------------------------
# Portfolio CRUD
# ------------------------------------------------------------------
def create_portfolio(self, portfolio: Portfolio) -> Portfolio:
"""Insert a new portfolio row."""
pid = portfolio.portfolio_id or str(uuid.uuid4())
try:
with self._cursor() as cur:
cur.execute(
"""INSERT INTO portfolios
(portfolio_id, name, cash, initial_cash, currency, report_path, metadata)
VALUES (%s, %s, %s, %s, %s, %s, %s)
RETURNING *""",
(pid, portfolio.name, portfolio.cash, portfolio.initial_cash,
portfolio.currency, portfolio.report_path,
json.dumps(portfolio.metadata)),
)
row = cur.fetchone()
except psycopg2.errors.UniqueViolation as exc:
raise DuplicatePortfolioError(f"Portfolio already exists: {pid}") from exc
return self._row_to_portfolio(row)
def get_portfolio(self, portfolio_id: str) -> Portfolio:
"""Fetch a portfolio by ID."""
with self._cursor() as cur:
cur.execute("SELECT * FROM portfolios WHERE portfolio_id = %s", (portfolio_id,))
row = cur.fetchone()
if not row:
raise PortfolioNotFoundError(f"Portfolio not found: {portfolio_id}")
return self._row_to_portfolio(row)
def list_portfolios(self) -> list[Portfolio]:
"""Return all portfolios ordered by created_at DESC."""
with self._cursor() as cur:
cur.execute("SELECT * FROM portfolios ORDER BY created_at DESC")
rows = cur.fetchall()
return [self._row_to_portfolio(r) for r in rows]
def update_portfolio(self, portfolio: Portfolio) -> Portfolio:
"""Update mutable portfolio fields (cash, report_path, metadata)."""
with self._cursor() as cur:
cur.execute(
"""UPDATE portfolios
SET cash = %s, report_path = %s, metadata = %s
WHERE portfolio_id = %s
RETURNING *""",
(portfolio.cash, portfolio.report_path,
json.dumps(portfolio.metadata), portfolio.portfolio_id),
)
row = cur.fetchone()
if not row:
raise PortfolioNotFoundError(f"Portfolio not found: {portfolio.portfolio_id}")
return self._row_to_portfolio(row)
def delete_portfolio(self, portfolio_id: str) -> None:
"""Delete a portfolio and all associated data (CASCADE)."""
with self._cursor() as cur:
cur.execute(
"DELETE FROM portfolios WHERE portfolio_id = %s RETURNING portfolio_id",
(portfolio_id,),
)
row = cur.fetchone()
if not row:
raise PortfolioNotFoundError(f"Portfolio not found: {portfolio_id}")
# ------------------------------------------------------------------
# Holdings CRUD
# ------------------------------------------------------------------
def upsert_holding(self, holding: Holding) -> Holding:
"""Insert or update a holding row (upsert on portfolio_id + ticker)."""
hid = holding.holding_id or str(uuid.uuid4())
with self._cursor() as cur:
cur.execute(
"""INSERT INTO holdings
(holding_id, portfolio_id, ticker, shares, avg_cost, sector, industry)
VALUES (%s, %s, %s, %s, %s, %s, %s)
ON CONFLICT ON CONSTRAINT holdings_portfolio_ticker_unique
DO UPDATE SET shares = EXCLUDED.shares,
avg_cost = EXCLUDED.avg_cost,
sector = EXCLUDED.sector,
industry = EXCLUDED.industry
RETURNING *""",
(hid, holding.portfolio_id, holding.ticker.upper(),
holding.shares, holding.avg_cost, holding.sector, holding.industry),
)
row = cur.fetchone()
return self._row_to_holding(row)
def get_holding(self, portfolio_id: str, ticker: str) -> Holding | None:
"""Return the holding for (portfolio_id, ticker), or None."""
with self._cursor() as cur:
cur.execute(
"SELECT * FROM holdings WHERE portfolio_id = %s AND ticker = %s",
(portfolio_id, ticker.upper()),
)
row = cur.fetchone()
return self._row_to_holding(row) if row else None
def list_holdings(self, portfolio_id: str) -> list[Holding]:
"""Return all holdings for a portfolio ordered by cost_basis DESC."""
with self._cursor() as cur:
cur.execute(
"""SELECT * FROM holdings
WHERE portfolio_id = %s
ORDER BY shares * avg_cost DESC""",
(portfolio_id,),
)
rows = cur.fetchall()
return [self._row_to_holding(r) for r in rows]
def delete_holding(self, portfolio_id: str, ticker: str) -> None:
"""Delete the holding for (portfolio_id, ticker)."""
with self._cursor() as cur:
cur.execute(
"DELETE FROM holdings WHERE portfolio_id = %s AND ticker = %s RETURNING holding_id",
(portfolio_id, ticker.upper()),
)
row = cur.fetchone()
if not row:
raise HoldingNotFoundError(
f"Holding not found: {ticker} in portfolio {portfolio_id}"
)
# ------------------------------------------------------------------
# Trades
# ------------------------------------------------------------------
def record_trade(self, trade: Trade) -> Trade:
"""Insert a new trade record."""
tid = trade.trade_id or str(uuid.uuid4())
with self._cursor() as cur:
cur.execute(
"""INSERT INTO trades
(trade_id, portfolio_id, ticker, action, shares, price,
total_value, rationale, signal_source, metadata)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING *""",
(tid, trade.portfolio_id, trade.ticker, trade.action,
trade.shares, trade.price, trade.total_value,
trade.rationale, trade.signal_source,
json.dumps(trade.metadata)),
)
row = cur.fetchone()
return self._row_to_trade(row)
def list_trades(
self,
portfolio_id: str,
ticker: str | None = None,
limit: int = 100,
) -> list[Trade]:
"""Return recent trades for a portfolio, newest first."""
if ticker:
query = """SELECT * FROM trades
WHERE portfolio_id = %s AND ticker = %s
ORDER BY trade_date DESC LIMIT %s"""
params = (portfolio_id, ticker.upper(), limit)
else:
query = """SELECT * FROM trades
WHERE portfolio_id = %s
ORDER BY trade_date DESC LIMIT %s"""
params = (portfolio_id, limit)
with self._cursor() as cur:
cur.execute(query, params)
rows = cur.fetchall()
return [self._row_to_trade(r) for r in rows]
# ------------------------------------------------------------------
# Snapshots
# ------------------------------------------------------------------
def save_snapshot(self, snapshot: PortfolioSnapshot) -> PortfolioSnapshot:
"""Insert a new immutable portfolio snapshot."""
sid = snapshot.snapshot_id or str(uuid.uuid4())
with self._cursor() as cur:
cur.execute(
"""INSERT INTO snapshots
(snapshot_id, portfolio_id, total_value, cash, equity_value,
num_positions, holdings_snapshot, metadata)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
RETURNING *""",
(sid, snapshot.portfolio_id, snapshot.total_value,
snapshot.cash, snapshot.equity_value, snapshot.num_positions,
json.dumps(snapshot.holdings_snapshot),
json.dumps(snapshot.metadata)),
)
row = cur.fetchone()
return self._row_to_snapshot(row)
def get_latest_snapshot(self, portfolio_id: str) -> PortfolioSnapshot | None:
"""Return the most recent snapshot for a portfolio, or None."""
with self._cursor() as cur:
cur.execute(
"""SELECT * FROM snapshots
WHERE portfolio_id = %s
ORDER BY snapshot_date DESC LIMIT 1""",
(portfolio_id,),
)
row = cur.fetchone()
return self._row_to_snapshot(row) if row else None
def list_snapshots(
self,
portfolio_id: str,
limit: int = 30,
) -> list[PortfolioSnapshot]:
"""Return snapshots newest-first up to limit."""
with self._cursor() as cur:
cur.execute(
"""SELECT * FROM snapshots
WHERE portfolio_id = %s
ORDER BY snapshot_date DESC LIMIT %s""",
(portfolio_id, limit),
)
rows = cur.fetchall()
return [self._row_to_snapshot(r) for r in rows]
# ------------------------------------------------------------------
# Row -> Model helpers
# ------------------------------------------------------------------
@staticmethod
def _row_to_portfolio(row: dict) -> Portfolio:
metadata = row.get("metadata") or {}
if isinstance(metadata, str):
metadata = json.loads(metadata)
return Portfolio(
portfolio_id=str(row["portfolio_id"]),
name=row["name"],
cash=float(row["cash"]),
initial_cash=float(row["initial_cash"]),
currency=row["currency"].strip(),
created_at=str(row["created_at"]),
updated_at=str(row["updated_at"]),
report_path=row.get("report_path"),
metadata=metadata,
)
@staticmethod
def _row_to_holding(row: dict) -> Holding:
return Holding(
holding_id=str(row["holding_id"]),
portfolio_id=str(row["portfolio_id"]),
ticker=row["ticker"],
shares=float(row["shares"]),
avg_cost=float(row["avg_cost"]),
sector=row.get("sector"),
industry=row.get("industry"),
created_at=str(row["created_at"]),
updated_at=str(row["updated_at"]),
)
@staticmethod
def _row_to_trade(row: dict) -> Trade:
metadata = row.get("metadata") or {}
if isinstance(metadata, str):
metadata = json.loads(metadata)
return Trade(
trade_id=str(row["trade_id"]),
portfolio_id=str(row["portfolio_id"]),
ticker=row["ticker"],
action=row["action"],
shares=float(row["shares"]),
price=float(row["price"]),
total_value=float(row["total_value"]),
trade_date=str(row["trade_date"]),
rationale=row.get("rationale"),
signal_source=row.get("signal_source"),
metadata=metadata,
)
@staticmethod
def _row_to_snapshot(row: dict) -> PortfolioSnapshot:
holdings = row.get("holdings_snapshot") or []
if isinstance(holdings, str):
holdings = json.loads(holdings)
metadata = row.get("metadata") or {}
if isinstance(metadata, str):
metadata = json.loads(metadata)
return PortfolioSnapshot(
snapshot_id=str(row["snapshot_id"]),
portfolio_id=str(row["portfolio_id"]),
snapshot_date=str(row["snapshot_date"]),
total_value=float(row["total_value"]),
cash=float(row["cash"]),
equity_value=float(row["equity_value"]),
num_positions=int(row["num_positions"]),
holdings_snapshot=holdings,
metadata=metadata,
)