392 lines
15 KiB
Python
392 lines
15 KiB
Python
"""PostgreSQL database client for the Portfolio Manager.
|
|
|
|
Uses ``psycopg2`` with the ``SUPABASE_CONNECTION_STRING`` env var to talk
|
|
directly to the Supabase-hosted PostgreSQL database. No ORM — see
|
|
``docs/agent/decisions/012-portfolio-no-orm.md`` for rationale.
|
|
|
|
Usage::
|
|
|
|
from tradingagents.portfolio.supabase_client import SupabaseClient
|
|
|
|
client = SupabaseClient.get_instance()
|
|
portfolio = client.get_portfolio("some-uuid")
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
|
|
import psycopg2
|
|
import psycopg2.extras
|
|
|
|
from tradingagents.portfolio.config import get_portfolio_config
|
|
from tradingagents.portfolio.exceptions import (
|
|
DuplicatePortfolioError,
|
|
HoldingNotFoundError,
|
|
PortfolioError,
|
|
PortfolioNotFoundError,
|
|
)
|
|
from tradingagents.portfolio.models import (
|
|
Holding,
|
|
Portfolio,
|
|
PortfolioSnapshot,
|
|
Trade,
|
|
)
|
|
|
|
|
|
class SupabaseClient:
|
|
"""Singleton PostgreSQL CRUD client for portfolio data.
|
|
|
|
All public methods translate database errors into domain exceptions
|
|
and return typed model instances.
|
|
"""
|
|
|
|
_instance: SupabaseClient | None = None
|
|
|
|
def __init__(self, connection_string: str) -> None:
|
|
self._dsn = self._fix_dsn(connection_string)
|
|
self._conn = psycopg2.connect(self._dsn)
|
|
self._conn.autocommit = True
|
|
|
|
@staticmethod
|
|
def _fix_dsn(dsn: str) -> str:
|
|
"""URL-encode the password if it contains special characters."""
|
|
from urllib.parse import quote
|
|
if "://" not in dsn:
|
|
return dsn # already key=value format
|
|
scheme, rest = dsn.split("://", 1)
|
|
at_idx = rest.rfind("@")
|
|
if at_idx == -1:
|
|
return dsn
|
|
userinfo = rest[:at_idx]
|
|
hostinfo = rest[at_idx + 1:]
|
|
colon_idx = userinfo.find(":")
|
|
if colon_idx == -1:
|
|
return dsn
|
|
user = userinfo[:colon_idx]
|
|
password = userinfo[colon_idx + 1:]
|
|
encoded = quote(password, safe="")
|
|
return f"{scheme}://{user}:{encoded}@{hostinfo}"
|
|
|
|
@classmethod
|
|
def get_instance(cls) -> SupabaseClient:
|
|
"""Return the singleton instance, creating it if necessary."""
|
|
if cls._instance is None:
|
|
cfg = get_portfolio_config()
|
|
dsn = cfg["supabase_connection_string"]
|
|
if not dsn:
|
|
raise PortfolioError(
|
|
"SUPABASE_CONNECTION_STRING not configured. "
|
|
"Set it in .env or as an environment variable."
|
|
)
|
|
cls._instance = cls(dsn)
|
|
return cls._instance
|
|
|
|
@classmethod
|
|
def reset_instance(cls) -> None:
|
|
"""Close and reset the singleton (for testing)."""
|
|
if cls._instance is not None:
|
|
try:
|
|
cls._instance._conn.close()
|
|
except Exception:
|
|
pass
|
|
cls._instance = None
|
|
|
|
def _cursor(self):
|
|
"""Return a RealDictCursor."""
|
|
return self._conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Portfolio CRUD
|
|
# ------------------------------------------------------------------
|
|
|
|
def create_portfolio(self, portfolio: Portfolio) -> Portfolio:
|
|
"""Insert a new portfolio row."""
|
|
pid = portfolio.portfolio_id or str(uuid.uuid4())
|
|
try:
|
|
with self._cursor() as cur:
|
|
cur.execute(
|
|
"""INSERT INTO portfolios
|
|
(portfolio_id, name, cash, initial_cash, currency, report_path, metadata)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
|
RETURNING *""",
|
|
(pid, portfolio.name, portfolio.cash, portfolio.initial_cash,
|
|
portfolio.currency, portfolio.report_path,
|
|
json.dumps(portfolio.metadata)),
|
|
)
|
|
row = cur.fetchone()
|
|
except psycopg2.errors.UniqueViolation as exc:
|
|
raise DuplicatePortfolioError(f"Portfolio already exists: {pid}") from exc
|
|
return self._row_to_portfolio(row)
|
|
|
|
def get_portfolio(self, portfolio_id: str) -> Portfolio:
|
|
"""Fetch a portfolio by ID."""
|
|
with self._cursor() as cur:
|
|
cur.execute("SELECT * FROM portfolios WHERE portfolio_id = %s", (portfolio_id,))
|
|
row = cur.fetchone()
|
|
if not row:
|
|
raise PortfolioNotFoundError(f"Portfolio not found: {portfolio_id}")
|
|
return self._row_to_portfolio(row)
|
|
|
|
def list_portfolios(self) -> list[Portfolio]:
|
|
"""Return all portfolios ordered by created_at DESC."""
|
|
with self._cursor() as cur:
|
|
cur.execute("SELECT * FROM portfolios ORDER BY created_at DESC")
|
|
rows = cur.fetchall()
|
|
return [self._row_to_portfolio(r) for r in rows]
|
|
|
|
def update_portfolio(self, portfolio: Portfolio) -> Portfolio:
|
|
"""Update mutable portfolio fields (cash, report_path, metadata)."""
|
|
with self._cursor() as cur:
|
|
cur.execute(
|
|
"""UPDATE portfolios
|
|
SET cash = %s, report_path = %s, metadata = %s
|
|
WHERE portfolio_id = %s
|
|
RETURNING *""",
|
|
(portfolio.cash, portfolio.report_path,
|
|
json.dumps(portfolio.metadata), portfolio.portfolio_id),
|
|
)
|
|
row = cur.fetchone()
|
|
if not row:
|
|
raise PortfolioNotFoundError(f"Portfolio not found: {portfolio.portfolio_id}")
|
|
return self._row_to_portfolio(row)
|
|
|
|
def delete_portfolio(self, portfolio_id: str) -> None:
|
|
"""Delete a portfolio and all associated data (CASCADE)."""
|
|
with self._cursor() as cur:
|
|
cur.execute(
|
|
"DELETE FROM portfolios WHERE portfolio_id = %s RETURNING portfolio_id",
|
|
(portfolio_id,),
|
|
)
|
|
row = cur.fetchone()
|
|
if not row:
|
|
raise PortfolioNotFoundError(f"Portfolio not found: {portfolio_id}")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Holdings CRUD
|
|
# ------------------------------------------------------------------
|
|
|
|
def upsert_holding(self, holding: Holding) -> Holding:
|
|
"""Insert or update a holding row (upsert on portfolio_id + ticker)."""
|
|
hid = holding.holding_id or str(uuid.uuid4())
|
|
with self._cursor() as cur:
|
|
cur.execute(
|
|
"""INSERT INTO holdings
|
|
(holding_id, portfolio_id, ticker, shares, avg_cost, sector, industry)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
|
ON CONFLICT ON CONSTRAINT holdings_portfolio_ticker_unique
|
|
DO UPDATE SET shares = EXCLUDED.shares,
|
|
avg_cost = EXCLUDED.avg_cost,
|
|
sector = EXCLUDED.sector,
|
|
industry = EXCLUDED.industry
|
|
RETURNING *""",
|
|
(hid, holding.portfolio_id, holding.ticker.upper(),
|
|
holding.shares, holding.avg_cost, holding.sector, holding.industry),
|
|
)
|
|
row = cur.fetchone()
|
|
return self._row_to_holding(row)
|
|
|
|
def get_holding(self, portfolio_id: str, ticker: str) -> Holding | None:
|
|
"""Return the holding for (portfolio_id, ticker), or None."""
|
|
with self._cursor() as cur:
|
|
cur.execute(
|
|
"SELECT * FROM holdings WHERE portfolio_id = %s AND ticker = %s",
|
|
(portfolio_id, ticker.upper()),
|
|
)
|
|
row = cur.fetchone()
|
|
return self._row_to_holding(row) if row else None
|
|
|
|
def list_holdings(self, portfolio_id: str) -> list[Holding]:
|
|
"""Return all holdings for a portfolio ordered by cost_basis DESC."""
|
|
with self._cursor() as cur:
|
|
cur.execute(
|
|
"""SELECT * FROM holdings
|
|
WHERE portfolio_id = %s
|
|
ORDER BY shares * avg_cost DESC""",
|
|
(portfolio_id,),
|
|
)
|
|
rows = cur.fetchall()
|
|
return [self._row_to_holding(r) for r in rows]
|
|
|
|
def delete_holding(self, portfolio_id: str, ticker: str) -> None:
|
|
"""Delete the holding for (portfolio_id, ticker)."""
|
|
with self._cursor() as cur:
|
|
cur.execute(
|
|
"DELETE FROM holdings WHERE portfolio_id = %s AND ticker = %s RETURNING holding_id",
|
|
(portfolio_id, ticker.upper()),
|
|
)
|
|
row = cur.fetchone()
|
|
if not row:
|
|
raise HoldingNotFoundError(
|
|
f"Holding not found: {ticker} in portfolio {portfolio_id}"
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Trades
|
|
# ------------------------------------------------------------------
|
|
|
|
def record_trade(self, trade: Trade) -> Trade:
|
|
"""Insert a new trade record."""
|
|
tid = trade.trade_id or str(uuid.uuid4())
|
|
with self._cursor() as cur:
|
|
cur.execute(
|
|
"""INSERT INTO trades
|
|
(trade_id, portfolio_id, ticker, action, shares, price,
|
|
total_value, rationale, signal_source, metadata)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|
RETURNING *""",
|
|
(tid, trade.portfolio_id, trade.ticker, trade.action,
|
|
trade.shares, trade.price, trade.total_value,
|
|
trade.rationale, trade.signal_source,
|
|
json.dumps(trade.metadata)),
|
|
)
|
|
row = cur.fetchone()
|
|
return self._row_to_trade(row)
|
|
|
|
def list_trades(
|
|
self,
|
|
portfolio_id: str,
|
|
ticker: str | None = None,
|
|
limit: int = 100,
|
|
) -> list[Trade]:
|
|
"""Return recent trades for a portfolio, newest first."""
|
|
if ticker:
|
|
query = """SELECT * FROM trades
|
|
WHERE portfolio_id = %s AND ticker = %s
|
|
ORDER BY trade_date DESC LIMIT %s"""
|
|
params = (portfolio_id, ticker.upper(), limit)
|
|
else:
|
|
query = """SELECT * FROM trades
|
|
WHERE portfolio_id = %s
|
|
ORDER BY trade_date DESC LIMIT %s"""
|
|
params = (portfolio_id, limit)
|
|
with self._cursor() as cur:
|
|
cur.execute(query, params)
|
|
rows = cur.fetchall()
|
|
return [self._row_to_trade(r) for r in rows]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Snapshots
|
|
# ------------------------------------------------------------------
|
|
|
|
def save_snapshot(self, snapshot: PortfolioSnapshot) -> PortfolioSnapshot:
|
|
"""Insert a new immutable portfolio snapshot."""
|
|
sid = snapshot.snapshot_id or str(uuid.uuid4())
|
|
with self._cursor() as cur:
|
|
cur.execute(
|
|
"""INSERT INTO snapshots
|
|
(snapshot_id, portfolio_id, total_value, cash, equity_value,
|
|
num_positions, holdings_snapshot, metadata)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
|
RETURNING *""",
|
|
(sid, snapshot.portfolio_id, snapshot.total_value,
|
|
snapshot.cash, snapshot.equity_value, snapshot.num_positions,
|
|
json.dumps(snapshot.holdings_snapshot),
|
|
json.dumps(snapshot.metadata)),
|
|
)
|
|
row = cur.fetchone()
|
|
return self._row_to_snapshot(row)
|
|
|
|
def get_latest_snapshot(self, portfolio_id: str) -> PortfolioSnapshot | None:
|
|
"""Return the most recent snapshot for a portfolio, or None."""
|
|
with self._cursor() as cur:
|
|
cur.execute(
|
|
"""SELECT * FROM snapshots
|
|
WHERE portfolio_id = %s
|
|
ORDER BY snapshot_date DESC LIMIT 1""",
|
|
(portfolio_id,),
|
|
)
|
|
row = cur.fetchone()
|
|
return self._row_to_snapshot(row) if row else None
|
|
|
|
def list_snapshots(
|
|
self,
|
|
portfolio_id: str,
|
|
limit: int = 30,
|
|
) -> list[PortfolioSnapshot]:
|
|
"""Return snapshots newest-first up to limit."""
|
|
with self._cursor() as cur:
|
|
cur.execute(
|
|
"""SELECT * FROM snapshots
|
|
WHERE portfolio_id = %s
|
|
ORDER BY snapshot_date DESC LIMIT %s""",
|
|
(portfolio_id, limit),
|
|
)
|
|
rows = cur.fetchall()
|
|
return [self._row_to_snapshot(r) for r in rows]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Row -> Model helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _row_to_portfolio(row: dict) -> Portfolio:
|
|
metadata = row.get("metadata") or {}
|
|
if isinstance(metadata, str):
|
|
metadata = json.loads(metadata)
|
|
return Portfolio(
|
|
portfolio_id=str(row["portfolio_id"]),
|
|
name=row["name"],
|
|
cash=float(row["cash"]),
|
|
initial_cash=float(row["initial_cash"]),
|
|
currency=row["currency"].strip(),
|
|
created_at=str(row["created_at"]),
|
|
updated_at=str(row["updated_at"]),
|
|
report_path=row.get("report_path"),
|
|
metadata=metadata,
|
|
)
|
|
|
|
@staticmethod
|
|
def _row_to_holding(row: dict) -> Holding:
|
|
return Holding(
|
|
holding_id=str(row["holding_id"]),
|
|
portfolio_id=str(row["portfolio_id"]),
|
|
ticker=row["ticker"],
|
|
shares=float(row["shares"]),
|
|
avg_cost=float(row["avg_cost"]),
|
|
sector=row.get("sector"),
|
|
industry=row.get("industry"),
|
|
created_at=str(row["created_at"]),
|
|
updated_at=str(row["updated_at"]),
|
|
)
|
|
|
|
@staticmethod
|
|
def _row_to_trade(row: dict) -> Trade:
|
|
metadata = row.get("metadata") or {}
|
|
if isinstance(metadata, str):
|
|
metadata = json.loads(metadata)
|
|
return Trade(
|
|
trade_id=str(row["trade_id"]),
|
|
portfolio_id=str(row["portfolio_id"]),
|
|
ticker=row["ticker"],
|
|
action=row["action"],
|
|
shares=float(row["shares"]),
|
|
price=float(row["price"]),
|
|
total_value=float(row["total_value"]),
|
|
trade_date=str(row["trade_date"]),
|
|
rationale=row.get("rationale"),
|
|
signal_source=row.get("signal_source"),
|
|
metadata=metadata,
|
|
)
|
|
|
|
@staticmethod
|
|
def _row_to_snapshot(row: dict) -> PortfolioSnapshot:
|
|
holdings = row.get("holdings_snapshot") or []
|
|
if isinstance(holdings, str):
|
|
holdings = json.loads(holdings)
|
|
metadata = row.get("metadata") or {}
|
|
if isinstance(metadata, str):
|
|
metadata = json.loads(metadata)
|
|
return PortfolioSnapshot(
|
|
snapshot_id=str(row["snapshot_id"]),
|
|
portfolio_id=str(row["portfolio_id"]),
|
|
snapshot_date=str(row["snapshot_date"]),
|
|
total_value=float(row["total_value"]),
|
|
cash=float(row["cash"]),
|
|
equity_value=float(row["equity_value"]),
|
|
num_positions=int(row["num_positions"]),
|
|
holdings_snapshot=holdings,
|
|
metadata=metadata,
|
|
)
|