"""Unified data-access facade for the Portfolio Manager. ``PortfolioRepository`` combines ``SupabaseClient`` (transactional data) and ``ReportStore`` (filesystem documents) into a single, business-logic-aware interface. Usage:: from tradingagents.portfolio import PortfolioRepository repo = PortfolioRepository() portfolio = repo.create_portfolio("Main Portfolio", initial_cash=100_000.0) holding = repo.add_holding(portfolio.portfolio_id, "AAPL", shares=50, price=195.50) See ``docs/portfolio/04_repository_api.md`` for full API documentation. """ from __future__ import annotations import uuid from datetime import datetime, timezone from pathlib import Path from typing import Any from tradingagents.portfolio.config import get_portfolio_config from tradingagents.portfolio.exceptions import ( HoldingNotFoundError, InsufficientCashError, InsufficientSharesError, ) from tradingagents.portfolio.models import ( Holding, Portfolio, PortfolioSnapshot, Trade, ) from tradingagents.portfolio.report_store import ReportStore from tradingagents.portfolio.supabase_client import SupabaseClient class PortfolioRepository: """Unified facade over SupabaseClient and ReportStore. Implements business logic for: - Average cost basis updates on repeated buys - Cash deduction / credit on trades - Constraint enforcement (cash, position size) - Snapshot management """ def __init__( self, client: SupabaseClient | None = None, store: ReportStore | None = None, config: dict[str, Any] | None = None, ) -> None: self._cfg = config or get_portfolio_config() self._client = client or SupabaseClient.get_instance() self._store = store or ReportStore(base_dir=self._cfg["data_dir"]) # ------------------------------------------------------------------ # Portfolio lifecycle # ------------------------------------------------------------------ def create_portfolio( self, name: str, initial_cash: float, currency: str = "USD", ) -> Portfolio: """Create a new portfolio with the given starting capital.""" if initial_cash <= 0: raise ValueError(f"initial_cash must be > 0, got {initial_cash}") portfolio = Portfolio( portfolio_id=str(uuid.uuid4()), name=name, cash=initial_cash, initial_cash=initial_cash, currency=currency, ) return self._client.create_portfolio(portfolio) def get_portfolio(self, portfolio_id: str) -> Portfolio: """Fetch a portfolio by ID.""" return self._client.get_portfolio(portfolio_id) def get_portfolio_with_holdings( self, portfolio_id: str, prices: dict[str, float] | None = None, ) -> tuple[Portfolio, list[Holding]]: """Fetch portfolio + all holdings, optionally enriched with current prices.""" portfolio = self._client.get_portfolio(portfolio_id) holdings = self._client.list_holdings(portfolio_id) if prices: # First pass: compute equity for total_value equity = sum( prices.get(h.ticker, 0.0) * h.shares for h in holdings ) total_value = portfolio.cash + equity # Second pass: enrich each holding with weight for h in holdings: if h.ticker in prices: h.enrich(prices[h.ticker], total_value) portfolio.enrich(holdings) return portfolio, holdings # ------------------------------------------------------------------ # Holdings management # ------------------------------------------------------------------ def add_holding( self, portfolio_id: str, ticker: str, shares: float, price: float, sector: str | None = None, industry: str | None = None, ) -> Holding: """Buy shares and update portfolio cash and holdings.""" if shares <= 0: raise ValueError(f"shares must be > 0, got {shares}") if price <= 0: raise ValueError(f"price must be > 0, got {price}") cost = shares * price portfolio = self._client.get_portfolio(portfolio_id) if portfolio.cash < cost: raise InsufficientCashError( f"Need ${cost:.2f} but only ${portfolio.cash:.2f} available" ) # Check for existing holding to update avg cost existing = self._client.get_holding(portfolio_id, ticker) if existing: new_total_shares = existing.shares + shares new_avg_cost = ( (existing.shares * existing.avg_cost + shares * price) / new_total_shares ) existing.shares = new_total_shares existing.avg_cost = new_avg_cost if sector: existing.sector = sector if industry: existing.industry = industry holding = self._client.upsert_holding(existing) else: holding = Holding( holding_id=str(uuid.uuid4()), portfolio_id=portfolio_id, ticker=ticker.upper(), shares=shares, avg_cost=price, sector=sector, industry=industry, ) holding = self._client.upsert_holding(holding) # Deduct cash portfolio.cash -= cost self._client.update_portfolio(portfolio) # Record trade trade = Trade( trade_id=str(uuid.uuid4()), portfolio_id=portfolio_id, ticker=ticker.upper(), action="BUY", shares=shares, price=price, total_value=cost, trade_date=datetime.now(timezone.utc).isoformat(), signal_source="pm_agent", ) self._client.record_trade(trade) return holding def remove_holding( self, portfolio_id: str, ticker: str, shares: float, price: float, ) -> Holding | None: """Sell shares and update portfolio cash and holdings.""" if shares <= 0: raise ValueError(f"shares must be > 0, got {shares}") if price <= 0: raise ValueError(f"price must be > 0, got {price}") existing = self._client.get_holding(portfolio_id, ticker) if not existing: raise HoldingNotFoundError( f"No holding for {ticker} in portfolio {portfolio_id}" ) if existing.shares < shares: raise InsufficientSharesError( f"Hold {existing.shares} shares of {ticker}, cannot sell {shares}" ) proceeds = shares * price portfolio = self._client.get_portfolio(portfolio_id) if existing.shares == shares: # Full sell — delete holding self._client.delete_holding(portfolio_id, ticker) result = None else: existing.shares -= shares result = self._client.upsert_holding(existing) # Credit cash portfolio.cash += proceeds self._client.update_portfolio(portfolio) # Record trade trade = Trade( trade_id=str(uuid.uuid4()), portfolio_id=portfolio_id, ticker=ticker.upper(), action="SELL", shares=shares, price=price, total_value=proceeds, trade_date=datetime.now(timezone.utc).isoformat(), signal_source="pm_agent", ) self._client.record_trade(trade) return result # ------------------------------------------------------------------ # Snapshots # ------------------------------------------------------------------ def take_snapshot( self, portfolio_id: str, prices: dict[str, float], ) -> PortfolioSnapshot: """Take an immutable snapshot of the current portfolio state.""" portfolio, holdings = self.get_portfolio_with_holdings(portfolio_id, prices) snapshot = PortfolioSnapshot( snapshot_id=str(uuid.uuid4()), portfolio_id=portfolio_id, snapshot_date=datetime.now(timezone.utc).isoformat(), total_value=portfolio.total_value or portfolio.cash, cash=portfolio.cash, equity_value=portfolio.equity_value or 0.0, num_positions=len(holdings), holdings_snapshot=[h.to_dict() for h in holdings], ) return self._client.save_snapshot(snapshot) # ------------------------------------------------------------------ # Report convenience methods # ------------------------------------------------------------------ def save_pm_decision( self, portfolio_id: str, date: str, decision: dict[str, Any], markdown: str | None = None, ) -> Path: """Save a PM agent decision and update portfolio.report_path.""" path = self._store.save_pm_decision(date, portfolio_id, decision, markdown) # Update portfolio report_path portfolio = self._client.get_portfolio(portfolio_id) portfolio.report_path = str(self._store._portfolio_dir(date)) self._client.update_portfolio(portfolio) return path def load_pm_decision( self, portfolio_id: str, date: str, ) -> dict[str, Any] | None: """Load a PM decision JSON. Returns None if not found.""" return self._store.load_pm_decision(date, portfolio_id) def save_risk_metrics( self, portfolio_id: str, date: str, metrics: dict[str, Any], ) -> Path: """Save risk computation results.""" return self._store.save_risk_metrics(date, portfolio_id, metrics) def load_risk_metrics( self, portfolio_id: str, date: str, ) -> dict[str, Any] | None: """Load risk metrics. Returns None if not found.""" return self._store.load_risk_metrics(date, portfolio_id)