feat: complete portfolio data foundation — psycopg2 client, repository, tests

Replace supabase-py stubs with working psycopg2 implementation using
Supabase pooler connection string. Implement full business logic in
repository (avg cost basis, cash accounting, trade recording, snapshots).
Add 12 unit tests + 4 integration tests (51 total portfolio tests pass).
Fix cash_pct bug in models.py, update docs for psycopg2 + pooler pattern.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Ahmet Guzererler 2026-03-20 14:06:50 +01:00
parent aa4dcdeb80
commit a17e5f3707
9 changed files with 816 additions and 498 deletions

View File

@ -1,20 +1,21 @@
# Current Milestone
Daily digest consolidation and Google NotebookLM sync shipped (PR open: `feat/daily-digest-notebooklm`). All analyses now append to a single `daily_digest.md` per day and auto-upload to NotebookLM via `nlm` CLI. Next: PR review and merge.
Portfolio Manager Phase 1 (data foundation) complete and merged. All 4 Supabase tables live, 51 tests passing (including integration tests against live DB).
# Recent Progress
- **PR #32 merged**: Portfolio Manager data foundation — models, SQL schema, module scaffolding
- `tradingagents/portfolio/` — full module: models, config, exceptions, supabase_client (psycopg2), report_store, repository
- `migrations/001_initial_schema.sql` — 4 tables (portfolios, holdings, trades, snapshots) with constraints, indexes, triggers
- `tests/portfolio/` — 51 tests: 20 model, 15 report_store, 12 repository unit, 4 integration
- Uses `psycopg2` direct PostgreSQL via Supabase pooler (`aws-1-eu-west-1.pooler.supabase.com:6543`)
- Business logic: avg cost basis, cash accounting, trade recording, snapshots
- **PR #22 merged**: Unified report paths, structured observability logging, memory system update
- **feat/daily-digest-notebooklm** (shipped): Daily digest consolidation + NotebookLM source sync
- `tradingagents/daily_digest.py``append_to_digest()` appends timestamped entries to `reports/daily/{date}/daily_digest.md`
- `tradingagents/notebook_sync.py``sync_to_notebooklm()` deletes existing "Daily Trading Digest" source then uploads new content via `nlm source add --text --wait`.
- `tradingagents/report_paths.py` — added `get_digest_path(date)`
- `cli/main.py``analyze` and `scan` commands both call digest + sync after each run
- `.env.example` — fixed consistency, removed duplicates, aligned with `NOTEBOOKLM_ID`
- **Verification**: 220+ offline tests passing + 5 new unit tests for `notebook_sync.py` + live integration test passed.
# In Progress
- Portfolio Manager Phase 2: Holding Reviewer Agent (next)
- Refinement of macro scan synthesis prompts (ongoing)
# Active Blockers

View File

@ -72,28 +72,26 @@ investment portfolio end-to-end. It performs the following actions in sequence:
The `report_path` column in the `portfolios` table points to the daily portfolio
subdirectory on disk: `reports/daily/{date}/portfolio/`.
### Data Access Layer: raw `supabase-py` (no ORM)
### Data Access Layer: raw `psycopg2` (no ORM)
The Python code talks to Supabase through the raw `supabase-py` client — **no
ORM** (Prisma, SQLAlchemy, etc.) is used.
The Python code talks to Supabase PostgreSQL directly via `psycopg2` using the
**pooler connection string** (`SUPABASE_CONNECTION_STRING`). No ORM (Prisma,
SQLAlchemy) and no `supabase-py` REST client is used.
**Why not Prisma?**
- `prisma-client-py` requires a Node.js runtime for code generation — an
extra non-Python dependency in a Python-only project.
- Prisma's `prisma migrate` conflicts with Supabase's own SQL migration tooling
(we use `.sql` files in `tradingagents/portfolio/migrations/`).
- 4 tables with straightforward CRUD don't benefit from a code-generated ORM.
**Why `psycopg2` over `supabase-py`?**
- Direct SQL gives full control — transactions, upserts, `RETURNING *`, CTEs.
- No dependency on Supabase's PostgREST schema cache or API key types.
- `psycopg2-binary` is a single pip install with zero non-Python dependencies.
- 4 tables with straightforward CRUD don't benefit from an ORM or REST wrapper.
**Why not SQLAlchemy?**
- Supabase is accessed via PostgREST (HTTP API), not a direct TCP database
connection. SQLAlchemy is designed for direct connections and would bypass
Supabase's Row Level Security.
- Extra dependency overhead for a non-DB-heavy feature.
**Connection:**
- Uses `SUPABASE_CONNECTION_STRING` env var (pooler URI format).
- Passwords with special characters are auto-URL-encoded by `SupabaseClient._fix_dsn()`.
- Typical pooler URI: `postgresql://postgres.<ref>:<password>@aws-1-<region>.pooler.supabase.com:6543/postgres`
**`supabase-py` is sufficient because:**
- Its builder-pattern API (`client.table("holdings").select("*").eq(...)`)
covers all needed queries cleanly.
- Our own dataclasses handle type-safety via `to_dict()` / `from_dict()`.
**Why not Prisma / SQLAlchemy?**
- Prisma requires Node.js runtime — extra non-Python dependency.
- SQLAlchemy adds dependency overhead for 4 simple tables.
- Plain SQL migration files are readable, versionable, and Supabase-native.
> Full rationale: `docs/agent/decisions/012-portfolio-no-orm.md`

View File

@ -89,14 +89,13 @@ All are optional and default to the values shown:
| Env Var | Default | Description |
|---------|---------|-------------|
| `SUPABASE_URL` | `""` | Supabase project URL |
| `SUPABASE_KEY` | `""` | Supabase anon/service key |
| `PORTFOLIO_DATA_DIR` | `"reports"` | Root dir for filesystem reports |
| `PM_MAX_POSITIONS` | `15` | Max number of open positions |
| `PM_MAX_POSITION_PCT` | `0.15` | Max single-position weight |
| `PM_MAX_SECTOR_PCT` | `0.35` | Max sector weight |
| `PM_MIN_CASH_PCT` | `0.05` | Minimum cash reserve |
| `PM_DEFAULT_BUDGET` | `100000.0` | Default starting cash (USD) |
| `SUPABASE_CONNECTION_STRING` | `""` | PostgreSQL pooler connection URI |
| `TRADINGAGENTS_PORTFOLIO_DATA_DIR` | `"reports"` | Root dir for filesystem reports |
| `TRADINGAGENTS_PM_MAX_POSITIONS` | `15` | Max number of open positions |
| `TRADINGAGENTS_PM_MAX_POSITION_PCT` | `0.15` | Max single-position weight |
| `TRADINGAGENTS_PM_MAX_SECTOR_PCT` | `0.35` | Max sector weight |
| `TRADINGAGENTS_PM_MIN_CASH_PCT` | `0.05` | Minimum cash reserve |
| `TRADINGAGENTS_PM_DEFAULT_BUDGET` | `100000.0` | Default starting cash (USD) |
### Acceptance Criteria
@ -172,7 +171,7 @@ Methods that query a single row raise `PortfolioNotFoundError` when no row is fo
- Singleton — only one Supabase connection per process
- All public methods fully type-annotated
- Supabase integration tests auto-skip when `SUPABASE_URL` is unset
- Supabase integration tests auto-skip when `SUPABASE_CONNECTION_STRING` is unset
---
@ -328,7 +327,7 @@ require PG functions — deferred to Phase 3+).
### Coverage Target
90 %+ for `models.py` and `report_store.py`.
Integration tests (`test_repository.py`) auto-skip when Supabase is unavailable.
Integration tests (`test_repository.py`) auto-skip when `SUPABASE_CONNECTION_STRING` is unset.
---

View File

