TradingAgents/tradingagents/portfolio/repository.py

302 lines
10 KiB
Python

"""Unified data-access facade for the Portfolio Manager.
``PortfolioRepository`` combines ``SupabaseClient`` (transactional data) and
``ReportStore`` (filesystem documents) into a single, business-logic-aware
interface.
Usage::
from tradingagents.portfolio import PortfolioRepository
repo = PortfolioRepository()
portfolio = repo.create_portfolio("Main Portfolio", initial_cash=100_000.0)
holding = repo.add_holding(portfolio.portfolio_id, "AAPL", shares=50, price=195.50)
See ``docs/portfolio/04_repository_api.md`` for full API documentation.
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from tradingagents.portfolio.config import get_portfolio_config
from tradingagents.portfolio.exceptions import (
HoldingNotFoundError,
InsufficientCashError,
InsufficientSharesError,
)
from tradingagents.portfolio.models import (
Holding,
Portfolio,
PortfolioSnapshot,
Trade,
)
from tradingagents.portfolio.report_store import ReportStore
from tradingagents.portfolio.supabase_client import SupabaseClient
class PortfolioRepository:
"""Unified facade over SupabaseClient and ReportStore.
Implements business logic for:
- Average cost basis updates on repeated buys
- Cash deduction / credit on trades
- Constraint enforcement (cash, position size)
- Snapshot management
"""
def __init__(
self,
client: SupabaseClient | None = None,
store: ReportStore | None = None,
config: dict[str, Any] | None = None,
) -> None:
self._cfg = config or get_portfolio_config()
self._client = client or SupabaseClient.get_instance()
self._store = store or ReportStore(base_dir=self._cfg["data_dir"])
# ------------------------------------------------------------------
# Portfolio lifecycle
# ------------------------------------------------------------------
def create_portfolio(
self,
name: str,
initial_cash: float,
currency: str = "USD",
) -> Portfolio:
"""Create a new portfolio with the given starting capital."""
if initial_cash <= 0:
raise ValueError(f"initial_cash must be > 0, got {initial_cash}")
portfolio = Portfolio(
portfolio_id=str(uuid.uuid4()),
name=name,
cash=initial_cash,
initial_cash=initial_cash,
currency=currency,
)
return self._client.create_portfolio(portfolio)
def get_portfolio(self, portfolio_id: str) -> Portfolio:
"""Fetch a portfolio by ID."""
return self._client.get_portfolio(portfolio_id)
def get_portfolio_with_holdings(
self,
portfolio_id: str,
prices: dict[str, float] | None = None,
) -> tuple[Portfolio, list[Holding]]:
"""Fetch portfolio + all holdings, optionally enriched with current prices."""
portfolio = self._client.get_portfolio(portfolio_id)
holdings = self._client.list_holdings(portfolio_id)
if prices:
# First pass: compute equity for total_value
equity = sum(
prices.get(h.ticker, 0.0) * h.shares for h in holdings
)
total_value = portfolio.cash + equity
# Second pass: enrich each holding with weight
for h in holdings:
if h.ticker in prices:
h.enrich(prices[h.ticker], total_value)
portfolio.enrich(holdings)
return portfolio, holdings
# ------------------------------------------------------------------
# Holdings management
# ------------------------------------------------------------------
def add_holding(
self,
portfolio_id: str,
ticker: str,
shares: float,
price: float,
sector: str | None = None,
industry: str | None = None,
) -> Holding:
"""Buy shares and update portfolio cash and holdings."""
if shares <= 0:
raise ValueError(f"shares must be > 0, got {shares}")
if price <= 0:
raise ValueError(f"price must be > 0, got {price}")
cost = shares * price
portfolio = self._client.get_portfolio(portfolio_id)
if portfolio.cash < cost:
raise InsufficientCashError(
f"Need ${cost:.2f} but only ${portfolio.cash:.2f} available"
)
# Check for existing holding to update avg cost
existing = self._client.get_holding(portfolio_id, ticker)
if existing:
new_total_shares = existing.shares + shares
new_avg_cost = (
(existing.shares * existing.avg_cost + shares * price) / new_total_shares
)
existing.shares = new_total_shares
existing.avg_cost = new_avg_cost
if sector:
existing.sector = sector
if industry:
existing.industry = industry
holding = self._client.upsert_holding(existing)
else:
holding = Holding(
holding_id=str(uuid.uuid4()),
portfolio_id=portfolio_id,
ticker=ticker.upper(),
shares=shares,
avg_cost=price,
sector=sector,
industry=industry,
)
holding = self._client.upsert_holding(holding)
# Deduct cash
portfolio.cash -= cost
self._client.update_portfolio(portfolio)
# Record trade
trade = Trade(
trade_id=str(uuid.uuid4()),
portfolio_id=portfolio_id,
ticker=ticker.upper(),
action="BUY",
shares=shares,
price=price,
total_value=cost,
trade_date=datetime.now(timezone.utc).isoformat(),
signal_source="pm_agent",
)
self._client.record_trade(trade)
return holding
def remove_holding(
self,
portfolio_id: str,
ticker: str,
shares: float,
price: float,
) -> Holding | None:
"""Sell shares and update portfolio cash and holdings."""
if shares <= 0:
raise ValueError(f"shares must be > 0, got {shares}")
if price <= 0:
raise ValueError(f"price must be > 0, got {price}")
existing = self._client.get_holding(portfolio_id, ticker)
if not existing:
raise HoldingNotFoundError(
f"No holding for {ticker} in portfolio {portfolio_id}"
)
if existing.shares < shares:
raise InsufficientSharesError(
f"Hold {existing.shares} shares of {ticker}, cannot sell {shares}"
)
proceeds = shares * price
portfolio = self._client.get_portfolio(portfolio_id)
if existing.shares == shares:
# Full sell — delete holding
self._client.delete_holding(portfolio_id, ticker)
result = None
else:
existing.shares -= shares
result = self._client.upsert_holding(existing)
# Credit cash
portfolio.cash += proceeds
self._client.update_portfolio(portfolio)
# Record trade
trade = Trade(
trade_id=str(uuid.uuid4()),
portfolio_id=portfolio_id,
ticker=ticker.upper(),
action="SELL",
shares=shares,
price=price,
total_value=proceeds,
trade_date=datetime.now(timezone.utc).isoformat(),
signal_source="pm_agent",
)
self._client.record_trade(trade)
return result
# ------------------------------------------------------------------
# Snapshots
# ------------------------------------------------------------------
def take_snapshot(
self,
portfolio_id: str,
prices: dict[str, float],
) -> PortfolioSnapshot:
"""Take an immutable snapshot of the current portfolio state."""
portfolio, holdings = self.get_portfolio_with_holdings(portfolio_id, prices)
snapshot = PortfolioSnapshot(
snapshot_id=str(uuid.uuid4()),
portfolio_id=portfolio_id,
snapshot_date=datetime.now(timezone.utc).isoformat(),
total_value=portfolio.total_value or portfolio.cash,
cash=portfolio.cash,
equity_value=portfolio.equity_value or 0.0,
num_positions=len(holdings),
holdings_snapshot=[h.to_dict() for h in holdings],
)
return self._client.save_snapshot(snapshot)
# ------------------------------------------------------------------
# Report convenience methods
# ------------------------------------------------------------------
def save_pm_decision(
self,
portfolio_id: str,
date: str,
decision: dict[str, Any],
markdown: str | None = None,
) -> Path:
"""Save a PM agent decision and update portfolio.report_path."""
path = self._store.save_pm_decision(date, portfolio_id, decision, markdown)
# Update portfolio report_path
portfolio = self._client.get_portfolio(portfolio_id)
portfolio.report_path = str(self._store._portfolio_dir(date))
self._client.update_portfolio(portfolio)
return path
def load_pm_decision(
self,
portfolio_id: str,
date: str,
) -> dict[str, Any] | None:
"""Load a PM decision JSON. Returns None if not found."""
return self._store.load_pm_decision(date, portfolio_id)
def save_risk_metrics(
self,
portfolio_id: str,
date: str,
metrics: dict[str, Any],
) -> Path:
"""Save risk computation results."""
return self._store.save_risk_metrics(date, portfolio_id, metrics)
def load_risk_metrics(
self,
portfolio_id: str,
date: str,
) -> dict[str, Any] | None:
"""Load risk metrics. Returns None if not found."""
return self._store.load_risk_metrics(date, portfolio_id)