@ -2,16 +2,16 @@
Fixtures provided:
- ``tmp_reports`` temporary directory used as ReportStore base_dir
- ``sample_portfolio`` a Portfolio instance for testing (not persisted)
- ``sample_holding`` a Holding instance for testing (not persisted)
- ``sample_trade`` a Trade instance for testing (not persisted)
- ``sample_snapshot`` a PortfolioSnapshot instance for testing
- ``report_store`` a ReportStore instance backed by tmp_reports
- ``mock_supabase_client`` MagicMock of SupabaseClient for unit tests
- ``tmp_reports`` -- temporary directory used as ReportStore base_dir
- ``sample_portfolio`` -- a Portfolio instance for testing (not persisted)
- ``sample_holding`` -- a Holding instance for testing (not persisted)
- ``sample_trade`` -- a Trade instance for testing (not persisted)
- ``sample_snapshot`` -- a PortfolioSnapshot instance for testing
- ``report_store`` -- a ReportStore instance backed by tmp_reports
- ``mock_supabase_client`` -- MagicMock of SupabaseClient for unit tests
Supabase integration tests use ``pytest.mark.skipif`` to auto-skip when
``SUPABASE_URL`` is not set in the environment.
``SUPABASE_CONNECTION_STRING`` is not set in the environment.
"""
from __future__ import annotations
@ -36,8 +36,8 @@ from tradingagents.portfolio.supabase_client import SupabaseClient
# ---------------------------------------------------------------------------
requires_supabase = pytest.mark.skipif(
not os.getenv("SUPABASE_URL"),
reason="SUPABASE_URL not set — skipping Supabase integration tests",
not os.getenv("SUPABASE_CONNECTION_STRING"),
reason="SUPABASE_CONNECTION_STRING not set -- skipping integration tests",
)

View File

@ -1,29 +1,59 @@
"""Tests for tradingagents/portfolio/repository.py.
Tests the PortfolioRepository façade business logic for holdings management,
cash accounting, avg-cost-basis updates, and snapshot creation.
Supabase integration tests are automatically skipped when ``SUPABASE_URL`` is
not set in the environment (use the ``requires_supabase`` fixture marker).
Unit tests use ``mock_supabase_client`` to avoid DB access.
Integration tests auto-skip when ``SUPABASE_CONNECTION_STRING`` is unset.
Run (unit tests only)::
pytest tests/portfolio/test_repository.py -v -k "not integration"
Run (with Supabase)::
SUPABASE_URL=... SUPABASE_KEY=... pytest tests/portfolio/test_repository.py -v
"""
from __future__ import annotations
from unittest.mock import MagicMock, call
import pytest
from tradingagents.portfolio.exceptions import (
HoldingNotFoundError,
InsufficientCashError,
InsufficientSharesError,
)
from tradingagents.portfolio.models import Holding, Portfolio, Trade
from tradingagents.portfolio.repository import PortfolioRepository
from tests.portfolio.conftest import requires_supabase
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_repo(mock_client, report_store):
"""Build a PortfolioRepository with mock client and real report store."""
return PortfolioRepository(
client=mock_client,
store=report_store,
config={"data_dir": "reports", "max_positions": 15,
"max_position_pct": 0.15, "max_sector_pct": 0.35,
"min_cash_pct": 0.05, "default_budget": 100_000.0,
"supabase_connection_string": ""},
)
def _mock_portfolio(pid="pid-1", cash=10_000.0):
return Portfolio(
portfolio_id=pid, name="Test", cash=cash,
initial_cash=100_000.0, currency="USD",
)
def _mock_holding(pid="pid-1", ticker="AAPL", shares=50.0, avg_cost=190.0):
return Holding(
holding_id="hid-1", portfolio_id=pid, ticker=ticker,
shares=shares, avg_cost=avg_cost,
)
# ---------------------------------------------------------------------------
# add_holding — new position
# ---------------------------------------------------------------------------
@ -31,12 +61,19 @@ from tests.portfolio.conftest import requires_supabase
def test_add_holding_new_position(mock_supabase_client, report_store):
"""add_holding() on a ticker not yet held must create a new Holding."""
# TODO: implement
# repo = PortfolioRepository(client=mock_supabase_client, store=report_store)
# mock portfolio with enough cash
# repo.add_holding(portfolio_id, "AAPL", shares=10, price=200.0)
# assert mock_supabase_client.upsert_holding.called
raise NotImplementedError
mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=10_000.0)
mock_supabase_client.get_holding.return_value = None
mock_supabase_client.upsert_holding.side_effect = lambda h: h
mock_supabase_client.update_portfolio.side_effect = lambda p: p
mock_supabase_client.record_trade.side_effect = lambda t: t
repo = _make_repo(mock_supabase_client, report_store)
holding = repo.add_holding("pid-1", "AAPL", shares=10, price=200.0)
assert holding.ticker == "AAPL"
assert holding.shares == 10
assert holding.avg_cost == 200.0
assert mock_supabase_client.upsert_holding.called
# ---------------------------------------------------------------------------
@ -45,16 +82,21 @@ def test_add_holding_new_position(mock_supabase_client, report_store):
def test_add_holding_updates_avg_cost(mock_supabase_client, report_store):
"""add_holding() on an existing position must update avg_cost correctly.
"""add_holding() on existing position must update avg_cost correctly."""
mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=10_000.0)
existing = _mock_holding(shares=50.0, avg_cost=190.0)
mock_supabase_client.get_holding.return_value = existing
mock_supabase_client.upsert_holding.side_effect = lambda h: h
mock_supabase_client.update_portfolio.side_effect = lambda p: p
mock_supabase_client.record_trade.side_effect = lambda t: t
Formula: new_avg_cost = (old_shares * old_avg_cost + new_shares * price)
/ (old_shares + new_shares)
"""
# TODO: implement
# existing holding: 50 shares @ 190.0
# buy 25 more @ 200.0
# expected avg_cost = (50*190 + 25*200) / 75 = 193.33...
raise NotImplementedError
repo = _make_repo(mock_supabase_client, report_store)
holding = repo.add_holding("pid-1", "AAPL", shares=25, price=200.0)
# expected: (50*190 + 25*200) / 75 = 193.333...
expected_avg = (50 * 190.0 + 25 * 200.0) / 75
assert holding.shares == 75
assert holding.avg_cost == pytest.approx(expected_avg)
# ---------------------------------------------------------------------------
@ -64,11 +106,11 @@ def test_add_holding_updates_avg_cost(mock_supabase_client, report_store):
def test_add_holding_raises_insufficient_cash(mock_supabase_client, report_store):
"""add_holding() must raise InsufficientCashError when cash < shares * price."""
# TODO: implement
# portfolio with cash=500.0, try to buy 10 shares @ 200.0 (cost=2000)
# with pytest.raises(InsufficientCashError):
# repo.add_holding(portfolio_id, "AAPL", shares=10, price=200.0)
raise NotImplementedError
mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=500.0)
repo = _make_repo(mock_supabase_client, report_store)
with pytest.raises(InsufficientCashError):
repo.add_holding("pid-1", "AAPL", shares=10, price=200.0)
# ---------------------------------------------------------------------------
@ -78,11 +120,17 @@ def test_add_holding_raises_insufficient_cash(mock_supabase_client, report_store
def test_remove_holding_full_position(mock_supabase_client, report_store):
"""remove_holding() selling all shares must delete the holding row."""
# TODO: implement
# holding: 50 shares
# sell 50 shares → holding deleted, cash credited
# assert mock_supabase_client.delete_holding.called
raise NotImplementedError
mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0)
mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=5_000.0)
mock_supabase_client.delete_holding.return_value = None
mock_supabase_client.update_portfolio.side_effect = lambda p: p
mock_supabase_client.record_trade.side_effect = lambda t: t
repo = _make_repo(mock_supabase_client, report_store)
result = repo.remove_holding("pid-1", "AAPL", shares=50.0, price=200.0)
assert result is None
assert mock_supabase_client.delete_holding.called
# ---------------------------------------------------------------------------
@ -92,10 +140,17 @@ def test_remove_holding_full_position(mock_supabase_client, report_store):
def test_remove_holding_partial_position(mock_supabase_client, report_store):
"""remove_holding() selling a subset must reduce shares, not delete."""
# TODO: implement
# holding: 50 shares
# sell 20 → holding.shares == 30, avg_cost unchanged
raise NotImplementedError
mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0)
mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=5_000.0)
mock_supabase_client.upsert_holding.side_effect = lambda h: h
mock_supabase_client.update_portfolio.side_effect = lambda p: p
mock_supabase_client.record_trade.side_effect = lambda t: t
repo = _make_repo(mock_supabase_client, report_store)
result = repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0)
assert result is not None
assert result.shares == 30.0
# ---------------------------------------------------------------------------
@ -105,16 +160,20 @@ def test_remove_holding_partial_position(mock_supabase_client, report_store):
def test_remove_holding_raises_insufficient_shares(mock_supabase_client, report_store):
"""remove_holding() must raise InsufficientSharesError when shares > held."""
# TODO: implement
# holding: 10 shares
# try sell 20 → InsufficientSharesError
raise NotImplementedError
mock_supabase_client.get_holding.return_value = _mock_holding(shares=10.0)
repo = _make_repo(mock_supabase_client, report_store)
with pytest.raises(InsufficientSharesError):
repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0)
def test_remove_holding_raises_when_ticker_not_held(mock_supabase_client, report_store):
"""remove_holding() must raise HoldingNotFoundError for unknown tickers."""
# TODO: implement
raise NotImplementedError
mock_supabase_client.get_holding.return_value = None
repo = _make_repo(mock_supabase_client, report_store)
with pytest.raises(HoldingNotFoundError):
repo.remove_holding("pid-1", "ZZZZ", shares=10.0, price=100.0)
# ---------------------------------------------------------------------------
@ -124,15 +183,35 @@ def test_remove_holding_raises_when_ticker_not_held(mock_supabase_client, report
def test_add_holding_deducts_cash(mock_supabase_client, report_store):
"""add_holding() must reduce portfolio.cash by shares * price."""
# TODO: implement
# portfolio.cash = 10_000, buy 10 @ 200 → cash should be 8_000
raise NotImplementedError
portfolio = _mock_portfolio(cash=10_000.0)
mock_supabase_client.get_portfolio.return_value = portfolio
mock_supabase_client.get_holding.return_value = None
mock_supabase_client.upsert_holding.side_effect = lambda h: h
mock_supabase_client.update_portfolio.side_effect = lambda p: p
mock_supabase_client.record_trade.side_effect = lambda t: t
repo = _make_repo(mock_supabase_client, report_store)
repo.add_holding("pid-1", "AAPL", shares=10, price=200.0)
# Check the portfolio passed to update_portfolio had cash deducted
updated = mock_supabase_client.update_portfolio.call_args[0][0]
assert updated.cash == pytest.approx(8_000.0)
def test_remove_holding_credits_cash(mock_supabase_client, report_store):
"""remove_holding() must increase portfolio.cash by shares * price."""
# TODO: implement
raise NotImplementedError
portfolio = _mock_portfolio(cash=5_000.0)
mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0)
mock_supabase_client.get_portfolio.return_value = portfolio
mock_supabase_client.upsert_holding.side_effect = lambda h: h
mock_supabase_client.update_portfolio.side_effect = lambda p: p
mock_supabase_client.record_trade.side_effect = lambda t: t
repo = _make_repo(mock_supabase_client, report_store)
repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0)
updated = mock_supabase_client.update_portfolio.call_args[0][0]
assert updated.cash == pytest.approx(9_000.0)
# ---------------------------------------------------------------------------
@ -142,14 +221,37 @@ def test_remove_holding_credits_cash(mock_supabase_client, report_store):
def test_add_holding_records_buy_trade(mock_supabase_client, report_store):
"""add_holding() must call client.record_trade() with action='BUY'."""
# TODO: implement
raise NotImplementedError
mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=10_000.0)
mock_supabase_client.get_holding.return_value = None
mock_supabase_client.upsert_holding.side_effect = lambda h: h
mock_supabase_client.update_portfolio.side_effect = lambda p: p
mock_supabase_client.record_trade.side_effect = lambda t: t
repo = _make_repo(mock_supabase_client, report_store)
repo.add_holding("pid-1", "AAPL", shares=10, price=200.0)
trade = mock_supabase_client.record_trade.call_args[0][0]
assert trade.action == "BUY"
assert trade.ticker == "AAPL"
assert trade.shares == 10
assert trade.total_value == pytest.approx(2_000.0)
def test_remove_holding_records_sell_trade(mock_supabase_client, report_store):
"""remove_holding() must call client.record_trade() with action='SELL'."""
# TODO: implement
raise NotImplementedError
mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0)
mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=5_000.0)
mock_supabase_client.upsert_holding.side_effect = lambda h: h
mock_supabase_client.update_portfolio.side_effect = lambda p: p
mock_supabase_client.record_trade.side_effect = lambda t: t
repo = _make_repo(mock_supabase_client, report_store)
repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0)
trade = mock_supabase_client.record_trade.call_args[0][0]
assert trade.action == "SELL"
assert trade.ticker == "AAPL"
assert trade.shares == 20.0
# ---------------------------------------------------------------------------
@ -159,40 +261,120 @@ def test_remove_holding_records_sell_trade(mock_supabase_client, report_store):
def test_take_snapshot(mock_supabase_client, report_store):
"""take_snapshot() must enrich holdings and persist a PortfolioSnapshot."""
# TODO: implement
# assert mock_supabase_client.save_snapshot.called
# snapshot.total_value == cash + equity
raise NotImplementedError
portfolio = _mock_portfolio(cash=5_000.0)
holding = _mock_holding(shares=50.0, avg_cost=190.0)
mock_supabase_client.get_portfolio.return_value = portfolio
mock_supabase_client.list_holdings.return_value = [holding]
mock_supabase_client.save_snapshot.side_effect = lambda s: s
repo = _make_repo(mock_supabase_client, report_store)
snapshot = repo.take_snapshot("pid-1", prices={"AAPL": 200.0})
assert mock_supabase_client.save_snapshot.called
assert snapshot.cash == 5_000.0
assert snapshot.num_positions == 1
assert snapshot.total_value == pytest.approx(5_000.0 + 50.0 * 200.0)
# ---------------------------------------------------------------------------
# Supabase integration tests (auto-skip without SUPABASE_URL)
# Supabase integration tests (auto-skip without SUPABASE_CONNECTION_STRING)
# ---------------------------------------------------------------------------
@requires_supabase
def test_integration_create_and_get_portfolio():
"""Integration: create a portfolio, retrieve it, verify fields match."""
# TODO: implement
raise NotImplementedError
from tradingagents.portfolio.supabase_client import SupabaseClient
client = SupabaseClient.get_instance()
from tradingagents.portfolio.report_store import ReportStore
store = ReportStore()
repo = PortfolioRepository(client=client, store=store)
portfolio = repo.create_portfolio("Integration Test", initial_cash=50_000.0)
try:
fetched = repo.get_portfolio(portfolio.portfolio_id)
assert fetched.name == "Integration Test"
assert fetched.cash == pytest.approx(50_000.0)
assert fetched.initial_cash == pytest.approx(50_000.0)
finally:
client.delete_portfolio(portfolio.portfolio_id)
@requires_supabase
def test_integration_add_and_remove_holding():
"""Integration: add holding, verify DB row; remove, verify deletion."""
# TODO: implement
raise NotImplementedError
"""Integration: add holding, verify; remove, verify deletion."""
from tradingagents.portfolio.supabase_client import SupabaseClient
client = SupabaseClient.get_instance()
from tradingagents.portfolio.report_store import ReportStore
store = ReportStore()
repo = PortfolioRepository(client=client, store=store)
portfolio = repo.create_portfolio("Hold Test", initial_cash=50_000.0)
try:
holding = repo.add_holding(
portfolio.portfolio_id, "AAPL", shares=10, price=200.0,
sector="Technology",
)
assert holding.ticker == "AAPL"
assert holding.shares == 10
# Verify cash deducted
p = repo.get_portfolio(portfolio.portfolio_id)
assert p.cash == pytest.approx(48_000.0)
# Sell all
result = repo.remove_holding(portfolio.portfolio_id, "AAPL", shares=10, price=210.0)
assert result is None
# Verify cash credited
p = repo.get_portfolio(portfolio.portfolio_id)
assert p.cash == pytest.approx(50_100.0)
finally:
client.delete_portfolio(portfolio.portfolio_id)
@requires_supabase
def test_integration_record_and_list_trades():
"""Integration: record BUY + SELL trades, list them, verify order."""
# TODO: implement
raise NotImplementedError
"""Integration: trades are recorded automatically via add/remove holding."""
from tradingagents.portfolio.supabase_client import SupabaseClient
client = SupabaseClient.get_instance()
from tradingagents.portfolio.report_store import ReportStore
store = ReportStore()
repo = PortfolioRepository(client=client, store=store)
portfolio = repo.create_portfolio("Trade Test", initial_cash=50_000.0)
try:
repo.add_holding(portfolio.portfolio_id, "MSFT", shares=5, price=400.0)
repo.remove_holding(portfolio.portfolio_id, "MSFT", shares=5, price=410.0)
trades = client.list_trades(portfolio.portfolio_id)
assert len(trades) == 2
assert trades[0].action == "SELL" # newest first
assert trades[1].action == "BUY"
finally:
client.delete_portfolio(portfolio.portfolio_id)
@requires_supabase
def test_integration_save_and_load_snapshot():
"""Integration: take snapshot, retrieve latest, verify total_value."""
# TODO: implement
raise NotImplementedError
from tradingagents.portfolio.supabase_client import SupabaseClient
client = SupabaseClient.get_instance()
from tradingagents.portfolio.report_store import ReportStore
store = ReportStore()
repo = PortfolioRepository(client=client, store=store)
portfolio = repo.create_portfolio("Snap Test", initial_cash=50_000.0)
try:
repo.add_holding(portfolio.portfolio_id, "AAPL", shares=10, price=200.0)
snapshot = repo.take_snapshot(portfolio.portfolio_id, prices={"AAPL": 210.0})
assert snapshot.num_positions == 1
assert snapshot.cash == pytest.approx(48_000.0)
assert snapshot.total_value == pytest.approx(48_000.0 + 10 * 210.0)
latest = client.get_latest_snapshot(portfolio.portfolio_id)
assert latest is not None
assert latest.snapshot_id == snapshot.snapshot_id
finally:
client.delete_portfolio(portfolio.portfolio_id)

View File

@ -1,8 +1,7 @@
"""Portfolio Manager configuration.
Reads all portfolio-related settings from environment variables with sensible
defaults. Integrates with the existing ``tradingagents/default_config.py``
pattern.
Integrates with the existing ``tradingagents/default_config.py`` pattern,
reading all portfolio settings from ``TRADINGAGENTS_<KEY>`` env vars.
Usage::
@ -11,34 +10,51 @@ Usage::
cfg = get_portfolio_config()
validate_config(cfg)
print(cfg["max_positions"]) # 15
Environment variables (all optional):
SUPABASE_URL Supabase project URL (default: "")
SUPABASE_KEY Supabase anon/service role key (default: "")
PORTFOLIO_DATA_DIR Root dir for filesystem reports (default: "reports")
PM_MAX_POSITIONS Max open positions (default: 15)
PM_MAX_POSITION_PCT Max single-position weight 01 (default: 0.15)
PM_MAX_SECTOR_PCT Max sector weight 01 (default: 0.35)
PM_MIN_CASH_PCT Minimum cash reserve 01 (default: 0.05)
PM_DEFAULT_BUDGET Default starting cash in USD (default: 100000.0)
"""
from __future__ import annotations
import os
# ---------------------------------------------------------------------------
# Defaults
# ---------------------------------------------------------------------------
from dotenv import load_dotenv
load_dotenv()
def _env(key: str, default=None):
"""Read ``TRADINGAGENTS_<KEY>`` from the environment.
Matches the convention in ``tradingagents/default_config.py``.
"""
val = os.getenv(f"TRADINGAGENTS_{key.upper()}")
if not val:
return default
return val
def _env_float(key: str, default=None):
val = _env(key)
if val is None:
return default
try:
return float(val)
except (ValueError, TypeError):
return default
def _env_int(key: str, default=None):
val = _env(key)
if val is None:
return default
try:
return int(val)
except (ValueError, TypeError):
return default
PORTFOLIO_CONFIG: dict = {
# Supabase connection
"supabase_url": "",
"supabase_key": "",
# Filesystem report root (matches report_paths.py REPORTS_ROOT)
"data_dir": "reports",
# PM constraint defaults
"supabase_connection_string": os.getenv("SUPABASE_CONNECTION_STRING", ""),
"data_dir": _env("PORTFOLIO_DATA_DIR", "reports"),
"max_positions": 15,
"max_position_pct": 0.15,
"max_sector_pct": 0.35,
@ -47,41 +63,44 @@ PORTFOLIO_CONFIG: dict = {
}
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def get_portfolio_config() -> dict:
"""Return the merged portfolio config (defaults overridden by env vars).
Reads ``SUPABASE_URL``, ``SUPABASE_KEY``, ``PORTFOLIO_DATA_DIR``,
``PM_MAX_POSITIONS``, ``PM_MAX_POSITION_PCT``, ``PM_MAX_SECTOR_PCT``,
``PM_MIN_CASH_PCT``, and ``PM_DEFAULT_BUDGET`` from the environment.
Returns:
A dict with all portfolio configuration keys.
"""
# TODO: implement — merge PORTFOLIO_CONFIG with env var overrides
raise NotImplementedError
cfg = dict(PORTFOLIO_CONFIG)
cfg["supabase_connection_string"] = os.getenv("SUPABASE_CONNECTION_STRING", cfg["supabase_connection_string"])
cfg["data_dir"] = _env("PORTFOLIO_DATA_DIR", cfg["data_dir"])
cfg["max_positions"] = _env_int("PM_MAX_POSITIONS", cfg["max_positions"])
cfg["max_position_pct"] = _env_float("PM_MAX_POSITION_PCT", cfg["max_position_pct"])
cfg["max_sector_pct"] = _env_float("PM_MAX_SECTOR_PCT", cfg["max_sector_pct"])
cfg["min_cash_pct"] = _env_float("PM_MIN_CASH_PCT", cfg["min_cash_pct"])
cfg["default_budget"] = _env_float("PM_DEFAULT_BUDGET", cfg["default_budget"])
return cfg
def validate_config(cfg: dict) -> None:
"""Validate a portfolio config dict, raising ValueError on invalid values.
Checks:
- ``max_positions >= 1``
- ``0 < max_position_pct <= 1``
- ``0 < max_sector_pct <= 1``
- ``0 <= min_cash_pct < 1``
- ``default_budget > 0``
- ``min_cash_pct + max_position_pct <= 1`` (can always meet both constraints)
Args:
cfg: Config dict as returned by ``get_portfolio_config()``.
Raises:
ValueError: With a descriptive message on the first failed check.
"""
# TODO: implement
raise NotImplementedError
if cfg["max_positions"] < 1:
raise ValueError(f"max_positions must be >= 1, got {cfg['max_positions']}")
if not (0 < cfg["max_position_pct"] <= 1.0):
raise ValueError(f"max_position_pct must be in (0, 1], got {cfg['max_position_pct']}")
if not (0 < cfg["max_sector_pct"] <= 1.0):
raise ValueError(f"max_sector_pct must be in (0, 1], got {cfg['max_sector_pct']}")
if not (0 <= cfg["min_cash_pct"] < 1.0):
raise ValueError(f"min_cash_pct must be in [0, 1), got {cfg['min_cash_pct']}")
if cfg["default_budget"] <= 0:
raise ValueError(f"default_budget must be > 0, got {cfg['default_budget']}")
if cfg["min_cash_pct"] + cfg["max_position_pct"] > 1.0:
raise ValueError(
f"min_cash_pct ({cfg['min_cash_pct']}) + max_position_pct ({cfg['max_position_pct']}) "
f"must be <= 1.0"
)

View File

@ -106,7 +106,7 @@ class Portfolio:
h.current_value for h in holdings if h.current_value is not None
)
self.total_value = self.cash + self.equity_value
self.cash_pct = self.cash / self.total_value if self.total_value != 0.0 else 0.0
self.cash_pct = self.cash / self.total_value if self.total_value != 0.0 else 1.0
return self

View File

@ -1,38 +1,28 @@
"""Unified data-access façade for the Portfolio Manager.
"""Unified data-access facade for the Portfolio Manager.
``PortfolioRepository`` combines ``SupabaseClient`` (transactional data) and
``ReportStore`` (filesystem documents) into a single, business-logic-aware
interface.
Callers should **only** interact with ``PortfolioRepository`` do not use
``SupabaseClient`` or ``ReportStore`` directly from outside this package.
Usage::
from tradingagents.portfolio import PortfolioRepository
repo = PortfolioRepository()
# Create a portfolio
portfolio = repo.create_portfolio("Main Portfolio", initial_cash=100_000.0)
# Buy shares
holding = repo.add_holding(portfolio.portfolio_id, "AAPL", shares=50, price=195.50)
# Sell shares
repo.remove_holding(portfolio.portfolio_id, "AAPL", shares=25, price=200.00)
# Snapshot
snapshot = repo.take_snapshot(portfolio.portfolio_id, prices={"AAPL": 200.00})
See ``docs/portfolio/04_repository_api.md`` for full API documentation.
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from tradingagents.portfolio.config import get_portfolio_config
from tradingagents.portfolio.exceptions import (
HoldingNotFoundError,
InsufficientCashError,
@ -49,7 +39,7 @@ from tradingagents.portfolio.supabase_client import SupabaseClient
class PortfolioRepository:
"""Unified façade over SupabaseClient and ReportStore.
"""Unified facade over SupabaseClient and ReportStore.
Implements business logic for:
- Average cost basis updates on repeated buys
@ -64,16 +54,9 @@ class PortfolioRepository:
store: ReportStore | None = None,
config: dict[str, Any] | None = None,
) -> None:
"""Initialise the repository.
Args:
client: SupabaseClient instance. Uses ``SupabaseClient.get_instance()``
when None.
store: ReportStore instance. Creates a default instance when None.
config: Portfolio config dict. Uses ``get_portfolio_config()`` when None.
"""
# TODO: implement — resolve defaults, store as self._client, self._store, self._cfg
raise NotImplementedError
self._cfg = config or get_portfolio_config()
self._client = client or SupabaseClient.get_instance()
self._store = store or ReportStore(base_dir=self._cfg["data_dir"])
# ------------------------------------------------------------------
# Portfolio lifecycle
@ -85,54 +68,42 @@ class PortfolioRepository:
initial_cash: float,
currency: str = "USD",
) -> Portfolio:
"""Create a new portfolio with the given starting capital.
Generates a UUID for ``portfolio_id``. Sets ``cash = initial_cash``.
Args:
name: Human-readable portfolio name.
initial_cash: Starting capital in USD (or configured currency).
currency: ISO 4217 currency code.
Returns:
Persisted Portfolio instance.
Raises:
DuplicatePortfolioError: If the name is already in use.
ValueError: If ``initial_cash <= 0``.
"""
# TODO: implement
raise NotImplementedError
"""Create a new portfolio with the given starting capital."""
if initial_cash <= 0:
raise ValueError(f"initial_cash must be > 0, got {initial_cash}")
portfolio = Portfolio(
portfolio_id=str(uuid.uuid4()),
name=name,
cash=initial_cash,
initial_cash=initial_cash,
currency=currency,
)
return self._client.create_portfolio(portfolio)
def get_portfolio(self, portfolio_id: str) -> Portfolio:
"""Fetch a portfolio by ID.
Args:
portfolio_id: UUID of the target portfolio.
Raises:
PortfolioNotFoundError: If not found.
"""
# TODO: implement
raise NotImplementedError
"""Fetch a portfolio by ID."""
return self._client.get_portfolio(portfolio_id)
def get_portfolio_with_holdings(
self,
portfolio_id: str,
prices: dict[str, float] | None = None,
) -> tuple[Portfolio, list[Holding]]:
"""Fetch portfolio + all holdings, optionally enriched with current prices.
Args:
portfolio_id: UUID of the target portfolio.
prices: Optional ``{ticker: current_price}`` dict. When provided,
holdings are enriched and ``Portfolio.enrich()`` is called.
Returns:
``(Portfolio, list[Holding])`` enriched when prices are supplied.
"""
# TODO: implement
raise NotImplementedError
"""Fetch portfolio + all holdings, optionally enriched with current prices."""
portfolio = self._client.get_portfolio(portfolio_id)
holdings = self._client.list_holdings(portfolio_id)
if prices:
# First pass: compute equity for total_value
equity = sum(
prices.get(h.ticker, 0.0) * h.shares for h in holdings
)
total_value = portfolio.cash + equity
# Second pass: enrich each holding with weight
for h in holdings:
if h.ticker in prices:
h.enrich(prices[h.ticker], total_value)
portfolio.enrich(holdings)
return portfolio, holdings
# ------------------------------------------------------------------
# Holdings management
@ -147,37 +118,65 @@ class PortfolioRepository:
sector: str | None = None,
industry: str | None = None,
) -> Holding:
"""Buy shares and update portfolio cash and holdings.
"""Buy shares and update portfolio cash and holdings."""
if shares <= 0:
raise ValueError(f"shares must be > 0, got {shares}")
if price <= 0:
raise ValueError(f"price must be > 0, got {price}")
Business logic:
- Raises ``InsufficientCashError`` if ``portfolio.cash < shares * price``
- If holding already exists: updates ``avg_cost`` using weighted average
- ``portfolio.cash -= shares * price``
- Records a BUY trade automatically
cost = shares * price
portfolio = self._client.get_portfolio(portfolio_id)
Avg cost formula::
if portfolio.cash < cost:
raise InsufficientCashError(
f"Need ${cost:.2f} but only ${portfolio.cash:.2f} available"
)
new_avg_cost = (old_shares * old_avg_cost + new_shares * price)
/ (old_shares + new_shares)
# Check for existing holding to update avg cost
existing = self._client.get_holding(portfolio_id, ticker)
if existing:
new_total_shares = existing.shares + shares
new_avg_cost = (
(existing.shares * existing.avg_cost + shares * price) / new_total_shares
)
existing.shares = new_total_shares
existing.avg_cost = new_avg_cost
if sector:
existing.sector = sector
if industry:
existing.industry = industry
holding = self._client.upsert_holding(existing)
else:
holding = Holding(
holding_id=str(uuid.uuid4()),
portfolio_id=portfolio_id,
ticker=ticker.upper(),
shares=shares,
avg_cost=price,
sector=sector,
industry=industry,
)
holding = self._client.upsert_holding(holding)
Args:
portfolio_id: UUID of the target portfolio.
ticker: Ticker symbol.
shares: Number of shares to buy (must be > 0).
price: Execution price per share.
sector: Optional GICS sector name.
industry: Optional GICS industry name.
# Deduct cash
portfolio.cash -= cost
self._client.update_portfolio(portfolio)
Returns:
Updated or created Holding.
# Record trade
trade = Trade(
trade_id=str(uuid.uuid4()),
portfolio_id=portfolio_id,
ticker=ticker.upper(),
action="BUY",
shares=shares,
price=price,
total_value=cost,
trade_date=datetime.now(timezone.utc).isoformat(),
signal_source="pm_agent",
)
self._client.record_trade(trade)
Raises:
InsufficientCashError: If cash would go negative.
PortfolioNotFoundError: If portfolio_id does not exist.
ValueError: If shares <= 0 or price <= 0.
"""
# TODO: implement
raise NotImplementedError
return holding
def remove_holding(
self,
@ -186,33 +185,53 @@ class PortfolioRepository:
shares: float,
price: float,
) -> Holding | None:
"""Sell shares and update portfolio cash and holdings.
"""Sell shares and update portfolio cash and holdings."""
if shares <= 0:
raise ValueError(f"shares must be > 0, got {shares}")
if price <= 0:
raise ValueError(f"price must be > 0, got {price}")
Business logic:
- Raises ``HoldingNotFoundError`` if no holding exists for ticker
- Raises ``InsufficientSharesError`` if ``holding.shares < shares``
- If ``shares == holding.shares``: deletes the holding row, returns None
- Otherwise: decrements ``holding.shares`` (avg_cost unchanged on sell)
- ``portfolio.cash += shares * price``
- Records a SELL trade automatically
existing = self._client.get_holding(portfolio_id, ticker)
if not existing:
raise HoldingNotFoundError(
f"No holding for {ticker} in portfolio {portfolio_id}"
)
Args:
portfolio_id: UUID of the target portfolio.
ticker: Ticker symbol.
shares: Number of shares to sell (must be > 0).
price: Execution price per share.
if existing.shares < shares:
raise InsufficientSharesError(
f"Hold {existing.shares} shares of {ticker}, cannot sell {shares}"
)
Returns:
Updated Holding, or None if the position was fully closed.
proceeds = shares * price
portfolio = self._client.get_portfolio(portfolio_id)
Raises:
HoldingNotFoundError: If no holding exists for this ticker.
InsufficientSharesError: If holding.shares < shares.
PortfolioNotFoundError: If portfolio_id does not exist.
ValueError: If shares <= 0 or price <= 0.
"""
# TODO: implement
raise NotImplementedError
if existing.shares == shares:
# Full sell — delete holding
self._client.delete_holding(portfolio_id, ticker)
result = None
else:
existing.shares -= shares
result = self._client.upsert_holding(existing)
# Credit cash
portfolio.cash += proceeds
self._client.update_portfolio(portfolio)
# Record trade
trade = Trade(
trade_id=str(uuid.uuid4()),
portfolio_id=portfolio_id,
ticker=ticker.upper(),
action="SELL",
shares=shares,
price=price,
total_value=proceeds,
trade_date=datetime.now(timezone.utc).isoformat(),
signal_source="pm_agent",
)
self._client.record_trade(trade)
return result
# ------------------------------------------------------------------
# Snapshots
@ -223,20 +242,19 @@ class PortfolioRepository:
portfolio_id: str,
prices: dict[str, float],
) -> PortfolioSnapshot:
"""Take an immutable snapshot of the current portfolio state.
Fetches all holdings, enriches them with ``prices``, computes
``total_value``, then persists via ``SupabaseClient.save_snapshot()``.
Args:
portfolio_id: UUID of the target portfolio.
prices: ``{ticker: current_price}`` for all held tickers.
Returns:
Persisted PortfolioSnapshot.
"""
# TODO: implement
raise NotImplementedError
"""Take an immutable snapshot of the current portfolio state."""
portfolio, holdings = self.get_portfolio_with_holdings(portfolio_id, prices)
snapshot = PortfolioSnapshot(
snapshot_id=str(uuid.uuid4()),
portfolio_id=portfolio_id,
snapshot_date=datetime.now(timezone.utc).isoformat(),
total_value=portfolio.total_value or portfolio.cash,
cash=portfolio.cash,
equity_value=portfolio.equity_value or 0.0,
num_positions=len(holdings),
holdings_snapshot=[h.to_dict() for h in holdings],
)
return self._client.save_snapshot(snapshot)
# ------------------------------------------------------------------
# Report convenience methods
@ -249,37 +267,21 @@ class PortfolioRepository:
decision: dict[str, Any],
markdown: str | None = None,
) -> Path:
"""Save a PM agent decision and update portfolio.report_path.
Delegates to ``ReportStore.save_pm_decision()`` then updates the
``portfolio.report_path`` column in Supabase to point to the daily
portfolio directory.
Args:
portfolio_id: UUID of the target portfolio.
date: ISO date string, e.g. ``"2026-03-20"``.
decision: PM decision dict.
markdown: Optional human-readable markdown version.
Returns:
Path of the written JSON file.
"""
# TODO: implement
raise NotImplementedError
"""Save a PM agent decision and update portfolio.report_path."""
path = self._store.save_pm_decision(date, portfolio_id, decision, markdown)
# Update portfolio report_path
portfolio = self._client.get_portfolio(portfolio_id)
portfolio.report_path = str(self._store._portfolio_dir(date))
self._client.update_portfolio(portfolio)
return path
def load_pm_decision(
self,
portfolio_id: str,
date: str,
) -> dict[str, Any] | None:
"""Load a PM decision JSON. Returns None if not found.
Args:
portfolio_id: UUID of the target portfolio.
date: ISO date string.
"""
# TODO: implement
raise NotImplementedError
"""Load a PM decision JSON. Returns None if not found."""
return self._store.load_pm_decision(date, portfolio_id)
def save_risk_metrics(
self,
@ -287,29 +289,13 @@ class PortfolioRepository:
date: str,
metrics: dict[str, Any],
) -> Path:
"""Save risk computation results. Delegates to ReportStore.
Args:
portfolio_id: UUID of the target portfolio.
date: ISO date string.
metrics: Risk metrics dict (Sharpe, Sortino, VaR, beta, etc.).
Returns:
Path of the written file.
"""
# TODO: implement
raise NotImplementedError
"""Save risk computation results."""
return self._store.save_risk_metrics(date, portfolio_id, metrics)
def load_risk_metrics(
self,
portfolio_id: str,
date: str,
) -> dict[str, Any] | None:
"""Load risk metrics. Returns None if not found.
Args:
portfolio_id: UUID of the target portfolio.
date: ISO date string.
"""
# TODO: implement
raise NotImplementedError
"""Load risk metrics. Returns None if not found."""
return self._store.load_risk_metrics(date, portfolio_id)

View File

@ -1,34 +1,30 @@
"""Supabase database client for the Portfolio Manager.
"""PostgreSQL database client for the Portfolio Manager.
Thin wrapper around ``supabase-py`` (no ORM) that:
- Provides a singleton connection (one client per process)
- Translates Supabase / HTTP errors into domain exceptions
- Converts raw DB rows into typed model instances via ``Model.from_dict()``
**No ORM is used here by design** see
``docs/agent/decisions/012-portfolio-no-orm.md`` for the full rationale.
In short: ``supabase-py``'s builder-pattern API is sufficient for 4 tables;
Prisma and SQLAlchemy add build-step / runtime complexity that isn't warranted
for this non-DB-heavy feature.
Uses ``psycopg2`` with the ``SUPABASE_CONNECTION_STRING`` env var to talk
directly to the Supabase-hosted PostgreSQL database. No ORM see
``docs/agent/decisions/012-portfolio-no-orm.md`` for rationale.
Usage::
from tradingagents.portfolio.supabase_client import SupabaseClient
from tradingagents.portfolio.models import Portfolio
client = SupabaseClient.get_instance()
portfolio = client.get_portfolio("some-uuid")
Configuration (read from environment via ``get_portfolio_config()``):
SUPABASE_URL Supabase project URL
SUPABASE_KEY Supabase anon or service-role key
"""
from __future__ import annotations
import json
import uuid
import psycopg2
import psycopg2.extras
from tradingagents.portfolio.config import get_portfolio_config
from tradingagents.portfolio.exceptions import (
DuplicatePortfolioError,
HoldingNotFoundError,
PortfolioError,
PortfolioNotFoundError,
)
from tradingagents.portfolio.models import (
@ -40,167 +36,213 @@ from tradingagents.portfolio.models import (
class SupabaseClient:
"""Singleton Supabase CRUD client for portfolio data.
"""Singleton PostgreSQL CRUD client for portfolio data.
All public methods translate Supabase / HTTP errors into domain exceptions
All public methods translate database errors into domain exceptions
and return typed model instances.
Do not instantiate directly use ``SupabaseClient.get_instance()``.
"""
_instance: "SupabaseClient | None" = None
_instance: SupabaseClient | None = None
def __init__(self, url: str, key: str) -> None:
"""Initialise the Supabase client.
def __init__(self, connection_string: str) -> None:
self._dsn = self._fix_dsn(connection_string)
self._conn = psycopg2.connect(self._dsn)
self._conn.autocommit = True
Args:
url: Supabase project URL.
key: Supabase anon or service-role key.
"""
# TODO: implement — create supabase.create_client(url, key)
raise NotImplementedError
@staticmethod
def _fix_dsn(dsn: str) -> str:
"""URL-encode the password if it contains special characters."""
from urllib.parse import quote
if "://" not in dsn:
return dsn # already key=value format
scheme, rest = dsn.split("://", 1)
at_idx = rest.rfind("@")
if at_idx == -1:
return dsn
userinfo = rest[:at_idx]
hostinfo = rest[at_idx + 1:]
colon_idx = userinfo.find(":")
if colon_idx == -1:
return dsn
user = userinfo[:colon_idx]
password = userinfo[colon_idx + 1:]
encoded = quote(password, safe="")
return f"{scheme}://{user}:{encoded}@{hostinfo}"
@classmethod
def get_instance(cls) -> "SupabaseClient":
"""Return the singleton instance, creating it if necessary.
def get_instance(cls) -> SupabaseClient:
"""Return the singleton instance, creating it if necessary."""
if cls._instance is None:
cfg = get_portfolio_config()
dsn = cfg["supabase_connection_string"]
if not dsn:
raise PortfolioError(
"SUPABASE_CONNECTION_STRING not configured. "
"Set it in .env or as an environment variable."
)
cls._instance = cls(dsn)
return cls._instance
Reads SUPABASE_URL and SUPABASE_KEY from ``get_portfolio_config()``.
@classmethod
def reset_instance(cls) -> None:
"""Close and reset the singleton (for testing)."""
if cls._instance is not None:
try:
cls._instance._conn.close()
except Exception:
pass
cls._instance = None
Raises:
PortfolioError: If SUPABASE_URL or SUPABASE_KEY are not configured.
"""
# TODO: implement
raise NotImplementedError
def _cursor(self):
"""Return a RealDictCursor."""
return self._conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
# ------------------------------------------------------------------
# Portfolio CRUD
# ------------------------------------------------------------------
def create_portfolio(self, portfolio: Portfolio) -> Portfolio:
"""Insert a new portfolio row.
Args:
portfolio: Portfolio instance with all required fields set.
Returns:
Portfolio with DB-assigned timestamps.
Raises:
DuplicatePortfolioError: If portfolio_id already exists.
"""
# TODO: implement
raise NotImplementedError
"""Insert a new portfolio row."""
pid = portfolio.portfolio_id or str(uuid.uuid4())
try:
with self._cursor() as cur:
cur.execute(
"""INSERT INTO portfolios
(portfolio_id, name, cash, initial_cash, currency, report_path, metadata)
VALUES (%s, %s, %s, %s, %s, %s, %s)
RETURNING *""",
(pid, portfolio.name, portfolio.cash, portfolio.initial_cash,
portfolio.currency, portfolio.report_path,
json.dumps(portfolio.metadata)),
)
row = cur.fetchone()
except psycopg2.errors.UniqueViolation as exc:
raise DuplicatePortfolioError(f"Portfolio already exists: {pid}") from exc
return self._row_to_portfolio(row)
def get_portfolio(self, portfolio_id: str) -> Portfolio:
"""Fetch a portfolio by ID.
Args:
portfolio_id: UUID of the target portfolio.
Returns:
Portfolio instance.
Raises:
PortfolioNotFoundError: If no portfolio has that ID.
"""
# TODO: implement
raise NotImplementedError
"""Fetch a portfolio by ID."""
with self._cursor() as cur:
cur.execute("SELECT * FROM portfolios WHERE portfolio_id = %s", (portfolio_id,))
row = cur.fetchone()
if not row:
raise PortfolioNotFoundError(f"Portfolio not found: {portfolio_id}")
return self._row_to_portfolio(row)
def list_portfolios(self) -> list[Portfolio]:
"""Return all portfolios ordered by created_at DESC."""
# TODO: implement
raise NotImplementedError
with self._cursor() as cur:
cur.execute("SELECT * FROM portfolios ORDER BY created_at DESC")
rows = cur.fetchall()
return [self._row_to_portfolio(r) for r in rows]
def update_portfolio(self, portfolio: Portfolio) -> Portfolio:
"""Update mutable portfolio fields (cash, report_path, metadata).
Args:
portfolio: Portfolio with updated field values.
Returns:
Updated Portfolio with refreshed updated_at.
Raises:
PortfolioNotFoundError: If portfolio_id does not exist.
"""
# TODO: implement
raise NotImplementedError
"""Update mutable portfolio fields (cash, report_path, metadata)."""
with self._cursor() as cur:
cur.execute(
"""UPDATE portfolios
SET cash = %s, report_path = %s, metadata = %s
WHERE portfolio_id = %s
RETURNING *""",
(portfolio.cash, portfolio.report_path,
json.dumps(portfolio.metadata), portfolio.portfolio_id),
)
row = cur.fetchone()
if not row:
raise PortfolioNotFoundError(f"Portfolio not found: {portfolio.portfolio_id}")
return self._row_to_portfolio(row)
def delete_portfolio(self, portfolio_id: str) -> None:
"""Delete a portfolio and all associated data (CASCADE).
Args:
portfolio_id: UUID of the portfolio to delete.
Raises:
PortfolioNotFoundError: If portfolio_id does not exist.
"""
# TODO: implement
raise NotImplementedError
"""Delete a portfolio and all associated data (CASCADE)."""
with self._cursor() as cur:
cur.execute(
"DELETE FROM portfolios WHERE portfolio_id = %s RETURNING portfolio_id",
(portfolio_id,),
)
row = cur.fetchone()
if not row:
raise PortfolioNotFoundError(f"Portfolio not found: {portfolio_id}")
# ------------------------------------------------------------------
# Holdings CRUD
# ------------------------------------------------------------------
def upsert_holding(self, holding: Holding) -> Holding:
"""Insert or update a holding row (upsert on portfolio_id + ticker).
Args:
holding: Holding instance with all required fields set.
Returns:
Holding with DB-assigned / refreshed timestamps.
"""
# TODO: implement
raise NotImplementedError
"""Insert or update a holding row (upsert on portfolio_id + ticker)."""
hid = holding.holding_id or str(uuid.uuid4())
with self._cursor() as cur:
cur.execute(
"""INSERT INTO holdings
(holding_id, portfolio_id, ticker, shares, avg_cost, sector, industry)
VALUES (%s, %s, %s, %s, %s, %s, %s)
ON CONFLICT ON CONSTRAINT holdings_portfolio_ticker_unique
DO UPDATE SET shares = EXCLUDED.shares,
avg_cost = EXCLUDED.avg_cost,
sector = EXCLUDED.sector,
industry = EXCLUDED.industry
RETURNING *""",
(hid, holding.portfolio_id, holding.ticker.upper(),
holding.shares, holding.avg_cost, holding.sector, holding.industry),
)
row = cur.fetchone()
return self._row_to_holding(row)
def get_holding(self, portfolio_id: str, ticker: str) -> Holding | None:
"""Return the holding for (portfolio_id, ticker), or None if not found.
Args:
portfolio_id: UUID of the target portfolio.
ticker: Ticker symbol (case-insensitive, stored as uppercase).
"""
# TODO: implement
raise NotImplementedError
"""Return the holding for (portfolio_id, ticker), or None."""
with self._cursor() as cur:
cur.execute(
"SELECT * FROM holdings WHERE portfolio_id = %s AND ticker = %s",
(portfolio_id, ticker.upper()),
)
row = cur.fetchone()
return self._row_to_holding(row) if row else None
def list_holdings(self, portfolio_id: str) -> list[Holding]:
"""Return all holdings for a portfolio ordered by cost_basis DESC.
Args:
portfolio_id: UUID of the target portfolio.
"""
# TODO: implement
raise NotImplementedError
"""Return all holdings for a portfolio ordered by cost_basis DESC."""
with self._cursor() as cur:
cur.execute(
"""SELECT * FROM holdings
WHERE portfolio_id = %s
ORDER BY shares * avg_cost DESC""",
(portfolio_id,),
)
rows = cur.fetchall()
return [self._row_to_holding(r) for r in rows]
def delete_holding(self, portfolio_id: str, ticker: str) -> None:
"""Delete the holding for (portfolio_id, ticker).
Args:
portfolio_id: UUID of the target portfolio.
ticker: Ticker symbol.
Raises:
HoldingNotFoundError: If no such holding exists.
"""
# TODO: implement
raise NotImplementedError
"""Delete the holding for (portfolio_id, ticker)."""
with self._cursor() as cur:
cur.execute(
"DELETE FROM holdings WHERE portfolio_id = %s AND ticker = %s RETURNING holding_id",
(portfolio_id, ticker.upper()),
)
row = cur.fetchone()
if not row:
raise HoldingNotFoundError(
f"Holding not found: {ticker} in portfolio {portfolio_id}"
)
# ------------------------------------------------------------------
# Trades
# ------------------------------------------------------------------
def record_trade(self, trade: Trade) -> Trade:
"""Insert a new trade record. Immutable — no update method.
Args:
trade: Trade instance with all required fields set.
Returns:
Trade with DB-assigned trade_id and trade_date.
"""
# TODO: implement
raise NotImplementedError
"""Insert a new trade record."""
tid = trade.trade_id or str(uuid.uuid4())
with self._cursor() as cur:
cur.execute(
"""INSERT INTO trades
(trade_id, portfolio_id, ticker, action, shares, price,
total_value, rationale, signal_source, metadata)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING *""",
(tid, trade.portfolio_id, trade.ticker, trade.action,
trade.shares, trade.price, trade.total_value,
trade.rationale, trade.signal_source,
json.dumps(trade.metadata)),
)
row = cur.fetchone()
return self._row_to_trade(row)
def list_trades(
self,
@ -208,51 +250,142 @@ class SupabaseClient:
ticker: str | None = None,
limit: int = 100,
) -> list[Trade]:
"""Return recent trades for a portfolio, newest first.
Args:
portfolio_id: Filter by portfolio.
ticker: Optional additional filter by ticker symbol.
limit: Maximum number of rows to return.
"""
# TODO: implement
raise NotImplementedError
"""Return recent trades for a portfolio, newest first."""
if ticker:
query = """SELECT * FROM trades
WHERE portfolio_id = %s AND ticker = %s
ORDER BY trade_date DESC LIMIT %s"""
params = (portfolio_id, ticker.upper(), limit)
else:
query = """SELECT * FROM trades
WHERE portfolio_id = %s
ORDER BY trade_date DESC LIMIT %s"""
params = (portfolio_id, limit)
with self._cursor() as cur:
cur.execute(query, params)
rows = cur.fetchall()
return [self._row_to_trade(r) for r in rows]
# ------------------------------------------------------------------
# Snapshots
# ------------------------------------------------------------------
def save_snapshot(self, snapshot: PortfolioSnapshot) -> PortfolioSnapshot:
"""Insert a new immutable portfolio snapshot.
Args:
snapshot: PortfolioSnapshot with all required fields set.
Returns:
Snapshot with DB-assigned snapshot_id and snapshot_date.
"""
# TODO: implement
raise NotImplementedError
"""Insert a new immutable portfolio snapshot."""
sid = snapshot.snapshot_id or str(uuid.uuid4())
with self._cursor() as cur:
cur.execute(
"""INSERT INTO snapshots
(snapshot_id, portfolio_id, total_value, cash, equity_value,
num_positions, holdings_snapshot, metadata)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
RETURNING *""",
(sid, snapshot.portfolio_id, snapshot.total_value,
snapshot.cash, snapshot.equity_value, snapshot.num_positions,
json.dumps(snapshot.holdings_snapshot),
json.dumps(snapshot.metadata)),
)
row = cur.fetchone()
return self._row_to_snapshot(row)
def get_latest_snapshot(self, portfolio_id: str) -> PortfolioSnapshot | None:
"""Return the most recent snapshot for a portfolio, or None.
Args:
portfolio_id: UUID of the target portfolio.
"""
# TODO: implement
raise NotImplementedError
"""Return the most recent snapshot for a portfolio, or None."""
with self._cursor() as cur:
cur.execute(
"""SELECT * FROM snapshots
WHERE portfolio_id = %s
ORDER BY snapshot_date DESC LIMIT 1""",
(portfolio_id,),
)
row = cur.fetchone()
return self._row_to_snapshot(row) if row else None
def list_snapshots(
self,
portfolio_id: str,
limit: int = 30,
) -> list[PortfolioSnapshot]:
"""Return snapshots newest-first up to limit.
"""Return snapshots newest-first up to limit."""
with self._cursor() as cur:
cur.execute(
"""SELECT * FROM snapshots
WHERE portfolio_id = %s
ORDER BY snapshot_date DESC LIMIT %s""",
(portfolio_id, limit),
)
rows = cur.fetchall()
return [self._row_to_snapshot(r) for r in rows]
Args:
portfolio_id: UUID of the target portfolio.
limit: Maximum number of snapshots to return.
"""
# TODO: implement
raise NotImplementedError
# ------------------------------------------------------------------
# Row -> Model helpers
# ------------------------------------------------------------------
@staticmethod
def _row_to_portfolio(row: dict) -> Portfolio:
metadata = row.get("metadata") or {}
if isinstance(metadata, str):
metadata = json.loads(metadata)
return Portfolio(
portfolio_id=str(row["portfolio_id"]),
name=row["name"],
cash=float(row["cash"]),
initial_cash=float(row["initial_cash"]),
currency=row["currency"].strip(),
created_at=str(row["created_at"]),
updated_at=str(row["updated_at"]),
report_path=row.get("report_path"),
metadata=metadata,
)
@staticmethod
def _row_to_holding(row: dict) -> Holding:
return Holding(
holding_id=str(row["holding_id"]),
portfolio_id=str(row["portfolio_id"]),
ticker=row["ticker"],
shares=float(row["shares"]),
avg_cost=float(row["avg_cost"]),
sector=row.get("sector"),
industry=row.get("industry"),
created_at=str(row["created_at"]),
updated_at=str(row["updated_at"]),
)
@staticmethod
def _row_to_trade(row: dict) -> Trade:
metadata = row.get("metadata") or {}
if isinstance(metadata, str):
metadata = json.loads(metadata)
return Trade(
trade_id=str(row["trade_id"]),
portfolio_id=str(row["portfolio_id"]),
ticker=row["ticker"],
action=row["action"],
shares=float(row["shares"]),
price=float(row["price"]),
total_value=float(row["total_value"]),
trade_date=str(row["trade_date"]),
rationale=row.get("rationale"),
signal_source=row.get("signal_source"),
metadata=metadata,
)
@staticmethod
def _row_to_snapshot(row: dict) -> PortfolioSnapshot:
holdings = row.get("holdings_snapshot") or []
if isinstance(holdings, str):
holdings = json.loads(holdings)
metadata = row.get("metadata") or {}
if isinstance(metadata, str):
metadata = json.loads(metadata)
return PortfolioSnapshot(
snapshot_id=str(row["snapshot_id"]),
portfolio_id=str(row["portfolio_id"]),
snapshot_date=str(row["snapshot_date"]),
total_value=float(row["total_value"]),
cash=float(row["cash"]),
equity_value=float(row["equity_value"]),
num_positions=int(row["num_positions"]),
holdings_snapshot=holdings,
metadata=metadata,
)